Skip to content
54 changes: 24 additions & 30 deletions codex-rs/codex-api/src/endpoint/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::provider::Provider;
use crate::telemetry::run_with_request_telemetry;
use codex_client::HttpTransport;
use codex_client::RequestTelemetry;
use codex_protocol::openai_models::ModelInfo;
use codex_protocol::openai_models::ModelsResponse;
use http::HeaderMap;
use http::Method;
Expand Down Expand Up @@ -41,7 +42,7 @@ impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
&self,
client_version: &str,
extra_headers: HeaderMap,
) -> Result<ModelsResponse, ApiError> {
) -> Result<(Vec<ModelInfo>, Option<String>), ApiError> {
let builder = || {
let mut req = self.provider.build_request(Method::GET, self.path());
req.headers.extend(extra_headers.clone());
Expand All @@ -66,17 +67,15 @@ impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
.and_then(|value| value.to_str().ok())
.map(ToString::to_string);

let ModelsResponse { models, etag } = serde_json::from_slice::<ModelsResponse>(&resp.body)
let ModelsResponse { models } = serde_json::from_slice::<ModelsResponse>(&resp.body)
.map_err(|e| {
ApiError::Stream(format!(
"failed to decode models response: {e}; body: {}",
String::from_utf8_lossy(&resp.body)
))
})?;

let etag = header_etag.unwrap_or(etag);

Ok(ModelsResponse { models, etag })
Ok((models, header_etag))
}
}

Expand All @@ -102,16 +101,15 @@ mod tests {
struct CapturingTransport {
last_request: Arc<Mutex<Option<Request>>>,
body: Arc<ModelsResponse>,
response_etag: Arc<Option<String>>,
}

impl Default for CapturingTransport {
fn default() -> Self {
Self {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(ModelsResponse {
models: Vec::new(),
etag: String::new(),
}),
body: Arc::new(ModelsResponse { models: Vec::new() }),
response_etag: Arc::new(None),
}
}
}
Expand All @@ -122,8 +120,8 @@ mod tests {
*self.last_request.lock().unwrap() = Some(req);
let body = serde_json::to_vec(&*self.body).unwrap();
let mut headers = HeaderMap::new();
if !self.body.etag.is_empty() {
headers.insert(ETAG, self.body.etag.parse().unwrap());
if let Some(etag) = self.response_etag.as_ref().as_deref() {
headers.insert(ETAG, etag.parse().unwrap());
}
Ok(Response {
status: StatusCode::OK,
Expand Down Expand Up @@ -166,14 +164,12 @@ mod tests {

#[tokio::test]
async fn appends_client_version_query() {
let response = ModelsResponse {
models: Vec::new(),
etag: String::new(),
};
let response = ModelsResponse { models: Vec::new() };

let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(response),
response_etag: Arc::new(None),
};

let client = ModelsClient::new(
Expand All @@ -182,12 +178,12 @@ mod tests {
DummyAuth,
);

let result = client
let (models, _etag) = client
.list_models("0.99.0", HeaderMap::new())
.await
.expect("request should succeed");

assert_eq!(result.models.len(), 0);
assert_eq!(models.len(), 0);

let url = transport
.last_request
Expand Down Expand Up @@ -232,12 +228,12 @@ mod tests {
}))
.unwrap(),
],
etag: String::new(),
};

let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(response),
response_etag: Arc::new(None),
};

let client = ModelsClient::new(
Expand All @@ -246,27 +242,25 @@ mod tests {
DummyAuth,
);

let result = client
let (models, _etag) = client
.list_models("0.99.0", HeaderMap::new())
.await
.expect("request should succeed");

assert_eq!(result.models.len(), 1);
assert_eq!(result.models[0].slug, "gpt-test");
assert_eq!(result.models[0].supported_in_api, true);
assert_eq!(result.models[0].priority, 1);
assert_eq!(models.len(), 1);
assert_eq!(models[0].slug, "gpt-test");
assert_eq!(models[0].supported_in_api, true);
assert_eq!(models[0].priority, 1);
}

#[tokio::test]
async fn list_models_includes_etag() {
let response = ModelsResponse {
models: Vec::new(),
etag: "\"abc\"".to_string(),
};
let response = ModelsResponse { models: Vec::new() };

let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(response),
response_etag: Arc::new(Some("\"abc\"".to_string())),
};

let client = ModelsClient::new(
Expand All @@ -275,12 +269,12 @@ mod tests {
DummyAuth,
);

let result = client
let (models, etag) = client
.list_models("0.1.0", HeaderMap::new())
.await
.expect("request should succeed");

assert_eq!(result.models.len(), 0);
assert_eq!(result.etag, "\"abc\"");
assert_eq!(models.len(), 0);
assert_eq!(etag.as_deref(), Some("\"abc\""));
}
}
7 changes: 3 additions & 4 deletions codex-rs/codex-api/tests/models_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ async fn models_client_hits_models_endpoint() {
reasoning_summary_format: ReasoningSummaryFormat::None,
experimental_supported_tools: Vec::new(),
}],
etag: String::new(),
};

Mock::given(method("GET"))
Expand All @@ -106,13 +105,13 @@ async fn models_client_hits_models_endpoint() {
let transport = ReqwestTransport::new(reqwest::Client::new());
let client = ModelsClient::new(transport, provider(&base_url), DummyAuth);

let result = client
let (models, _etag) = client
.list_models("0.1.0", HeaderMap::new())
.await
.expect("models request should succeed");

assert_eq!(result.models.len(), 1);
assert_eq!(result.models[0].slug, "gpt-test");
assert_eq!(models.len(), 1);
assert_eq!(models[0].slug, "gpt-test");

let received = server
.received_requests()
Expand Down
5 changes: 5 additions & 0 deletions codex-rs/core/src/api_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ pub(crate) fn map_api_error(err: ApiError) -> CodexErr {
status,
request_id: extract_request_id(headers.as_ref()),
})
} else if status == http::StatusCode::PRECONDITION_FAILED

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use a status code, don't match on strings

&& body_text
.contains("Models catalog has changed. Please refresh your models list.")
{
CodexErr::OutdatedModels
} else {
CodexErr::UnexpectedStatus(UnexpectedResponseError {
status,
Expand Down
47 changes: 41 additions & 6 deletions codex-rs/core/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::sync::Arc;
use std::sync::RwLock;

use crate::api_bridge::auth_provider_from_auth;
use crate::api_bridge::map_api_error;
Expand Down Expand Up @@ -53,11 +54,12 @@ use crate::openai_models::model_family::ModelFamily;
use crate::tools::spec::create_tools_json_for_chat_completions_api;
use crate::tools::spec::create_tools_json_for_responses_api;

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct ModelClient {
config: Arc<Config>,
auth_manager: Option<Arc<AuthManager>>,
model_family: ModelFamily,
model_family: RwLock<ModelFamily>,
models_etag: RwLock<Option<String>>,
Comment on lines +61 to +62

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two values are connected and should be only used/updated together. Let's not spread them across two fields.

otel_manager: OtelManager,
provider: ModelProviderInfo,
conversation_id: ConversationId,
Expand All @@ -72,6 +74,7 @@ impl ModelClient {
config: Arc<Config>,
auth_manager: Option<Arc<AuthManager>>,
model_family: ModelFamily,
models_etag: Option<String>,
otel_manager: OtelManager,
provider: ModelProviderInfo,
effort: Option<ReasoningEffortConfig>,
Expand All @@ -82,7 +85,8 @@ impl ModelClient {
Self {
config,
auth_manager,
model_family,
model_family: RwLock::new(model_family),
models_etag: RwLock::new(models_etag),
otel_manager,
provider,
conversation_id,
Expand All @@ -108,6 +112,22 @@ impl ModelClient {
&self.provider
}

pub fn update_models_etag(&self, models_etag: Option<String>) {
let mut guard = self
.models_etag
.write()
.unwrap_or_else(std::sync::PoisonError::into_inner);
*guard = models_etag;
}

pub fn update_model_family(&self, model_family: ModelFamily) {
let mut guard = self
.model_family
.write()
.unwrap_or_else(std::sync::PoisonError::into_inner);
*guard = model_family;
}

/// Streams a single model turn using either the Responses or Chat
/// Completions wire API, depending on the configured provider.
///
Expand Down Expand Up @@ -262,7 +282,7 @@ impl ModelClient {
store_override: None,
conversation_id: Some(conversation_id.clone()),
session_source: Some(session_source.clone()),
extra_headers: beta_feature_headers(&self.config),
extra_headers: beta_feature_headers(&self.config, self.get_models_etag()),
};

let stream_result = client
Expand Down Expand Up @@ -303,7 +323,17 @@ impl ModelClient {

/// Returns the currently configured model family.
pub fn get_model_family(&self) -> ModelFamily {
self.model_family.clone()
self.model_family
.read()
.map(|model_family| model_family.clone())
.unwrap_or_else(|err| err.into_inner().clone())
}

fn get_models_etag(&self) -> Option<String> {
self.models_etag
.read()
.map(|models_etag| models_etag.clone())
.unwrap_or_else(|err| err.into_inner().clone())
}

/// Returns the current reasoning effort setting.
Expand Down Expand Up @@ -398,7 +428,7 @@ fn build_api_prompt(prompt: &Prompt, instructions: String, tools_json: Vec<Value
}
}

fn beta_feature_headers(config: &Config) -> ApiHeaderMap {
fn beta_feature_headers(config: &Config, models_etag: Option<String>) -> ApiHeaderMap {
let enabled = FEATURES
.iter()
.filter_map(|spec| {
Expand All @@ -416,6 +446,11 @@ fn beta_feature_headers(config: &Config) -> ApiHeaderMap {
{
headers.insert("x-codex-beta-features", header_value);
}
if let Some(etag) = models_etag
&& let Ok(header_value) = HeaderValue::from_str(&etag)
{
headers.insert("X-If-Models-Match", header_value);
Comment thread
aibrahim-oai marked this conversation as resolved.
}
headers
}

Expand Down
Loading
Loading