diff --git a/codex-rs/codex-api/src/endpoint/models.rs b/codex-rs/codex-api/src/endpoint/models.rs index b15f07fca2a..7f21c776352 100644 --- a/codex-rs/codex-api/src/endpoint/models.rs +++ b/codex-rs/codex-api/src/endpoint/models.rs @@ -5,6 +5,7 @@ use crate::provider::Provider; use crate::telemetry::run_with_request_telemetry; use codex_client::HttpTransport; use codex_client::RequestTelemetry; +use codex_protocol::openai_models::ModelInfo; use codex_protocol::openai_models::ModelsResponse; use http::HeaderMap; use http::Method; @@ -41,7 +42,7 @@ impl ModelsClient { &self, client_version: &str, extra_headers: HeaderMap, - ) -> Result { + ) -> Result<(Vec, Option), ApiError> { let builder = || { let mut req = self.provider.build_request(Method::GET, self.path()); req.headers.extend(extra_headers.clone()); @@ -66,7 +67,7 @@ impl ModelsClient { .and_then(|value| value.to_str().ok()) .map(ToString::to_string); - let ModelsResponse { models, etag } = serde_json::from_slice::(&resp.body) + let ModelsResponse { models } = serde_json::from_slice::(&resp.body) .map_err(|e| { ApiError::Stream(format!( "failed to decode models response: {e}; body: {}", @@ -74,9 +75,7 @@ impl ModelsClient { )) })?; - let etag = header_etag.unwrap_or(etag); - - Ok(ModelsResponse { models, etag }) + Ok((models, header_etag)) } } @@ -102,16 +101,15 @@ mod tests { struct CapturingTransport { last_request: Arc>>, body: Arc, + response_etag: Arc>, } impl Default for CapturingTransport { fn default() -> Self { Self { last_request: Arc::new(Mutex::new(None)), - body: Arc::new(ModelsResponse { - models: Vec::new(), - etag: String::new(), - }), + body: Arc::new(ModelsResponse { models: Vec::new() }), + response_etag: Arc::new(None), } } } @@ -122,8 +120,8 @@ mod tests { *self.last_request.lock().unwrap() = Some(req); let body = serde_json::to_vec(&*self.body).unwrap(); let mut headers = HeaderMap::new(); - if !self.body.etag.is_empty() { - headers.insert(ETAG, self.body.etag.parse().unwrap()); + if let Some(etag) = self.response_etag.as_ref().as_deref() { + headers.insert(ETAG, etag.parse().unwrap()); } Ok(Response { status: StatusCode::OK, @@ -166,14 +164,12 @@ mod tests { #[tokio::test] async fn appends_client_version_query() { - let response = ModelsResponse { - models: Vec::new(), - etag: String::new(), - }; + let response = ModelsResponse { models: Vec::new() }; let transport = CapturingTransport { last_request: Arc::new(Mutex::new(None)), body: Arc::new(response), + response_etag: Arc::new(None), }; let client = ModelsClient::new( @@ -182,12 +178,12 @@ mod tests { DummyAuth, ); - let result = client + let (models, _etag) = client .list_models("0.99.0", HeaderMap::new()) .await .expect("request should succeed"); - assert_eq!(result.models.len(), 0); + assert_eq!(models.len(), 0); let url = transport .last_request @@ -232,12 +228,12 @@ mod tests { })) .unwrap(), ], - etag: String::new(), }; let transport = CapturingTransport { last_request: Arc::new(Mutex::new(None)), body: Arc::new(response), + response_etag: Arc::new(None), }; let client = ModelsClient::new( @@ -246,27 +242,25 @@ mod tests { DummyAuth, ); - let result = client + let (models, _etag) = client .list_models("0.99.0", HeaderMap::new()) .await .expect("request should succeed"); - assert_eq!(result.models.len(), 1); - assert_eq!(result.models[0].slug, "gpt-test"); - assert_eq!(result.models[0].supported_in_api, true); - assert_eq!(result.models[0].priority, 1); + assert_eq!(models.len(), 1); + assert_eq!(models[0].slug, "gpt-test"); + assert_eq!(models[0].supported_in_api, true); + assert_eq!(models[0].priority, 1); } #[tokio::test] async fn list_models_includes_etag() { - let response = ModelsResponse { - models: Vec::new(), - etag: "\"abc\"".to_string(), - }; + let response = ModelsResponse { models: Vec::new() }; let transport = CapturingTransport { last_request: Arc::new(Mutex::new(None)), body: Arc::new(response), + response_etag: Arc::new(Some("\"abc\"".to_string())), }; let client = ModelsClient::new( @@ -275,12 +269,12 @@ mod tests { DummyAuth, ); - let result = client + let (models, etag) = client .list_models("0.1.0", HeaderMap::new()) .await .expect("request should succeed"); - assert_eq!(result.models.len(), 0); - assert_eq!(result.etag, "\"abc\""); + assert_eq!(models.len(), 0); + assert_eq!(etag.as_deref(), Some("\"abc\"")); } } diff --git a/codex-rs/codex-api/tests/models_integration.rs b/codex-rs/codex-api/tests/models_integration.rs index 93baffd3560..0b3e95ee303 100644 --- a/codex-rs/codex-api/tests/models_integration.rs +++ b/codex-rs/codex-api/tests/models_integration.rs @@ -90,7 +90,6 @@ async fn models_client_hits_models_endpoint() { reasoning_summary_format: ReasoningSummaryFormat::None, experimental_supported_tools: Vec::new(), }], - etag: String::new(), }; Mock::given(method("GET")) @@ -106,13 +105,13 @@ async fn models_client_hits_models_endpoint() { let transport = ReqwestTransport::new(reqwest::Client::new()); let client = ModelsClient::new(transport, provider(&base_url), DummyAuth); - let result = client + let (models, _etag) = client .list_models("0.1.0", HeaderMap::new()) .await .expect("models request should succeed"); - assert_eq!(result.models.len(), 1); - assert_eq!(result.models[0].slug, "gpt-test"); + assert_eq!(models.len(), 1); + assert_eq!(models[0].slug, "gpt-test"); let received = server .received_requests() diff --git a/codex-rs/core/src/api_bridge.rs b/codex-rs/core/src/api_bridge.rs index 79fd67d6501..c7da55da853 100644 --- a/codex-rs/core/src/api_bridge.rs +++ b/codex-rs/core/src/api_bridge.rs @@ -67,6 +67,11 @@ pub(crate) fn map_api_error(err: ApiError) -> CodexErr { status, request_id: extract_request_id(headers.as_ref()), }) + } else if status == http::StatusCode::PRECONDITION_FAILED + && body_text + .contains("Models catalog has changed. Please refresh your models list.") + { + CodexErr::OutdatedModels } else { CodexErr::UnexpectedStatus(UnexpectedResponseError { status, diff --git a/codex-rs/core/src/client.rs b/codex-rs/core/src/client.rs index aaf3b0ea353..3426da78c17 100644 --- a/codex-rs/core/src/client.rs +++ b/codex-rs/core/src/client.rs @@ -33,6 +33,7 @@ use http::StatusCode as HttpStatusCode; use reqwest::StatusCode; use serde_json::Value; use std::time::Duration; +use tokio::sync::RwLock; use tokio::sync::mpsc; use tracing::warn; @@ -53,11 +54,12 @@ use crate::openai_models::model_family::ModelFamily; use crate::tools::spec::create_tools_json_for_chat_completions_api; use crate::tools::spec::create_tools_json_for_responses_api; -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct ModelClient { config: Arc, auth_manager: Option>, - model_family: ModelFamily, + model_family: RwLock, + models_etag: RwLock>, otel_manager: OtelManager, provider: ModelProviderInfo, conversation_id: ConversationId, @@ -72,6 +74,7 @@ impl ModelClient { config: Arc, auth_manager: Option>, model_family: ModelFamily, + models_etag: Option, otel_manager: OtelManager, provider: ModelProviderInfo, effort: Option, @@ -82,7 +85,8 @@ impl ModelClient { Self { config, auth_manager, - model_family, + model_family: RwLock::new(model_family), + models_etag: RwLock::new(models_etag), otel_manager, provider, conversation_id, @@ -92,8 +96,8 @@ impl ModelClient { } } - pub fn get_model_context_window(&self) -> Option { - let model_family = self.get_model_family(); + pub async fn get_model_context_window(&self) -> Option { + let model_family = self.get_model_family().await; let effective_context_window_percent = model_family.effective_context_window_percent; model_family .context_window @@ -146,7 +150,7 @@ impl ModelClient { } let auth_manager = self.auth_manager.clone(); - let model_family = self.get_model_family(); + let model_family = self.get_model_family().await; let instructions = prompt.get_full_instructions(&model_family).into_owned(); let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?; let api_prompt = build_api_prompt(prompt, instructions, tools_json); @@ -167,7 +171,7 @@ impl ModelClient { let stream_result = client .stream_prompt( - &self.get_model(), + &self.get_model().await, &api_prompt, Some(conversation_id.clone()), Some(session_source.clone()), @@ -200,7 +204,7 @@ impl ModelClient { } let auth_manager = self.auth_manager.clone(); - let model_family = self.get_model_family(); + let model_family = self.get_model_family().await; let instructions = prompt.get_full_instructions(&model_family).into_owned(); let tools_json: Vec = create_tools_json_for_responses_api(&prompt.tools)?; @@ -262,11 +266,14 @@ impl ModelClient { store_override: None, conversation_id: Some(conversation_id.clone()), session_source: Some(session_source.clone()), - extra_headers: beta_feature_headers(&self.config), + extra_headers: beta_feature_headers( + &self.config, + self.get_models_etag().await.clone(), + ), }; let stream_result = client - .stream_prompt(&self.get_model(), &api_prompt, options) + .stream_prompt(&self.get_model().await, &api_prompt, options) .await; match stream_result { @@ -297,13 +304,25 @@ impl ModelClient { } /// Returns the currently configured model slug. - pub fn get_model(&self) -> String { - self.get_model_family().get_model_slug().to_string() + pub async fn get_model(&self) -> String { + self.get_model_family().await.get_model_slug().to_string() } /// Returns the currently configured model family. - pub fn get_model_family(&self) -> ModelFamily { - self.model_family.clone() + pub async fn get_model_family(&self) -> ModelFamily { + self.model_family.read().await.clone() + } + + pub async fn get_models_etag(&self) -> Option { + self.models_etag.read().await.clone() + } + + pub async fn update_models_etag(&self, etag: Option) { + *self.models_etag.write().await = etag; + } + + pub async fn update_model_family(&self, model_family: ModelFamily) { + *self.model_family.write().await = model_family; } /// Returns the current reasoning effort setting. @@ -340,10 +359,10 @@ impl ModelClient { .with_telemetry(Some(request_telemetry)); let instructions = prompt - .get_full_instructions(&self.get_model_family()) + .get_full_instructions(&self.get_model_family().await) .into_owned(); let payload = ApiCompactionInput { - model: &self.get_model(), + model: &self.get_model().await, input: &prompt.input, instructions: &instructions, }; @@ -398,7 +417,7 @@ fn build_api_prompt(prompt: &Prompt, instructions: String, tools_json: Vec ApiHeaderMap { +fn beta_feature_headers(config: &Config, models_etag: Option) -> ApiHeaderMap { let enabled = FEATURES .iter() .filter_map(|spec| { @@ -416,6 +435,11 @@ fn beta_feature_headers(config: &Config) -> ApiHeaderMap { { headers.insert("x-codex-beta-features", header_value); } + if let Some(etag) = models_etag + && let Ok(header_value) = HeaderValue::from_str(&etag) + { + headers.insert("X-If-Models-Match", header_value); + } headers } diff --git a/codex-rs/core/src/client_common.rs b/codex-rs/core/src/client_common.rs index 4a3bc8de235..db1a7333d18 100644 --- a/codex-rs/core/src/client_common.rs +++ b/codex-rs/core/src/client_common.rs @@ -1,6 +1,10 @@ use crate::client_common::tools::ToolSpec; +use crate::codex::Session; +use crate::codex::TurnContext; use crate::error::Result; +use crate::features::Feature; use crate::openai_models::model_family::ModelFamily; +use crate::tools::ToolRouter; pub use codex_api::common::ResponseEvent; use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS; use codex_protocol::models::ResponseItem; @@ -44,6 +48,28 @@ pub struct Prompt { } impl Prompt { + pub(crate) async fn new( + sess: &Session, + turn_context: &TurnContext, + router: &ToolRouter, + input: &[ResponseItem], + ) -> Prompt { + let model_supports_parallel = turn_context + .client + .get_model_family() + .await + .supports_parallel_tool_calls; + + Prompt { + input: input.to_vec(), + tools: router.specs(), + parallel_tool_calls: model_supports_parallel + && sess.enabled(Feature::ParallelToolCalls), + base_instructions_override: turn_context.base_instructions.clone(), + output_schema: turn_context.final_output_json_schema.clone(), + } + } + pub(crate) fn get_full_instructions<'a>(&'a self, model: &'a ModelFamily) -> Cow<'a, str> { let base = self .base_instructions_override diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index f0d2056587c..0c6b5739894 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -249,7 +249,7 @@ impl Codex { let config = Arc::new(config); if config.features.enabled(Feature::RemoteModels) - && let Err(err) = models_manager.refresh_available_models(&config).await + && let Err(err) = models_manager.try_refresh_available_models(&config).await { error!("failed to refresh available models: {err:?}"); } @@ -492,6 +492,7 @@ impl Session { session_configuration: &SessionConfiguration, per_turn_config: Config, model_family: ModelFamily, + models_etag: Option, conversation_id: ConversationId, sub_id: String, ) -> TurnContext { @@ -505,6 +506,7 @@ impl Session { per_turn_config.clone(), auth_manager, model_family.clone(), + models_etag, otel_manager, provider, session_configuration.model_reasoning_effort, @@ -788,7 +790,7 @@ impl Session { } }) { - let curr = turn_context.client.get_model(); + let curr = turn_context.client.get_model().await; if prev != curr { warn!( "resuming session with different model: previous={prev}, current={curr}" @@ -919,6 +921,7 @@ impl Session { .models_manager .construct_model_family(session_configuration.model.as_str(), &per_turn_config) .await; + let models_etag = self.services.models_manager.get_models_etag().await; let mut turn_context: TurnContext = Self::make_turn_context( Some(Arc::clone(&self.services.auth_manager)), &self.services.otel_manager, @@ -926,6 +929,7 @@ impl Session { &session_configuration, per_turn_config, model_family, + models_etag, self.conversation_id, sub_id, ); @@ -1334,7 +1338,7 @@ impl Session { if let Some(token_usage) = token_usage { state.update_token_info_from_usage( token_usage, - turn_context.client.get_model_context_window(), + turn_context.client.get_model_context_window().await, ); } } @@ -1346,6 +1350,7 @@ impl Session { .clone_history() .await .estimate_token_count(turn_context) + .await else { return; }; @@ -1366,7 +1371,7 @@ impl Session { }; if info.model_context_window.is_none() { - info.model_context_window = turn_context.client.get_model_context_window(); + info.model_context_window = turn_context.client.get_model_context_window().await; } state.set_token_info(Some(info)); @@ -1396,7 +1401,7 @@ impl Session { } pub(crate) async fn set_total_tokens_full(&self, turn_context: &TurnContext) { - let context_window = turn_context.client.get_model_context_window(); + let context_window = turn_context.client.get_model_context_window().await; if let Some(context_window) = context_window { { let mut state = self.state.lock().await; @@ -2105,6 +2110,7 @@ async fn spawn_review_thread( .models_manager .construct_model_family(&model, &config) .await; + let models_etag = sess.services.models_manager.get_models_etag().await; // For reviews, disable web_search and view_image regardless of global settings. let mut review_features = sess.features.clone(); review_features @@ -2137,6 +2143,7 @@ async fn spawn_review_thread( per_turn_config.clone(), auth_manager, model_family.clone(), + models_etag, otel_manager, provider, per_turn_config.model_reasoning_effort, @@ -2231,6 +2238,7 @@ pub(crate) async fn run_task( let auto_compact_limit = turn_context .client .get_model_family() + .await .auto_compact_token_limit() .unwrap_or(i64::MAX); let total_usage_tokens = sess.get_total_token_usage().await; @@ -2238,7 +2246,7 @@ pub(crate) async fn run_task( run_auto_compact(&sess, &turn_context).await; } let event = EventMsg::TaskStarted(TaskStartedEvent { - model_context_window: turn_context.client.get_model_context_window(), + model_context_window: turn_context.client.get_model_context_window().await, }); sess.send_event(&turn_context, event).await; @@ -2303,7 +2311,7 @@ pub(crate) async fn run_task( .collect::>(); match run_turn( Arc::clone(&sess), - Arc::clone(&turn_context), + &turn_context, Arc::clone(&turn_diff_tracker), turn_input, cancellation_token.child_token(), @@ -2362,6 +2370,36 @@ pub(crate) async fn run_task( last_agent_message } +pub(crate) async fn refresh_models_and_reset_turn_context( + sess: &Arc, + turn_context: &Arc, +) { + let config = { + let state = sess.state.lock().await; + state + .session_configuration + .original_config_do_not_use + .clone() + }; + if let Err(err) = sess + .services + .models_manager + .refresh_available_models(&config) + .await + { + error!("failed to refresh models after outdated models error: {err}"); + } + let model = turn_context.client.get_model().await; + let model_family = sess + .services + .models_manager + .construct_model_family(&model, &config) + .await; + let models_etag = sess.services.models_manager.get_models_etag().await; + turn_context.client.update_model_family(model_family).await; + turn_context.client.update_models_etag(models_etag).await; +} + async fn run_auto_compact(sess: &Arc, turn_context: &Arc) { if should_use_remote_compact_task(sess.as_ref(), &turn_context.client.get_provider()) { run_inline_remote_auto_compact_task(Arc::clone(sess), Arc::clone(turn_context)).await; @@ -2374,17 +2412,19 @@ async fn run_auto_compact(sess: &Arc, turn_context: &Arc) skip_all, fields( turn_id = %turn_context.sub_id, - model = %turn_context.client.get_model(), + model = tracing::field::Empty, cwd = %turn_context.cwd.display() ) )] async fn run_turn( sess: Arc, - turn_context: Arc, + turn_context: &Arc, turn_diff_tracker: SharedTurnDiffTracker, input: Vec, cancellation_token: CancellationToken, ) -> CodexResult { + let model = turn_context.client.get_model().await; + tracing::Span::current().record("model", field::display(&model)); let mcp_tools = sess .services .mcp_connection_manager @@ -2393,37 +2433,32 @@ async fn run_turn( .list_all_tools() .or_cancel(&cancellation_token) .await?; - let router = Arc::new(ToolRouter::from_config( - &turn_context.tools_config, - Some( - mcp_tools - .into_iter() - .map(|(name, tool)| (name, tool.tool)) - .collect(), - ), - )); - - let model_supports_parallel = turn_context - .client - .get_model_family() - .supports_parallel_tool_calls; - - let prompt = Prompt { - input, - tools: router.specs(), - parallel_tool_calls: model_supports_parallel && sess.enabled(Feature::ParallelToolCalls), - base_instructions_override: turn_context.base_instructions.clone(), - output_schema: turn_context.final_output_json_schema.clone(), - }; let mut retries = 0; loop { + let router = Arc::new(ToolRouter::from_config( + &turn_context.tools_config, + Some( + mcp_tools + .clone() + .into_iter() + .map(|(name, tool)| (name, tool.tool)) + .collect(), + ), + )); + let prompt = Prompt::new( + sess.as_ref(), + turn_context.as_ref(), + router.as_ref(), + &input, + ); + match try_run_turn( Arc::clone(&router), Arc::clone(&sess), - Arc::clone(&turn_context), + Arc::clone(turn_context), Arc::clone(&turn_diff_tracker), - &prompt, + &prompt.await, cancellation_token.child_token(), ) .await @@ -2437,13 +2472,13 @@ async fn run_turn( Err(CodexErr::EnvVar(var)) => return Err(CodexErr::EnvVar(var)), Err(e @ CodexErr::Fatal(_)) => return Err(e), Err(e @ CodexErr::ContextWindowExceeded) => { - sess.set_total_tokens_full(&turn_context).await; + sess.set_total_tokens_full(turn_context).await; return Err(e); } Err(CodexErr::UsageLimitReached(e)) => { let rate_limits = e.rate_limits.clone(); if let Some(rate_limits) = rate_limits { - sess.update_rate_limits(&turn_context, rate_limits).await; + sess.update_rate_limits(turn_context, rate_limits).await; } return Err(CodexErr::UsageLimitReached(e)); } @@ -2457,6 +2492,11 @@ async fn run_turn( let max_retries = turn_context.client.get_provider().stream_max_retries(); if retries < max_retries { retries += 1; + // Refresh models if we got an outdated models error + if matches!(e, CodexErr::OutdatedModels) { + refresh_models_and_reset_turn_context(&sess, turn_context).await; + continue; + } let delay = match e { CodexErr::Stream(_, Some(delay)) => delay, _ => backoff(retries), @@ -2469,7 +2509,7 @@ async fn run_turn( // user understands what is happening instead of staring // at a seemingly frozen screen. sess.notify_stream_error( - &turn_context, + turn_context, format!("Reconnecting... {retries}/{max_retries}"), e, ) @@ -2514,7 +2554,7 @@ async fn drain_in_flight( skip_all, fields( turn_id = %turn_context.sub_id, - model = %turn_context.client.get_model() + model = tracing::field::Empty, ) )] async fn try_run_turn( @@ -2525,11 +2565,13 @@ async fn try_run_turn( prompt: &Prompt, cancellation_token: CancellationToken, ) -> CodexResult { + let model = turn_context.client.get_model().await; + tracing::Span::current().record("model", field::display(&model)); let rollout_item = RolloutItem::TurnContext(TurnContextItem { cwd: turn_context.cwd.clone(), approval_policy: turn_context.approval_policy, sandbox_policy: turn_context.sandbox_policy.clone(), - model: turn_context.client.get_model(), + model, effort: turn_context.client.get_reasoning_effort(), summary: turn_context.client.get_reasoning_summary(), }); @@ -2537,7 +2579,6 @@ async fn try_run_turn( sess.persist_rollout_items(&[rollout_item]).await; let mut stream = turn_context .client - .clone() .stream(prompt) .instrument(trace_span!("stream_request")) .or_cancel(&cancellation_token) @@ -3163,6 +3204,7 @@ mod tests { &session_configuration, per_turn_config, model_family, + None, conversation_id, "turn_id".to_string(), ); @@ -3249,6 +3291,7 @@ mod tests { &session_configuration, per_turn_config, model_family, + None, conversation_id, "turn_id".to_string(), )); diff --git a/codex-rs/core/src/compact.rs b/codex-rs/core/src/compact.rs index 1a90b7b223f..a864774d69c 100644 --- a/codex-rs/core/src/compact.rs +++ b/codex-rs/core/src/compact.rs @@ -6,6 +6,7 @@ use crate::client_common::ResponseEvent; use crate::codex::Session; use crate::codex::TurnContext; use crate::codex::get_last_assistant_message_from_turn; +use crate::codex::refresh_models_and_reset_turn_context; use crate::error::CodexErr; use crate::error::Result as CodexResult; use crate::features::Feature; @@ -55,7 +56,7 @@ pub(crate) async fn run_compact_task( input: Vec, ) { let start_event = EventMsg::TaskStarted(TaskStartedEvent { - model_context_window: turn_context.client.get_model_context_window(), + model_context_window: turn_context.client.get_model_context_window().await, }); sess.send_event(&turn_context, start_event).await; run_compact_task_inner(sess.clone(), turn_context, input).await; @@ -83,7 +84,7 @@ async fn run_compact_task_inner( cwd: turn_context.cwd.clone(), approval_policy: turn_context.approval_policy, sandbox_policy: turn_context.sandbox_policy.clone(), - model: turn_context.client.get_model(), + model: turn_context.client.get_model().await, effort: turn_context.client.get_reasoning_effort(), summary: turn_context.client.get_reasoning_summary(), }); @@ -132,6 +133,10 @@ async fn run_compact_task_inner( Err(e) => { if retries < max_retries { retries += 1; + if matches!(e, CodexErr::OutdatedModels) { + refresh_models_and_reset_turn_context(&sess, &turn_context).await; + continue; + } let delay = backoff(retries); sess.notify_stream_error( turn_context.as_ref(), @@ -290,7 +295,7 @@ async fn drain_to_completed( turn_context: &TurnContext, prompt: &Prompt, ) -> CodexResult<()> { - let mut stream = turn_context.client.clone().stream(prompt).await?; + let mut stream = turn_context.client.stream(prompt).await?; loop { let maybe_event = stream.next().await; let Some(event) = maybe_event else { diff --git a/codex-rs/core/src/compact_remote.rs b/codex-rs/core/src/compact_remote.rs index b855f28d39d..3419b2e51fd 100644 --- a/codex-rs/core/src/compact_remote.rs +++ b/codex-rs/core/src/compact_remote.rs @@ -20,7 +20,7 @@ pub(crate) async fn run_inline_remote_auto_compact_task( pub(crate) async fn run_remote_compact_task(sess: Arc, turn_context: Arc) { let start_event = EventMsg::TaskStarted(TaskStartedEvent { - model_context_window: turn_context.client.get_model_context_window(), + model_context_window: turn_context.client.get_model_context_window().await, }); sess.send_event(&turn_context, start_event).await; diff --git a/codex-rs/core/src/context_manager/history.rs b/codex-rs/core/src/context_manager/history.rs index c18ad7df8ec..841be4bae58 100644 --- a/codex-rs/core/src/context_manager/history.rs +++ b/codex-rs/core/src/context_manager/history.rs @@ -79,8 +79,8 @@ impl ContextManager { // Estimate token usage using byte-based heuristics from the truncation helpers. // This is a coarse lower bound, not a tokenizer-accurate count. - pub(crate) fn estimate_token_count(&self, turn_context: &TurnContext) -> Option { - let model_family = turn_context.client.get_model_family(); + pub(crate) async fn estimate_token_count(&self, turn_context: &TurnContext) -> Option { + let model_family = turn_context.client.get_model_family().await; let base_tokens = i64::try_from(approx_token_count(model_family.base_instructions.as_str())) .unwrap_or(i64::MAX); diff --git a/codex-rs/core/src/error.rs b/codex-rs/core/src/error.rs index e8fa91d26e8..e85eaf0dbfe 100644 --- a/codex-rs/core/src/error.rs +++ b/codex-rs/core/src/error.rs @@ -90,6 +90,10 @@ pub enum CodexErr { #[error("spawn failed: child stdout/stderr not captured")] Spawn, + /// Returned when the models list is outdated and needs to be refreshed. + #[error("remote models list is outdated")] + OutdatedModels, + /// Returned by run_command_stream when the user pressed Ctrl‑C (SIGINT). Session uses this to /// surface a polite FunctionCallOutput back to the model instead of crashing the CLI. #[error("interrupted (Ctrl-C). Something went wrong? Hit `/feedback` to report the issue.")] diff --git a/codex-rs/core/src/openai_models/models_manager.rs b/codex-rs/core/src/openai_models/models_manager.rs index 7f54c4f8525..999ac20164d 100644 --- a/codex-rs/core/src/openai_models/models_manager.rs +++ b/codex-rs/core/src/openai_models/models_manager.rs @@ -77,7 +77,7 @@ impl ModelsManager { } /// Fetch the latest remote models, using the on-disk cache when still fresh. - pub async fn refresh_available_models(&self, config: &Config) -> CoreResult<()> { + pub async fn try_refresh_available_models(&self, config: &Config) -> CoreResult<()> { if !config.features.enabled(Feature::RemoteModels) || self.auth_manager.get_auth_mode() == Some(AuthMode::ApiKey) { @@ -86,7 +86,15 @@ impl ModelsManager { if self.try_load_cache().await { return Ok(()); } + self.refresh_available_models(config).await + } + pub async fn refresh_available_models(&self, config: &Config) -> CoreResult<()> { + if !config.features.enabled(Feature::RemoteModels) + || self.auth_manager.get_auth_mode() == Some(AuthMode::ApiKey) + { + return Ok(()); + } let auth = self.auth_manager.auth(); let api_provider = self.provider.to_api_provider(Some(AuthMode::ChatGPT))?; let api_auth = auth_provider_from_auth(auth.clone(), &self.provider).await?; @@ -94,12 +102,12 @@ impl ModelsManager { let client = ModelsClient::new(transport, api_provider, api_auth); let client_version = format_client_version_to_whole(); - let ModelsResponse { models, etag } = client + let (models, etag) = client .list_models(&client_version, HeaderMap::new()) .await .map_err(map_api_error)?; - let etag = (!etag.is_empty()).then_some(etag); + let etag = etag.filter(|value| !value.is_empty()); self.apply_remote_models(models.clone()).await; *self.etag.write().await = etag.clone(); @@ -108,7 +116,7 @@ impl ModelsManager { } pub async fn list_models(&self, config: &Config) -> Vec { - if let Err(err) = self.refresh_available_models(config).await { + if let Err(err) = self.try_refresh_available_models(config).await { error!("failed to refresh available models: {err}"); } let remote_models = self.remote_models(config).await; @@ -131,11 +139,15 @@ impl ModelsManager { .with_config_overrides(config) } + pub async fn get_models_etag(&self) -> Option { + self.etag.read().await.clone() + } + pub async fn get_model(&self, model: &Option, config: &Config) -> String { if let Some(model) = model.as_ref() { return model.to_string(); } - if let Err(err) = self.refresh_available_models(config).await { + if let Err(err) = self.try_refresh_available_models(config).await { error!("failed to refresh available models: {err}"); } // if codex-auto-balanced exists & signed in with chatgpt mode, return it, otherwise return the default model @@ -389,7 +401,6 @@ mod tests { &server, ModelsResponse { models: remote_models.clone(), - etag: String::new(), }, ) .await; @@ -407,7 +418,7 @@ mod tests { let manager = ModelsManager::with_provider(auth_manager, provider); manager - .refresh_available_models(&config) + .try_refresh_available_models(&config) .await .expect("refresh succeeds"); let cached_remote = manager.remote_models(&config).await; @@ -446,7 +457,6 @@ mod tests { &server, ModelsResponse { models: remote_models.clone(), - etag: String::new(), }, ) .await; @@ -467,7 +477,7 @@ mod tests { let manager = ModelsManager::with_provider(auth_manager, provider); manager - .refresh_available_models(&config) + .try_refresh_available_models(&config) .await .expect("first refresh succeeds"); assert_eq!( @@ -478,7 +488,7 @@ mod tests { // Second call should read from cache and avoid the network. manager - .refresh_available_models(&config) + .try_refresh_available_models(&config) .await .expect("cached refresh succeeds"); assert_eq!( @@ -501,7 +511,6 @@ mod tests { &server, ModelsResponse { models: initial_models.clone(), - etag: String::new(), }, ) .await; @@ -522,7 +531,7 @@ mod tests { let manager = ModelsManager::with_provider(auth_manager, provider); manager - .refresh_available_models(&config) + .try_refresh_available_models(&config) .await .expect("initial refresh succeeds"); @@ -542,13 +551,12 @@ mod tests { &server, ModelsResponse { models: updated_models.clone(), - etag: String::new(), }, ) .await; manager - .refresh_available_models(&config) + .try_refresh_available_models(&config) .await .expect("second refresh succeeds"); assert_eq!( @@ -576,7 +584,6 @@ mod tests { &server, ModelsResponse { models: initial_models, - etag: String::new(), }, ) .await; @@ -595,7 +602,7 @@ mod tests { manager.cache_ttl = Duration::ZERO; manager - .refresh_available_models(&config) + .try_refresh_available_models(&config) .await .expect("initial refresh succeeds"); @@ -605,13 +612,12 @@ mod tests { &server, ModelsResponse { models: refreshed_models, - etag: String::new(), }, ) .await; manager - .refresh_available_models(&config) + .try_refresh_available_models(&config) .await .expect("second refresh succeeds"); diff --git a/codex-rs/core/src/tasks/user_shell.rs b/codex-rs/core/src/tasks/user_shell.rs index aec09514ca3..b053020bdb9 100644 --- a/codex-rs/core/src/tasks/user_shell.rs +++ b/codex-rs/core/src/tasks/user_shell.rs @@ -59,7 +59,7 @@ impl SessionTask for UserShellCommandTask { cancellation_token: CancellationToken, ) -> Option { let event = EventMsg::TaskStarted(TaskStartedEvent { - model_context_window: turn_context.client.get_model_context_window(), + model_context_window: turn_context.client.get_model_context_window().await, }); let session = session.clone_session(); session.send_event(turn_context.as_ref(), event).await; diff --git a/codex-rs/core/tests/chat_completions_payload.rs b/codex-rs/core/tests/chat_completions_payload.rs index 5867935470e..b0086ef5186 100644 --- a/codex-rs/core/tests/chat_completions_payload.rs +++ b/codex-rs/core/tests/chat_completions_payload.rs @@ -92,6 +92,7 @@ async fn run_request(input: Vec) -> Value { Arc::clone(&config), None, model_family, + None, otel_manager, provider, effort, diff --git a/codex-rs/core/tests/chat_completions_sse.rs b/codex-rs/core/tests/chat_completions_sse.rs index f58b039220e..74f7730504c 100644 --- a/codex-rs/core/tests/chat_completions_sse.rs +++ b/codex-rs/core/tests/chat_completions_sse.rs @@ -93,6 +93,7 @@ async fn run_stream_with_bytes(sse_body: &[u8]) -> Vec { Arc::clone(&config), None, model_family, + None, otel_manager, provider, effort, diff --git a/codex-rs/core/tests/common/responses.rs b/codex-rs/core/tests/common/responses.rs index b98b29625eb..76133a9fa39 100644 --- a/codex-rs/core/tests/common/responses.rs +++ b/codex-rs/core/tests/common/responses.rs @@ -670,6 +670,24 @@ pub async fn mount_models_once(server: &MockServer, body: ModelsResponse) -> Mod models_mock } +pub async fn mount_models_once_with_etag( + server: &MockServer, + body: ModelsResponse, + etag: &str, +) -> ModelsMock { + let (mock, models_mock) = models_mock(); + mock.respond_with( + ResponseTemplate::new(200) + .insert_header("content-type", "application/json") + .insert_header("etag", etag) + .set_body_json(body.clone()), + ) + .up_to_n_times(1) + .mount(server) + .await; + models_mock +} + pub async fn start_mock_server() -> MockServer { let server = MockServer::builder() .body_print_limit(BodyPrintLimit::Limited(80_000)) @@ -677,14 +695,7 @@ pub async fn start_mock_server() -> MockServer { .await; // Provide a default `/models` response so tests remain hermetic when the client queries it. - let _ = mount_models_once( - &server, - ModelsResponse { - models: Vec::new(), - etag: String::new(), - }, - ) - .await; + let _ = mount_models_once(&server, ModelsResponse { models: Vec::new() }).await; server } diff --git a/codex-rs/core/tests/responses_headers.rs b/codex-rs/core/tests/responses_headers.rs index 5c32685cc92..f005ec5be9e 100644 --- a/codex-rs/core/tests/responses_headers.rs +++ b/codex-rs/core/tests/responses_headers.rs @@ -86,6 +86,7 @@ async fn responses_stream_includes_subagent_header_on_review() { Arc::clone(&config), None, model_family, + None, otel_manager, provider, effort, @@ -181,6 +182,7 @@ async fn responses_stream_includes_subagent_header_on_other() { Arc::clone(&config), None, model_family, + None, otel_manager, provider, effort, @@ -275,6 +277,7 @@ async fn responses_respects_model_family_overrides_from_config() { Arc::clone(&config), None, model_family, + None, otel_manager, provider, effort, diff --git a/codex-rs/core/tests/suite/client.rs b/codex-rs/core/tests/suite/client.rs index bda232433da..65b1a7f4c8e 100644 --- a/codex-rs/core/tests/suite/client.rs +++ b/codex-rs/core/tests/suite/client.rs @@ -1146,6 +1146,7 @@ async fn azure_responses_request_includes_store_and_reasoning_ids() { Arc::clone(&config), None, model_family, + None, otel_manager, provider, effort, diff --git a/codex-rs/core/tests/suite/remote_models.rs b/codex-rs/core/tests/suite/remote_models.rs index 3c4d389ec05..d54405fe6dc 100644 --- a/codex-rs/core/tests/suite/remote_models.rs +++ b/codex-rs/core/tests/suite/remote_models.rs @@ -33,8 +33,12 @@ use core_test_support::responses::ev_assistant_message; use core_test_support::responses::ev_completed; use core_test_support::responses::ev_function_call; use core_test_support::responses::ev_response_created; +use core_test_support::responses::ev_shell_command_call; use core_test_support::responses::mount_models_once; +use core_test_support::responses::mount_models_once_with_etag; +use core_test_support::responses::mount_response_once_match; use core_test_support::responses::mount_sse_once; +use core_test_support::responses::mount_sse_once_match; use core_test_support::responses::mount_sse_sequence; use core_test_support::responses::sse; use core_test_support::skip_if_no_network; @@ -42,6 +46,7 @@ use core_test_support::skip_if_sandbox; use core_test_support::wait_for_event; use core_test_support::wait_for_event_match; use pretty_assertions::assert_eq; +use serde_json::Value; use serde_json::json; use tempfile::TempDir; use tokio::time::Duration; @@ -49,9 +54,92 @@ use tokio::time::Instant; use tokio::time::sleep; use wiremock::BodyPrintLimit; use wiremock::MockServer; +use wiremock::ResponseTemplate; const REMOTE_MODEL_SLUG: &str = "codex-test"; +#[derive(Clone, Default)] +struct ResponsesMatch { + etag: Option, + user_text: Option, + call_id: Option, +} + +impl ResponsesMatch { + fn with_etag(mut self, etag: &str) -> Self { + self.etag = Some(etag.to_string()); + self + } + + fn with_user_text(mut self, text: &str) -> Self { + self.user_text = Some(text.to_string()); + self + } + + fn with_function_call_output(mut self, call_id: &str) -> Self { + self.call_id = Some(call_id.to_string()); + self + } +} + +impl wiremock::Match for ResponsesMatch { + fn matches(&self, request: &wiremock::Request) -> bool { + if let Some(expected_etag) = &self.etag { + let header = request + .headers + .get("X-If-Models-Match") + .and_then(|value| value.to_str().ok()); + if header != Some(expected_etag.as_str()) { + return false; + } + } + + let Ok(body): Result = request.body_json() else { + return false; + }; + let Some(items) = body.get("input").and_then(Value::as_array) else { + return false; + }; + + if let Some(expected_text) = &self.user_text + && !input_has_user_text(items, expected_text) + { + return false; + } + + if let Some(expected_call_id) = &self.call_id + && !input_has_function_call_output(items, expected_call_id) + { + return false; + } + + true + } +} + +fn input_has_user_text(items: &[Value], expected: &str) -> bool { + items.iter().any(|item| { + item.get("type").and_then(Value::as_str) == Some("message") + && item.get("role").and_then(Value::as_str) == Some("user") + && item + .get("content") + .and_then(Value::as_array) + .is_some_and(|content| { + content.iter().any(|span| { + span.get("type").and_then(Value::as_str) == Some("input_text") + && span.get("text").and_then(Value::as_str) == Some(expected) + }) + }) + }) +} + +fn input_has_function_call_output(items: &[Value], call_id: &str) -> bool { + items.iter().any(|item| { + item.get("type").and_then(Value::as_str) == Some("function_call_output") + && item.get("call_id").and_then(Value::as_str) == Some(call_id) + }) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn remote_models_remote_model_uses_unified_exec() -> Result<()> { skip_if_no_network!(Ok(())); @@ -93,7 +181,6 @@ async fn remote_models_remote_model_uses_unified_exec() -> Result<()> { &server, ModelsResponse { models: vec![remote_model], - etag: String::new(), }, ) .await; @@ -232,7 +319,6 @@ async fn remote_models_apply_remote_base_instructions() -> Result<()> { &server, ModelsResponse { models: vec![remote_model], - etag: String::new(), }, ) .await; @@ -299,6 +385,208 @@ async fn remote_models_apply_remote_base_instructions() -> Result<()> { Ok(()) } +/// Exercises the remote-models retry flow: +/// 1) initial `/models` fetch stores an ETag, +/// 2) `/responses` uses that ETag for a tool call, +/// 3) the tool-output turn receives a 412 (stale models), +/// 4) Codex refreshes `/models` to get a new ETag and retries, +/// 5) subsequent user turns keep sending the refreshed ETag. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn remote_models_refresh_etag_after_outdated_models() -> Result<()> { + skip_if_no_network!(Ok(())); + skip_if_sandbox!(Ok(())); + + let server = MockServer::builder() + .body_print_limit(BodyPrintLimit::Limited(80_000)) + .start() + .await; + + let remote_model = test_remote_model("remote-etag", ModelVisibility::List, 1); + let initial_etag = "models-etag-initial"; + let refreshed_etag = "models-etag-refreshed"; + + // Phase 1a: seed the initial `/models` response with an ETag. + let models_mock = mount_models_once_with_etag( + &server, + ModelsResponse { + models: vec![remote_model.clone()], + }, + initial_etag, + ) + .await; + + // Phase 1b: boot a Codex session configured for remote models. + let harness = build_remote_models_harness(&server, |config| { + config.features.enable(Feature::RemoteModels); + config.model = Some("gpt-5.1".to_string()); + }) + .await?; + + let RemoteModelsHarness { + codex, + cwd, + config, + conversation_manager, + .. + } = harness; + + let models_manager = conversation_manager.get_models_manager(); + wait_for_model_available(&models_manager, "remote-etag", &config).await; + + // Phase 1c: confirm the ETag is stored and `/models` was called. + assert_eq!( + models_manager.get_models_etag().await.as_deref(), + Some(initial_etag), + ); + assert_eq!( + models_mock.requests().len(), + 1, + "expected an initial /models request", + ); + assert_eq!(models_mock.requests()[0].url.path(), "/v1/models"); + + // Phase 2a: reset mocks so the next `/models` call must be explicit. + server.reset().await; + // Phase 2b: mount a refreshed `/models` response with a new ETag. + let refreshed_models_mock = mount_models_once_with_etag( + &server, + ModelsResponse { + models: vec![remote_model], + }, + refreshed_etag, + ) + .await; + + let call_id = "shell-command-call"; + let first_prompt = "run a shell command"; + let followup_prompt = "send another message"; + + // Phase 2c: first `/responses` turn uses the initial ETag and emits a tool call. + let first_response = mount_sse_once_match( + &server, + ResponsesMatch::default() + .with_etag(initial_etag) + .with_user_text(first_prompt), + sse(vec![ + ev_response_created("resp-1"), + ev_shell_command_call(call_id, "echo refreshed"), + ev_completed("resp-1"), + ]), + ) + .await; + + // Phase 2d: the tool-output follow-up returns 412 (stale models). + let stale_response = mount_response_once_match( + &server, + ResponsesMatch::default() + .with_etag(initial_etag) + .with_function_call_output(call_id), + ResponseTemplate::new(412) + .set_body_string("Models catalog has changed. Please refresh your models list."), + ) + .await; + + // Phase 2e: retry tool-output follow-up should use the refreshed ETag. + let refreshed_response = mount_sse_once_match( + &server, + ResponsesMatch::default() + .with_etag(refreshed_etag) + .with_function_call_output(call_id), + sse(vec![ + ev_response_created("resp-2"), + ev_assistant_message("msg-1", "done"), + ev_completed("resp-2"), + ]), + ) + .await; + + // Phase 3a: next user turn should also use the refreshed ETag. + let next_turn_response = mount_sse_once_match( + &server, + ResponsesMatch::default() + .with_etag(refreshed_etag) + .with_user_text(followup_prompt), + sse(vec![ + ev_response_created("resp-3"), + ev_assistant_message("msg-2", "ok"), + ev_completed("resp-3"), + ]), + ) + .await; + + // Phase 3b: run the first user turn and let retries complete. + codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: first_prompt.into(), + }], + final_output_json_schema: None, + cwd: cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: "gpt-5.1".to_string(), + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await; + + // Phase 3c: assert the refresh happened and the ETag was updated. + assert_eq!( + refreshed_models_mock.requests().len(), + 1, + "expected a refreshed /models request", + ); + assert_eq!( + models_manager.get_models_etag().await.as_deref(), + Some(refreshed_etag), + ); + + // Phase 3d: assert the ETag header progression across the retry sequence. + assert_eq!( + first_response.single_request().header("X-If-Models-Match"), + Some(initial_etag.to_string()), + ); + assert_eq!( + stale_response.single_request().header("X-If-Models-Match"), + Some(initial_etag.to_string()), + ); + assert_eq!( + refreshed_response + .single_request() + .header("X-If-Models-Match"), + Some(refreshed_etag.to_string()), + ); + + // Phase 3e: execute a new user turn and ensure the refreshed ETag persists. + codex + .submit(Op::UserTurn { + items: vec![UserInput::Text { + text: followup_prompt.into(), + }], + final_output_json_schema: None, + cwd: cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: "gpt-5.1".to_string(), + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + wait_for_event(&codex, |event| matches!(event, EventMsg::TaskComplete(_))).await; + + assert_eq!( + next_turn_response + .single_request() + .header("X-If-Models-Match"), + Some(refreshed_etag.to_string()), + ); + + Ok(()) +} + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn remote_models_preserve_builtin_presets() -> Result<()> { skip_if_no_network!(Ok(())); @@ -310,7 +598,6 @@ async fn remote_models_preserve_builtin_presets() -> Result<()> { &server, ModelsResponse { models: vec![remote_model.clone()], - etag: String::new(), }, ) .await; @@ -330,7 +617,7 @@ async fn remote_models_preserve_builtin_presets() -> Result<()> { ); manager - .refresh_available_models(&config) + .try_refresh_available_models(&config) .await .expect("refresh succeeds"); @@ -368,7 +655,6 @@ async fn remote_models_hide_picker_only_models() -> Result<()> { &server, ModelsResponse { models: vec![remote_model], - etag: String::new(), }, ) .await; diff --git a/codex-rs/protocol/src/openai_models.rs b/codex-rs/protocol/src/openai_models.rs index 28b25bb604e..4cdc0e4ba6f 100644 --- a/codex-rs/protocol/src/openai_models.rs +++ b/codex-rs/protocol/src/openai_models.rs @@ -197,8 +197,6 @@ pub struct ModelInfo { #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, TS, JsonSchema, Default)] pub struct ModelsResponse { pub models: Vec, - #[serde(default)] - pub etag: String, } // convert ModelInfo to ModelPreset