Skip to content
54 changes: 24 additions & 30 deletions codex-rs/codex-api/src/endpoint/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::provider::Provider;
use crate::telemetry::run_with_request_telemetry;
use codex_client::HttpTransport;
use codex_client::RequestTelemetry;
use codex_protocol::openai_models::ModelInfo;
use codex_protocol::openai_models::ModelsResponse;
use http::HeaderMap;
use http::Method;
Expand Down Expand Up @@ -41,7 +42,7 @@ impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
&self,
client_version: &str,
extra_headers: HeaderMap,
) -> Result<ModelsResponse, ApiError> {
) -> Result<(Vec<ModelInfo>, Option<String>), ApiError> {
let builder = || {
let mut req = self.provider.build_request(Method::GET, self.path());
req.headers.extend(extra_headers.clone());
Expand All @@ -66,17 +67,15 @@ impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
.and_then(|value| value.to_str().ok())
.map(ToString::to_string);

let ModelsResponse { models, etag } = serde_json::from_slice::<ModelsResponse>(&resp.body)
let ModelsResponse { models } = serde_json::from_slice::<ModelsResponse>(&resp.body)
.map_err(|e| {
ApiError::Stream(format!(
"failed to decode models response: {e}; body: {}",
String::from_utf8_lossy(&resp.body)
))
})?;

let etag = header_etag.unwrap_or(etag);

Ok(ModelsResponse { models, etag })
Ok((models, header_etag))
}
}

Expand All @@ -102,16 +101,15 @@ mod tests {
struct CapturingTransport {
last_request: Arc<Mutex<Option<Request>>>,
body: Arc<ModelsResponse>,
response_etag: Arc<Option<String>>,
}

impl Default for CapturingTransport {
fn default() -> Self {
Self {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(ModelsResponse {
models: Vec::new(),
etag: String::new(),
}),
body: Arc::new(ModelsResponse { models: Vec::new() }),
response_etag: Arc::new(None),
}
}
}
Expand All @@ -122,8 +120,8 @@ mod tests {
*self.last_request.lock().unwrap() = Some(req);
let body = serde_json::to_vec(&*self.body).unwrap();
let mut headers = HeaderMap::new();
if !self.body.etag.is_empty() {
headers.insert(ETAG, self.body.etag.parse().unwrap());
if let Some(etag) = self.response_etag.as_ref().as_deref() {
headers.insert(ETAG, etag.parse().unwrap());
}
Ok(Response {
status: StatusCode::OK,
Expand Down Expand Up @@ -166,14 +164,12 @@ mod tests {

#[tokio::test]
async fn appends_client_version_query() {
let response = ModelsResponse {
models: Vec::new(),
etag: String::new(),
};
let response = ModelsResponse { models: Vec::new() };

let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(response),
response_etag: Arc::new(None),
};

let client = ModelsClient::new(
Expand All @@ -182,12 +178,12 @@ mod tests {
DummyAuth,
);

let result = client
let (models, _etag) = client
.list_models("0.99.0", HeaderMap::new())
.await
.expect("request should succeed");

assert_eq!(result.models.len(), 0);
assert_eq!(models.len(), 0);

let url = transport
.last_request
Expand Down Expand Up @@ -232,12 +228,12 @@ mod tests {
}))
.unwrap(),
],
etag: String::new(),
};

let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(response),
response_etag: Arc::new(None),
};

let client = ModelsClient::new(
Expand All @@ -246,27 +242,25 @@ mod tests {
DummyAuth,
);

let result = client
let (models, _etag) = client
.list_models("0.99.0", HeaderMap::new())
.await
.expect("request should succeed");

assert_eq!(result.models.len(), 1);
assert_eq!(result.models[0].slug, "gpt-test");
assert_eq!(result.models[0].supported_in_api, true);
assert_eq!(result.models[0].priority, 1);
assert_eq!(models.len(), 1);
assert_eq!(models[0].slug, "gpt-test");
assert_eq!(models[0].supported_in_api, true);
assert_eq!(models[0].priority, 1);
}

#[tokio::test]
async fn list_models_includes_etag() {
let response = ModelsResponse {
models: Vec::new(),
etag: "\"abc\"".to_string(),
};
let response = ModelsResponse { models: Vec::new() };

let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(response),
response_etag: Arc::new(Some("\"abc\"".to_string())),
};

let client = ModelsClient::new(
Expand All @@ -275,12 +269,12 @@ mod tests {
DummyAuth,
);

let result = client
let (models, etag) = client
.list_models("0.1.0", HeaderMap::new())
.await
.expect("request should succeed");

assert_eq!(result.models.len(), 0);
assert_eq!(result.etag, "\"abc\"");
assert_eq!(models.len(), 0);
assert_eq!(etag.as_deref(), Some("\"abc\""));
}
}
7 changes: 3 additions & 4 deletions codex-rs/codex-api/tests/models_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ async fn models_client_hits_models_endpoint() {
reasoning_summary_format: ReasoningSummaryFormat::None,
experimental_supported_tools: Vec::new(),
}],
etag: String::new(),
};

Mock::given(method("GET"))
Expand All @@ -106,13 +105,13 @@ async fn models_client_hits_models_endpoint() {
let transport = ReqwestTransport::new(reqwest::Client::new());
let client = ModelsClient::new(transport, provider(&base_url), DummyAuth);

let result = client
let (models, _etag) = client
.list_models("0.1.0", HeaderMap::new())
.await
.expect("models request should succeed");

assert_eq!(result.models.len(), 1);
assert_eq!(result.models[0].slug, "gpt-test");
assert_eq!(models.len(), 1);
assert_eq!(models[0].slug, "gpt-test");

let received = server
.received_requests()
Expand Down
5 changes: 5 additions & 0 deletions codex-rs/core/src/api_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ pub(crate) fn map_api_error(err: ApiError) -> CodexErr {
status,
request_id: extract_request_id(headers.as_ref()),
})
} else if status == http::StatusCode::PRECONDITION_FAILED
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use a status code, don't match on strings

&& body_text
.contains("Models catalog has changed. Please refresh your models list.")
{
CodexErr::OutdatedModels
} else {
CodexErr::UnexpectedStatus(UnexpectedResponseError {
status,
Expand Down
58 changes: 41 additions & 17 deletions codex-rs/core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use http::StatusCode as HttpStatusCode;
use reqwest::StatusCode;
use serde_json::Value;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::sync::mpsc;
use tracing::warn;

Expand All @@ -53,11 +54,12 @@ use crate::openai_models::model_family::ModelFamily;
use crate::tools::spec::create_tools_json_for_chat_completions_api;
use crate::tools::spec::create_tools_json_for_responses_api;

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct ModelClient {
config: Arc<Config>,
auth_manager: Option<Arc<AuthManager>>,
model_family: ModelFamily,
model_family: RwLock<ModelFamily>,
models_etag: RwLock<Option<String>>,
Comment on lines +61 to +62
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two values are connected and should be only used/updated together. Let's not spread them across two fields.

otel_manager: OtelManager,
provider: ModelProviderInfo,
conversation_id: ConversationId,
Expand All @@ -72,6 +74,7 @@ impl ModelClient {
config: Arc<Config>,
auth_manager: Option<Arc<AuthManager>>,
model_family: ModelFamily,
models_etag: Option<String>,
otel_manager: OtelManager,
provider: ModelProviderInfo,
effort: Option<ReasoningEffortConfig>,
Expand All @@ -82,7 +85,8 @@ impl ModelClient {
Self {
config,
auth_manager,
model_family,
model_family: RwLock::new(model_family),
models_etag: RwLock::new(models_etag),
otel_manager,
provider,
conversation_id,
Expand All @@ -92,8 +96,8 @@ impl ModelClient {
}
}

pub fn get_model_context_window(&self) -> Option<i64> {
let model_family = self.get_model_family();
pub async fn get_model_context_window(&self) -> Option<i64> {
let model_family = self.get_model_family().await;
let effective_context_window_percent = model_family.effective_context_window_percent;
model_family
.context_window
Expand Down Expand Up @@ -146,7 +150,7 @@ impl ModelClient {
}

let auth_manager = self.auth_manager.clone();
let model_family = self.get_model_family();
let model_family = self.get_model_family().await;
let instructions = prompt.get_full_instructions(&model_family).into_owned();
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
let api_prompt = build_api_prompt(prompt, instructions, tools_json);
Expand All @@ -167,7 +171,7 @@ impl ModelClient {

let stream_result = client
.stream_prompt(
&self.get_model(),
&self.get_model().await,
&api_prompt,
Some(conversation_id.clone()),
Some(session_source.clone()),
Expand Down Expand Up @@ -200,7 +204,7 @@ impl ModelClient {
}

let auth_manager = self.auth_manager.clone();
let model_family = self.get_model_family();
let model_family = self.get_model_family().await;
let instructions = prompt.get_full_instructions(&model_family).into_owned();
let tools_json: Vec<Value> = create_tools_json_for_responses_api(&prompt.tools)?;

Expand Down Expand Up @@ -262,11 +266,14 @@ impl ModelClient {
store_override: None,
conversation_id: Some(conversation_id.clone()),
session_source: Some(session_source.clone()),
extra_headers: beta_feature_headers(&self.config),
extra_headers: beta_feature_headers(
&self.config,
self.get_models_etag().await.clone(),
),
};

let stream_result = client
.stream_prompt(&self.get_model(), &api_prompt, options)
.stream_prompt(&self.get_model().await, &api_prompt, options)
.await;

match stream_result {
Expand Down Expand Up @@ -297,13 +304,25 @@ impl ModelClient {
}

/// Returns the currently configured model slug.
pub fn get_model(&self) -> String {
self.get_model_family().get_model_slug().to_string()
pub async fn get_model(&self) -> String {
self.get_model_family().await.get_model_slug().to_string()
}

/// Returns the currently configured model family.
pub fn get_model_family(&self) -> ModelFamily {
self.model_family.clone()
pub async fn get_model_family(&self) -> ModelFamily {
self.model_family.read().await.clone()
}

pub async fn get_models_etag(&self) -> Option<String> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't make the client event larger model storage God-object.

self.models_etag.read().await.clone()
}

pub async fn update_models_etag(&self, etag: Option<String>) {
*self.models_etag.write().await = etag;
}

pub async fn update_model_family(&self, model_family: ModelFamily) {
*self.model_family.write().await = model_family;
}

/// Returns the current reasoning effort setting.
Expand Down Expand Up @@ -340,10 +359,10 @@ impl ModelClient {
.with_telemetry(Some(request_telemetry));

let instructions = prompt
.get_full_instructions(&self.get_model_family())
.get_full_instructions(&self.get_model_family().await)
.into_owned();
let payload = ApiCompactionInput {
model: &self.get_model(),
model: &self.get_model().await,
input: &prompt.input,
instructions: &instructions,
};
Expand Down Expand Up @@ -398,7 +417,7 @@ fn build_api_prompt(prompt: &Prompt, instructions: String, tools_json: Vec<Value
}
}

fn beta_feature_headers(config: &Config) -> ApiHeaderMap {
fn beta_feature_headers(config: &Config, models_etag: Option<String>) -> ApiHeaderMap {
let enabled = FEATURES
.iter()
.filter_map(|spec| {
Expand All @@ -416,6 +435,11 @@ fn beta_feature_headers(config: &Config) -> ApiHeaderMap {
{
headers.insert("x-codex-beta-features", header_value);
}
if let Some(etag) = models_etag
&& let Ok(header_value) = HeaderValue::from_str(&etag)
{
headers.insert("X-If-Models-Match", header_value);
}
headers
}

Expand Down
Loading
Loading