Skip to content
54 changes: 24 additions & 30 deletions codex-rs/codex-api/src/endpoint/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::provider::Provider;
use crate::telemetry::run_with_request_telemetry;
use codex_client::HttpTransport;
use codex_client::RequestTelemetry;
use codex_protocol::openai_models::ModelInfo;
use codex_protocol::openai_models::ModelsResponse;
use http::HeaderMap;
use http::Method;
Expand Down Expand Up @@ -41,7 +42,7 @@ impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
&self,
client_version: &str,
extra_headers: HeaderMap,
) -> Result<ModelsResponse, ApiError> {
) -> Result<(Vec<ModelInfo>, Option<String>), ApiError> {
let builder = || {
let mut req = self.provider.build_request(Method::GET, self.path());
req.headers.extend(extra_headers.clone());
Expand All @@ -66,17 +67,15 @@ impl<T: HttpTransport, A: AuthProvider> ModelsClient<T, A> {
.and_then(|value| value.to_str().ok())
.map(ToString::to_string);

let ModelsResponse { models, etag } = serde_json::from_slice::<ModelsResponse>(&resp.body)
let ModelsResponse { models } = serde_json::from_slice::<ModelsResponse>(&resp.body)
.map_err(|e| {
ApiError::Stream(format!(
"failed to decode models response: {e}; body: {}",
String::from_utf8_lossy(&resp.body)
))
})?;

let etag = header_etag.unwrap_or(etag);

Ok(ModelsResponse { models, etag })
Ok((models, header_etag))
}
}

Expand All @@ -102,16 +101,15 @@ mod tests {
struct CapturingTransport {
last_request: Arc<Mutex<Option<Request>>>,
body: Arc<ModelsResponse>,
response_etag: Arc<Option<String>>,
}

impl Default for CapturingTransport {
fn default() -> Self {
Self {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(ModelsResponse {
models: Vec::new(),
etag: String::new(),
}),
body: Arc::new(ModelsResponse { models: Vec::new() }),
response_etag: Arc::new(None),
}
}
}
Expand All @@ -122,8 +120,8 @@ mod tests {
*self.last_request.lock().unwrap() = Some(req);
let body = serde_json::to_vec(&*self.body).unwrap();
let mut headers = HeaderMap::new();
if !self.body.etag.is_empty() {
headers.insert(ETAG, self.body.etag.parse().unwrap());
if let Some(etag) = self.response_etag.as_ref().as_deref() {
headers.insert(ETAG, etag.parse().unwrap());
}
Ok(Response {
status: StatusCode::OK,
Expand Down Expand Up @@ -166,14 +164,12 @@ mod tests {

#[tokio::test]
async fn appends_client_version_query() {
let response = ModelsResponse {
models: Vec::new(),
etag: String::new(),
};
let response = ModelsResponse { models: Vec::new() };

let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(response),
response_etag: Arc::new(None),
};

let client = ModelsClient::new(
Expand All @@ -182,12 +178,12 @@ mod tests {
DummyAuth,
);

let result = client
let (models, _etag) = client
.list_models("0.99.0", HeaderMap::new())
.await
.expect("request should succeed");

assert_eq!(result.models.len(), 0);
assert_eq!(models.len(), 0);

let url = transport
.last_request
Expand Down Expand Up @@ -232,12 +228,12 @@ mod tests {
}))
.unwrap(),
],
etag: String::new(),
};

let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(response),
response_etag: Arc::new(None),
};

let client = ModelsClient::new(
Expand All @@ -246,27 +242,25 @@ mod tests {
DummyAuth,
);

let result = client
let (models, _etag) = client
.list_models("0.99.0", HeaderMap::new())
.await
.expect("request should succeed");

assert_eq!(result.models.len(), 1);
assert_eq!(result.models[0].slug, "gpt-test");
assert_eq!(result.models[0].supported_in_api, true);
assert_eq!(result.models[0].priority, 1);
assert_eq!(models.len(), 1);
assert_eq!(models[0].slug, "gpt-test");
assert_eq!(models[0].supported_in_api, true);
assert_eq!(models[0].priority, 1);
}

#[tokio::test]
async fn list_models_includes_etag() {
let response = ModelsResponse {
models: Vec::new(),
etag: "\"abc\"".to_string(),
};
let response = ModelsResponse { models: Vec::new() };

let transport = CapturingTransport {
last_request: Arc::new(Mutex::new(None)),
body: Arc::new(response),
response_etag: Arc::new(Some("\"abc\"".to_string())),
};

let client = ModelsClient::new(
Expand All @@ -275,12 +269,12 @@ mod tests {
DummyAuth,
);

let result = client
let (models, etag) = client
.list_models("0.1.0", HeaderMap::new())
.await
.expect("request should succeed");

assert_eq!(result.models.len(), 0);
assert_eq!(result.etag, "\"abc\"");
assert_eq!(models.len(), 0);
assert_eq!(etag.as_deref(), Some("\"abc\""));
}
}
7 changes: 3 additions & 4 deletions codex-rs/codex-api/tests/models_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ async fn models_client_hits_models_endpoint() {
reasoning_summary_format: ReasoningSummaryFormat::None,
experimental_supported_tools: Vec::new(),
}],
etag: String::new(),
};

Mock::given(method("GET"))
Expand All @@ -106,13 +105,13 @@ async fn models_client_hits_models_endpoint() {
let transport = ReqwestTransport::new(reqwest::Client::new());
let client = ModelsClient::new(transport, provider(&base_url), DummyAuth);

let result = client
let (models, _etag) = client
.list_models("0.1.0", HeaderMap::new())
.await
.expect("models request should succeed");

assert_eq!(result.models.len(), 1);
assert_eq!(result.models[0].slug, "gpt-test");
assert_eq!(models.len(), 1);
assert_eq!(models[0].slug, "gpt-test");

let received = server
.received_requests()
Expand Down
5 changes: 5 additions & 0 deletions codex-rs/core/src/api_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ pub(crate) fn map_api_error(err: ApiError) -> CodexErr {
status,
request_id: extract_request_id(headers.as_ref()),
})
} else if status == http::StatusCode::PRECONDITION_FAILED
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use a status code, don't match on strings

&& body_text
.contains("Models catalog has changed. Please refresh your models list.")
{
CodexErr::OutdatedModels
} else {
CodexErr::UnexpectedStatus(UnexpectedResponseError {
status,
Expand Down
28 changes: 20 additions & 8 deletions codex-rs/core/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,12 @@ use crate::openai_models::model_family::ModelFamily;
use crate::tools::spec::create_tools_json_for_chat_completions_api;
use crate::tools::spec::create_tools_json_for_responses_api;

#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct ModelClient {
config: Arc<Config>,
auth_manager: Option<Arc<AuthManager>>,
model_family: ModelFamily,
models_etag: Option<String>,
otel_manager: OtelManager,
provider: ModelProviderInfo,
conversation_id: ConversationId,
Expand All @@ -72,6 +73,7 @@ impl ModelClient {
config: Arc<Config>,
auth_manager: Option<Arc<AuthManager>>,
model_family: ModelFamily,
models_etag: Option<String>,
otel_manager: OtelManager,
provider: ModelProviderInfo,
effort: Option<ReasoningEffortConfig>,
Expand All @@ -83,6 +85,7 @@ impl ModelClient {
config,
auth_manager,
model_family,
models_etag,
otel_manager,
provider,
conversation_id,
Expand Down Expand Up @@ -147,7 +150,7 @@ impl ModelClient {

let auth_manager = self.auth_manager.clone();
let model_family = self.get_model_family();
let instructions = prompt.get_full_instructions(&model_family).into_owned();
let instructions = prompt.get_full_instructions(model_family).into_owned();
let tools_json = create_tools_json_for_chat_completions_api(&prompt.tools)?;
let api_prompt = build_api_prompt(prompt, instructions, tools_json);
let conversation_id = self.conversation_id.to_string();
Expand Down Expand Up @@ -201,7 +204,7 @@ impl ModelClient {

let auth_manager = self.auth_manager.clone();
let model_family = self.get_model_family();
let instructions = prompt.get_full_instructions(&model_family).into_owned();
let instructions = prompt.get_full_instructions(model_family).into_owned();
let tools_json: Vec<Value> = create_tools_json_for_responses_api(&prompt.tools)?;

let reasoning = if model_family.supports_reasoning_summaries {
Expand Down Expand Up @@ -262,7 +265,7 @@ impl ModelClient {
store_override: None,
conversation_id: Some(conversation_id.clone()),
session_source: Some(session_source.clone()),
extra_headers: beta_feature_headers(&self.config),
extra_headers: beta_feature_headers(&self.config, self.get_models_etag().clone()),
};

let stream_result = client
Expand Down Expand Up @@ -302,8 +305,12 @@ impl ModelClient {
}

/// Returns the currently configured model family.
pub fn get_model_family(&self) -> ModelFamily {
self.model_family.clone()
pub fn get_model_family(&self) -> &ModelFamily {
&self.model_family
}

fn get_models_etag(&self) -> &Option<String> {
&self.models_etag
}

/// Returns the current reasoning effort setting.
Expand Down Expand Up @@ -340,7 +347,7 @@ impl ModelClient {
.with_telemetry(Some(request_telemetry));

let instructions = prompt
.get_full_instructions(&self.get_model_family())
.get_full_instructions(self.get_model_family())
.into_owned();
let payload = ApiCompactionInput {
model: &self.get_model(),
Expand Down Expand Up @@ -398,7 +405,7 @@ fn build_api_prompt(prompt: &Prompt, instructions: String, tools_json: Vec<Value
}
}

fn beta_feature_headers(config: &Config) -> ApiHeaderMap {
fn beta_feature_headers(config: &Config, models_etag: Option<String>) -> ApiHeaderMap {
let enabled = FEATURES
.iter()
.filter_map(|spec| {
Expand All @@ -416,6 +423,11 @@ fn beta_feature_headers(config: &Config) -> ApiHeaderMap {
{
headers.insert("x-codex-beta-features", header_value);
}
if let Some(etag) = models_etag
&& let Ok(header_value) = HeaderValue::from_str(&etag)
{
headers.insert("X-If-Models-Match", header_value);
}
headers
}

Expand Down
25 changes: 25 additions & 0 deletions codex-rs/core/src/client_common.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use crate::client_common::tools::ToolSpec;
use crate::codex::Session;
use crate::codex::TurnContext;
use crate::error::Result;
use crate::features::Feature;
use crate::openai_models::model_family::ModelFamily;
use crate::tools::ToolRouter;
pub use codex_api::common::ResponseEvent;
use codex_apply_patch::APPLY_PATCH_TOOL_INSTRUCTIONS;
use codex_protocol::models::ResponseItem;
Expand Down Expand Up @@ -44,6 +48,27 @@ pub struct Prompt {
}

impl Prompt {
pub(crate) fn new(
sess: &Session,
turn_context: &TurnContext,
router: &ToolRouter,
input: &[ResponseItem],
) -> Prompt {
let model_supports_parallel = turn_context
.client
.get_model_family()
.supports_parallel_tool_calls;

Prompt {
input: input.to_vec(),
tools: router.specs(),
parallel_tool_calls: model_supports_parallel
&& sess.enabled(Feature::ParallelToolCalls),
base_instructions_override: turn_context.base_instructions.clone(),
output_schema: turn_context.final_output_json_schema.clone(),
}
}

pub(crate) fn get_full_instructions<'a>(&'a self, model: &'a ModelFamily) -> Cow<'a, str> {
let base = self
.base_instructions_override
Expand Down
Loading
Loading