From 3ad2850ffc7d8a1da19c65a92425637a59098f1b Mon Sep 17 00:00:00 2001 From: jif-oai Date: Mon, 18 May 2026 16:27:17 +0200 Subject: [PATCH 1/3] feat: add `ToolLifecycleContributor` --- codex-rs/core/src/tools/lifecycle.rs | 98 +++++++++ codex-rs/core/src/tools/mod.rs | 1 + codex-rs/core/src/tools/parallel.rs | 15 +- codex-rs/core/src/tools/registry.rs | 43 +++- codex-rs/core/src/tools/registry_tests.rs | 194 ++++++++++++++++++ .../ext/extension-api/src/contributors.rs | 23 +++ .../src/contributors/tool_lifecycle.rs | 82 ++++++++ codex-rs/ext/extension-api/src/lib.rs | 6 + codex-rs/ext/extension-api/src/registry.rs | 15 ++ codex-rs/ext/goal/src/extension.rs | 47 +++++ 10 files changed, 520 insertions(+), 4 deletions(-) create mode 100644 codex-rs/core/src/tools/lifecycle.rs create mode 100644 codex-rs/ext/extension-api/src/contributors/tool_lifecycle.rs diff --git a/codex-rs/core/src/tools/lifecycle.rs b/codex-rs/core/src/tools/lifecycle.rs new file mode 100644 index 000000000000..ad8b492cce78 --- /dev/null +++ b/codex-rs/core/src/tools/lifecycle.rs @@ -0,0 +1,98 @@ +use codex_extension_api::ToolCallOutcome; +use codex_extension_api::ToolCallSource as ExtensionToolCallSource; +use codex_extension_api::ToolFinishInput; +use codex_extension_api::ToolStartInput; +use codex_tools::ToolName; + +use crate::session::session::Session; +use crate::session::turn_context::TurnContext; +use crate::tools::context::ToolCallSource; +use crate::tools::context::ToolInvocation; + +pub(crate) async fn notify_tool_start(invocation: &ToolInvocation) { + for contributor in invocation + .session + .services + .extensions + .tool_lifecycle_contributors() + { + contributor + .on_tool_start(ToolStartInput { + session_store: &invocation.session.services.session_extension_data, + thread_store: &invocation.session.services.thread_extension_data, + turn_store: invocation.turn.extension_data.as_ref(), + turn_id: invocation.turn.sub_id.as_str(), + call_id: invocation.call_id.as_str(), + tool_name: &invocation.tool_name, + source: extension_tool_call_source(invocation.source.clone()), + }) + .await; + } +} + +pub(crate) async fn notify_tool_finish(invocation: &ToolInvocation, outcome: ToolCallOutcome) { + notify_tool_finish_parts( + invocation.session.as_ref(), + invocation.turn.as_ref(), + invocation.call_id.as_str(), + &invocation.tool_name, + invocation.source.clone(), + outcome, + ) + .await; +} + +pub(crate) async fn notify_tool_aborted( + session: &Session, + turn: &TurnContext, + call_id: &str, + tool_name: &ToolName, + source: ToolCallSource, +) { + notify_tool_finish_parts( + session, + turn, + call_id, + tool_name, + source, + ToolCallOutcome::Aborted, + ) + .await; +} + +async fn notify_tool_finish_parts( + session: &Session, + turn: &TurnContext, + call_id: &str, + tool_name: &ToolName, + source: ToolCallSource, + outcome: ToolCallOutcome, +) { + for contributor in session.services.extensions.tool_lifecycle_contributors() { + contributor + .on_tool_finish(ToolFinishInput { + session_store: &session.services.session_extension_data, + thread_store: &session.services.thread_extension_data, + turn_store: turn.extension_data.as_ref(), + turn_id: turn.sub_id.as_str(), + call_id, + tool_name, + source: extension_tool_call_source(source.clone()), + outcome, + }) + .await; + } +} + +fn extension_tool_call_source(source: ToolCallSource) -> ExtensionToolCallSource { + match source { + ToolCallSource::Direct => ExtensionToolCallSource::Direct, + ToolCallSource::CodeMode { + cell_id, + runtime_tool_call_id, + } => ExtensionToolCallSource::CodeMode { + cell_id, + runtime_tool_call_id, + }, + } +} diff --git a/codex-rs/core/src/tools/mod.rs b/codex-rs/core/src/tools/mod.rs index 5cd943e199c3..f6c176262023 100644 --- a/codex-rs/core/src/tools/mod.rs +++ b/codex-rs/core/src/tools/mod.rs @@ -4,6 +4,7 @@ pub(crate) mod events; pub(crate) mod handlers; pub(crate) mod hook_names; pub(crate) mod hosted_spec; +pub(crate) mod lifecycle; pub(crate) mod network_approval; pub(crate) mod orchestrator; pub(crate) mod parallel; diff --git a/codex-rs/core/src/tools/parallel.rs b/codex-rs/core/src/tools/parallel.rs index 4c79e4b1686b..a985784c8c18 100644 --- a/codex-rs/core/src/tools/parallel.rs +++ b/codex-rs/core/src/tools/parallel.rs @@ -15,6 +15,7 @@ use crate::session::turn_context::TurnContext; use crate::tools::context::AbortedToolOutput; use crate::tools::context::SharedTurnDiffTracker; use crate::tools::context::ToolPayload; +use crate::tools::lifecycle::notify_tool_aborted; use crate::tools::registry::AnyToolResult; use crate::tools::registry::ToolArgumentDiffConsumer; use crate::tools::router::ToolCall; @@ -89,6 +90,9 @@ impl ToolCallRuntime { let lock = Arc::clone(&self.parallel_execution); let invocation_cancellation_token = cancellation_token.clone(); let started = Instant::now(); + let abort_session = Arc::clone(&session); + let abort_source = source.clone(); + let abort_turn = Arc::clone(&turn); let dispatch_span = trace_span!( "dispatch_tool_call_with_code_mode_result", @@ -104,7 +108,16 @@ impl ToolCallRuntime { _ = cancellation_token.cancelled() => { let secs = started.elapsed().as_secs_f32().max(0.1); dispatch_span.record("aborted", true); - Ok(Self::aborted_response(&call, secs)) + let response = Self::aborted_response(&call, secs); + notify_tool_aborted( + abort_session.as_ref(), + abort_turn.as_ref(), + call.call_id.as_str(), + &call.tool_name, + abort_source, + ) + .await; + Ok(response) }, res = async { let _guard = if supports_parallel { diff --git a/codex-rs/core/src/tools/registry.rs b/codex-rs/core/src/tools/registry.rs index 96b686a83a73..620c3c0d1180 100644 --- a/codex-rs/core/src/tools/registry.rs +++ b/codex-rs/core/src/tools/registry.rs @@ -18,9 +18,12 @@ use crate::tools::context::ToolOutput; use crate::tools::context::ToolPayload; use crate::tools::flat_tool_name; use crate::tools::hook_names::HookToolName; +use crate::tools::lifecycle::notify_tool_finish; +use crate::tools::lifecycle::notify_tool_start; use crate::tools::tool_dispatch_trace::ToolDispatchTrace; use crate::tools::tool_search_entry::ToolSearchInfo; use crate::util::error_or_panic; +use codex_extension_api::ToolCallOutcome; use codex_protocol::models::ResponseInputItem; use codex_protocol::protocol::EventMsg; use codex_tools::ToolName; @@ -389,6 +392,8 @@ impl ToolRegistry { return Err(err); } + notify_tool_start(&invocation).await; + if let Some(pre_tool_use_payload) = tool.pre_tool_use_payload(&invocation) { match run_pre_tool_use_hooks( &invocation.session, @@ -402,13 +407,27 @@ impl ToolRegistry { PreToolUseHookResult::Blocked(message) => { let err = FunctionCallError::RespondToModel(message); dispatch_trace.record_failed(&err); + notify_tool_finish(&invocation, ToolCallOutcome::Blocked).await; return Err(err); } PreToolUseHookResult::Continue { updated_input: Some(updated_input), - } => { - invocation = tool.with_updated_hook_input(invocation, updated_input)?; - } + } => match tool.with_updated_hook_input(invocation.clone(), updated_input) { + Ok(updated_invocation) => { + invocation = updated_invocation; + } + Err(err) => { + dispatch_trace.record_failed(&err); + notify_tool_finish( + &invocation, + ToolCallOutcome::Failed { + handler_executed: false, + }, + ) + .await; + return Err(err); + } + }, PreToolUseHookResult::Continue { updated_input: None, } => {} @@ -503,6 +522,24 @@ impl ToolRegistry { } } + let lifecycle_outcome = match &result { + Ok(_) => { + let guard = response_cell.lock().await; + match guard.as_ref() { + Some(result) => ToolCallOutcome::Completed { + success: result.result.success_for_logging(), + }, + None => ToolCallOutcome::Failed { + handler_executed: true, + }, + } + } + Err(_) => ToolCallOutcome::Failed { + handler_executed: true, + }, + }; + notify_tool_finish(&invocation, lifecycle_outcome).await; + if let Err(err) = invocation .session .goal_runtime_apply(GoalRuntimeEvent::ToolCompleted { diff --git a/codex-rs/core/src/tools/registry_tests.rs b/codex-rs/core/src/tools/registry_tests.rs index defacf33c01d..e3ecfc8f9890 100644 --- a/codex-rs/core/src/tools/registry_tests.rs +++ b/codex-rs/core/src/tools/registry_tests.rs @@ -23,6 +23,97 @@ impl ToolExecutor for TestHandler { impl CoreToolRuntime for TestHandler {} +#[derive(Clone)] +enum LifecycleTestResult { + Ok { success: bool }, + Err, +} + +struct LifecycleTestHandler { + tool_name: codex_tools::ToolName, + result: LifecycleTestResult, +} + +#[async_trait::async_trait] +impl ToolExecutor for LifecycleTestHandler { + fn tool_name(&self) -> codex_tools::ToolName { + self.tool_name.clone() + } + + async fn handle( + &self, + _invocation: ToolInvocation, + ) -> Result, FunctionCallError> { + match self.result.clone() { + LifecycleTestResult::Ok { success } => Ok(Box::new( + crate::tools::context::FunctionToolOutput::from_text( + "ok".to_string(), + Some(success), + ), + )), + LifecycleTestResult::Err => Err(FunctionCallError::RespondToModel( + "handler failed".to_string(), + )), + } + } +} + +impl CoreToolRuntime for LifecycleTestHandler {} + +#[derive(Debug, PartialEq, Eq)] +enum RecordedToolLifecycle { + Start { + call_id: String, + tool_name: codex_tools::ToolName, + }, + Finish { + call_id: String, + tool_name: codex_tools::ToolName, + outcome: codex_extension_api::ToolCallOutcome, + }, +} + +struct ToolLifecycleRecorder { + records: Arc>>, +} + +impl codex_extension_api::ToolLifecycleContributor for ToolLifecycleRecorder { + fn on_tool_start<'a>( + &'a self, + input: codex_extension_api::ToolStartInput<'a>, + ) -> codex_extension_api::ToolLifecycleFuture<'a> { + let records = Arc::clone(&self.records); + let record = RecordedToolLifecycle::Start { + call_id: input.call_id.to_string(), + tool_name: input.tool_name.clone(), + }; + Box::pin(async move { + records + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .push(record); + }) + } + + fn on_tool_finish<'a>( + &'a self, + input: codex_extension_api::ToolFinishInput<'a>, + ) -> codex_extension_api::ToolLifecycleFuture<'a> { + let records = Arc::clone(&self.records); + let record = RecordedToolLifecycle::Finish { + call_id: input.call_id.to_string(), + tool_name: input.tool_name.clone(), + outcome: input.outcome, + }; + Box::pin(async move { + records + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .push(record); + }) + } +} + #[test] fn handler_looks_up_namespaced_aliases_explicitly() { let namespace = "mcp__codex_apps__gmail"; @@ -61,3 +152,106 @@ fn handler_looks_up_namespaced_aliases_explicitly() { .is_some_and(|handler| Arc::ptr_eq(handler, &namespaced_handler)) ); } + +#[tokio::test] +async fn dispatch_notifies_tool_lifecycle_contributors() -> anyhow::Result<()> { + let (mut session, turn) = crate::session::tests::make_session_and_context().await; + let records = Arc::new(std::sync::Mutex::new(Vec::new())); + let mut builder = codex_extension_api::ExtensionRegistryBuilder::::new(); + builder.tool_lifecycle_contributor(Arc::new(ToolLifecycleRecorder { + records: Arc::clone(&records), + })); + session.services.extensions = Arc::new(builder.build()); + + let ok_tool = codex_tools::ToolName::plain("ok_tool"); + let failing_tool = codex_tools::ToolName::plain("failing_tool"); + let ok_handler = Arc::new(LifecycleTestHandler { + tool_name: ok_tool.clone(), + result: LifecycleTestResult::Ok { success: false }, + }) as Arc; + let failing_handler = Arc::new(LifecycleTestHandler { + tool_name: failing_tool.clone(), + result: LifecycleTestResult::Err, + }) as Arc; + let registry = ToolRegistry::new(HashMap::from([ + (ok_tool.clone(), ok_handler), + (failing_tool.clone(), failing_handler), + ])); + let session = Arc::new(session); + let turn = Arc::new(turn); + + registry + .dispatch_any(test_invocation( + Arc::clone(&session), + Arc::clone(&turn), + "ok-call", + ok_tool.clone(), + )) + .await?; + let err = match registry + .dispatch_any(test_invocation( + Arc::clone(&session), + Arc::clone(&turn), + "failing-call", + failing_tool.clone(), + )) + .await + { + Ok(_) => panic!("failing handler should return an error"), + Err(err) => err, + }; + assert_eq!(err.to_string(), "handler failed"); + + let expected = vec![ + RecordedToolLifecycle::Start { + call_id: "ok-call".to_string(), + tool_name: ok_tool.clone(), + }, + RecordedToolLifecycle::Finish { + call_id: "ok-call".to_string(), + tool_name: ok_tool, + outcome: codex_extension_api::ToolCallOutcome::Completed { success: false }, + }, + RecordedToolLifecycle::Start { + call_id: "failing-call".to_string(), + tool_name: failing_tool.clone(), + }, + RecordedToolLifecycle::Finish { + call_id: "failing-call".to_string(), + tool_name: failing_tool, + outcome: codex_extension_api::ToolCallOutcome::Failed { + handler_executed: true, + }, + }, + ]; + let actual = records + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .drain(..) + .collect::>(); + assert_eq!(expected, actual); + + Ok(()) +} + +fn test_invocation( + session: Arc, + turn: Arc, + call_id: &str, + tool_name: codex_tools::ToolName, +) -> ToolInvocation { + ToolInvocation { + session, + turn, + cancellation_token: tokio_util::sync::CancellationToken::new(), + tracker: Arc::new(tokio::sync::Mutex::new( + crate::turn_diff_tracker::TurnDiffTracker::new(), + )), + call_id: call_id.to_string(), + tool_name, + source: crate::tools::context::ToolCallSource::Direct, + payload: ToolPayload::Function { + arguments: "{}".to_string(), + }, + } +} diff --git a/codex-rs/ext/extension-api/src/contributors.rs b/codex-rs/ext/extension-api/src/contributors.rs index b03846fd3148..1125cc94c5e1 100644 --- a/codex-rs/ext/extension-api/src/contributors.rs +++ b/codex-rs/ext/extension-api/src/contributors.rs @@ -11,6 +11,7 @@ use crate::ExtensionData; mod prompt; mod thread_lifecycle; +mod tool_lifecycle; mod turn_lifecycle; pub use prompt::PromptFragment; @@ -18,6 +19,11 @@ pub use prompt::PromptSlot; pub use thread_lifecycle::ThreadResumeInput; pub use thread_lifecycle::ThreadStartInput; pub use thread_lifecycle::ThreadStopInput; +pub use tool_lifecycle::ToolCallOutcome; +pub use tool_lifecycle::ToolCallSource; +pub use tool_lifecycle::ToolFinishInput; +pub use tool_lifecycle::ToolLifecycleFuture; +pub use tool_lifecycle::ToolStartInput; pub use turn_lifecycle::TurnAbortInput; pub use turn_lifecycle::TurnStartInput; pub use turn_lifecycle::TurnStopInput; @@ -110,6 +116,23 @@ pub trait ToolContributor: Send + Sync { ) -> Vec>>; } +/// Contributor for host-owned tool lifecycle gates. +/// +/// Implementations should use these callbacks to observe tool execution without +/// inspecting or rewriting tool input/output. Use `ToolContributor` for owning a +/// tool implementation and hooks for policy that needs tool payloads. +pub trait ToolLifecycleContributor: Send + Sync { + /// Called once the host has accepted a tool call for execution. + fn on_tool_start<'a>(&'a self, _input: ToolStartInput<'a>) -> ToolLifecycleFuture<'a> { + Box::pin(std::future::ready(())) + } + + /// Called after the tool call returns, is blocked, fails, or is cancelled. + fn on_tool_finish<'a>(&'a self, _input: ToolFinishInput<'a>) -> ToolLifecycleFuture<'a> { + Box::pin(std::future::ready(())) + } +} + /// Future returned by one claimed approval-review contribution. pub type ApprovalReviewFuture<'a> = std::pin::Pin + Send + 'a>>; diff --git a/codex-rs/ext/extension-api/src/contributors/tool_lifecycle.rs b/codex-rs/ext/extension-api/src/contributors/tool_lifecycle.rs new file mode 100644 index 000000000000..486bca2643aa --- /dev/null +++ b/codex-rs/ext/extension-api/src/contributors/tool_lifecycle.rs @@ -0,0 +1,82 @@ +use std::future::Future; +use std::pin::Pin; + +use codex_tools::ToolName; + +use crate::ExtensionData; + +/// Future returned by one tool-lifecycle callback. +pub type ToolLifecycleFuture<'a> = Pin + Send + 'a>>; + +/// Host-visible source for a model tool call. +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum ToolCallSource { + /// The model invoked the tool directly. + Direct, + /// Code mode invoked the tool while executing a runtime cell. + CodeMode { + /// Runtime cell that issued the nested tool request. + cell_id: String, + /// Code-mode's per-cell tool invocation id. + runtime_tool_call_id: String, + }, +} + +/// Extension-facing outcome for a finished tool call. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum ToolCallOutcome { + /// The tool returned a normal output. + Completed { + /// The tool output's own success marker for telemetry/logging. + success: bool, + }, + /// The tool was blocked by host policy before the handler ran. + Blocked, + /// The tool did not produce a normal output. + Failed { + /// Whether the host reached the tool handler before the failure. + handler_executed: bool, + }, + /// The host cancelled the tool before normal completion. Cancellation can + /// win before the dispatch path accepts the call, so contributors should not + /// assume a matching start callback exists. + Aborted, +} + +/// Input supplied when the host starts executing one tool call. +pub struct ToolStartInput<'a> { + /// Store scoped to the host session runtime. + pub session_store: &'a ExtensionData, + /// Store scoped to this thread runtime. + pub thread_store: &'a ExtensionData, + /// Store scoped to this turn runtime. + pub turn_store: &'a ExtensionData, + /// Current turn submission id. + pub turn_id: &'a str, + /// Model-visible tool call id. + pub call_id: &'a str, + /// Tool name as routed by the host. + pub tool_name: &'a ToolName, + /// Source that issued the tool call. + pub source: ToolCallSource, +} + +/// Input supplied when the host finishes executing one tool call. +pub struct ToolFinishInput<'a> { + /// Store scoped to the host session runtime. + pub session_store: &'a ExtensionData, + /// Store scoped to this thread runtime. + pub thread_store: &'a ExtensionData, + /// Store scoped to this turn runtime. + pub turn_store: &'a ExtensionData, + /// Current turn submission id. + pub turn_id: &'a str, + /// Model-visible tool call id. + pub call_id: &'a str, + /// Tool name as routed by the host. + pub tool_name: &'a ToolName, + /// Source that issued the tool call. + pub source: ToolCallSource, + /// Host-observed result of the tool call. + pub outcome: ToolCallOutcome, +} diff --git a/codex-rs/ext/extension-api/src/lib.rs b/codex-rs/ext/extension-api/src/lib.rs index fe33d421287f..373f3735a465 100644 --- a/codex-rs/ext/extension-api/src/lib.rs +++ b/codex-rs/ext/extension-api/src/lib.rs @@ -28,7 +28,13 @@ pub use contributors::ThreadResumeInput; pub use contributors::ThreadStartInput; pub use contributors::ThreadStopInput; pub use contributors::TokenUsageContributor; +pub use contributors::ToolCallOutcome; +pub use contributors::ToolCallSource; pub use contributors::ToolContributor; +pub use contributors::ToolFinishInput; +pub use contributors::ToolLifecycleContributor; +pub use contributors::ToolLifecycleFuture; +pub use contributors::ToolStartInput; pub use contributors::TurnAbortInput; pub use contributors::TurnItemContributionFuture; pub use contributors::TurnItemContributor; diff --git a/codex-rs/ext/extension-api/src/registry.rs b/codex-rs/ext/extension-api/src/registry.rs index 41d0967126d0..4577ddc048f1 100644 --- a/codex-rs/ext/extension-api/src/registry.rs +++ b/codex-rs/ext/extension-api/src/registry.rs @@ -10,6 +10,7 @@ use crate::NoopExtensionEventSink; use crate::ThreadLifecycleContributor; use crate::TokenUsageContributor; use crate::ToolContributor; +use crate::ToolLifecycleContributor; use crate::TurnItemContributor; use crate::TurnLifecycleContributor; @@ -22,6 +23,7 @@ pub struct ExtensionRegistryBuilder { token_usage_contributors: Vec>, context_contributors: Vec>, tool_contributors: Vec>, + tool_lifecycle_contributors: Vec>, turn_item_contributors: Vec>, approval_review_contributors: Vec>, } @@ -37,6 +39,7 @@ impl Default for ExtensionRegistryBuilder { approval_review_contributors: Vec::new(), context_contributors: Vec::new(), tool_contributors: Vec::new(), + tool_lifecycle_contributors: Vec::new(), turn_item_contributors: Vec::new(), } } @@ -99,6 +102,11 @@ impl ExtensionRegistryBuilder { self.tool_contributors.push(contributor); } + /// Registers one tool-lifecycle contributor. + pub fn tool_lifecycle_contributor(&mut self, contributor: Arc) { + self.tool_lifecycle_contributors.push(contributor); + } + /// Registers one ordered turn-item contributor. pub fn turn_item_contributor(&mut self, contributor: Arc) { self.turn_item_contributors.push(contributor); @@ -115,6 +123,7 @@ impl ExtensionRegistryBuilder { approval_review_contributors: self.approval_review_contributors, context_contributors: self.context_contributors, tool_contributors: self.tool_contributors, + tool_lifecycle_contributors: self.tool_lifecycle_contributors, turn_item_contributors: self.turn_item_contributors, } } @@ -129,6 +138,7 @@ pub struct ExtensionRegistry { token_usage_contributors: Vec>, context_contributors: Vec>, tool_contributors: Vec>, + tool_lifecycle_contributors: Vec>, turn_item_contributors: Vec>, approval_review_contributors: Vec>, } @@ -182,6 +192,11 @@ impl ExtensionRegistry { &self.tool_contributors } + /// Returns the registered tool-lifecycle contributors. + pub fn tool_lifecycle_contributors(&self) -> &[Arc] { + &self.tool_lifecycle_contributors + } + /// Returns the registered ordered turn-item contributors. pub fn turn_item_contributors(&self) -> &[Arc] { &self.turn_item_contributors diff --git a/codex-rs/ext/goal/src/extension.rs b/codex-rs/ext/goal/src/extension.rs index 97d7157298ee..c14829c2fab4 100644 --- a/codex-rs/ext/goal/src/extension.rs +++ b/codex-rs/ext/goal/src/extension.rs @@ -7,7 +7,11 @@ use codex_extension_api::ExtensionRegistryBuilder; use codex_extension_api::ThreadLifecycleContributor; use codex_extension_api::ThreadStartInput; use codex_extension_api::TokenUsageContributor; +use codex_extension_api::ToolCallOutcome; use codex_extension_api::ToolContributor; +use codex_extension_api::ToolFinishInput; +use codex_extension_api::ToolLifecycleContributor; +use codex_extension_api::ToolLifecycleFuture; use codex_extension_api::TurnAbortInput; use codex_extension_api::TurnLifecycleContributor; use codex_extension_api::TurnStartInput; @@ -18,6 +22,7 @@ use codex_protocol::protocol::TokenUsageInfo; use codex_protocol::protocol::TurnAbortReason; use crate::accounting::GoalAccountingState; +use crate::spec::UPDATE_GOAL_TOOL_NAME; use crate::tool::CreateGoalRequest; use crate::tool::GoalToolExecutor; @@ -209,6 +214,33 @@ where } } +impl ToolLifecycleContributor for GoalExtension +where + C: Send + Sync + 'static, +{ + fn on_tool_finish<'a>(&'a self, input: ToolFinishInput<'a>) -> ToolLifecycleFuture<'a> { + Box::pin(async move { + if !goal_enabled(input.thread_store) { + return; + } + + if !tool_attempt_counts_for_goal_progress(input.outcome) { + return; + } + + if input.tool_name.namespace.is_none() && input.tool_name.name == UPDATE_GOAL_TOOL_NAME + { + return; + } + + // TODO: commit active goal progress through host goal storage and emit + // ThreadGoalUpdated when the persisted goal changes. This replaces + // GoalRuntimeEvent::ToolCompleted once the goal extension owns runtime + // accounting. + }) + } +} + // TODO: app-server initiated goal set/clear operations need a contributor or // backend callback here. They currently happen outside thread/turn/token // lifecycle, but the goal extension must observe them to account before @@ -266,6 +298,7 @@ pub fn install_with_backend( registry.config_contributor(extension.clone()); registry.turn_lifecycle_contributor(extension.clone()); registry.token_usage_contributor(extension.clone()); + registry.tool_lifecycle_contributor(extension.clone()); registry.tool_contributor(extension); } @@ -278,3 +311,17 @@ fn goal_enabled(thread_store: &ExtensionData) -> bool { fn accounting_state(thread_store: &ExtensionData) -> Arc { thread_store.get_or_init::(GoalAccountingState::default) } + +fn tool_attempt_counts_for_goal_progress(outcome: ToolCallOutcome) -> bool { + match outcome { + ToolCallOutcome::Completed { .. } => true, + ToolCallOutcome::Failed { + handler_executed: true, + } => true, + ToolCallOutcome::Blocked + | ToolCallOutcome::Failed { + handler_executed: false, + } + | ToolCallOutcome::Aborted => false, + } +} From d0c284ab433ee623826726d6c107323079e0455b Mon Sep 17 00:00:00 2001 From: jif-oai Date: Mon, 18 May 2026 17:59:02 +0200 Subject: [PATCH 2/3] Fix tool lifecycle cancellation race --- codex-rs/core/src/tools/parallel.rs | 220 ++++++++++++++++++++++++---- codex-rs/core/src/tools/registry.rs | 25 +++- codex-rs/core/src/tools/router.rs | 53 ++++++- codex-rs/ext/goal/src/extension.rs | 16 +- 4 files changed, 267 insertions(+), 47 deletions(-) diff --git a/codex-rs/core/src/tools/parallel.rs b/codex-rs/core/src/tools/parallel.rs index a985784c8c18..0c90fb3b56e7 100644 --- a/codex-rs/core/src/tools/parallel.rs +++ b/codex-rs/core/src/tools/parallel.rs @@ -1,7 +1,10 @@ use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; use std::time::Instant; use tokio::sync::RwLock; +use tokio::task::JoinError; use tokio_util::either::Either; use tokio_util::sync::CancellationToken; use tokio_util::task::AbortOnDropHandle; @@ -93,6 +96,9 @@ impl ToolCallRuntime { let abort_session = Arc::clone(&session); let abort_source = source.clone(); let abort_turn = Arc::clone(&turn); + let handler_finished = Arc::new(AtomicBool::new(false)); + let dispatch_handler_finished = Arc::clone(&handler_finished); + let dispatch_call = call.clone(); let dispatch_span = trace_span!( "dispatch_tool_call_with_code_mode_result", @@ -101,13 +107,41 @@ impl ToolCallRuntime { call_id = call.call_id.as_str(), aborted = false, ); + let abort_dispatch_span = dispatch_span.clone(); - let handle: AbortOnDropHandle> = + let mut handle: AbortOnDropHandle> = AbortOnDropHandle::new(tokio::spawn(async move { - tokio::select! { - _ = cancellation_token.cancelled() => { + let _guard = if supports_parallel { + Either::Left(lock.read().await) + } else { + Either::Right(lock.write().await) + }; + + router + .dispatch_tool_call_with_handler_finished( + session, + turn, + invocation_cancellation_token, + tracker, + dispatch_call, + source, + dispatch_handler_finished, + ) + .instrument(dispatch_span.clone()) + .await + })); + + async move { + tokio::select! { + res = &mut handle => res.map_err(Self::tool_task_join_error)?, + _ = cancellation_token.cancelled() => { + if handler_finished.load(Ordering::Acquire) || handle.is_finished() { + handle.await.map_err(Self::tool_task_join_error)? + } else { let secs = started.elapsed().as_secs_f32().max(0.1); - dispatch_span.record("aborted", true); + abort_dispatch_span.record("aborted", true); + handle.abort(); + let _ = handle.await; let response = Self::aborted_response(&call, secs); notify_tool_aborted( abort_session.as_ref(), @@ -118,39 +152,19 @@ impl ToolCallRuntime { ) .await; Ok(response) - }, - res = async { - let _guard = if supports_parallel { - Either::Left(lock.read().await) - } else { - Either::Right(lock.write().await) - }; - - router - .dispatch_tool_call_with_code_mode_result( - session, - turn, - invocation_cancellation_token, - tracker, - call.clone(), - source, - ) - .instrument(dispatch_span.clone()) - .await - } => res, - } - })); - - async move { - handle.await.map_err(|err| { - FunctionCallError::Fatal(format!("tool task failed to receive: {err:?}")) - })? + } + }, + } } .in_current_span() } } impl ToolCallRuntime { + fn tool_task_join_error(err: JoinError) -> FunctionCallError { + FunctionCallError::Fatal(format!("tool task failed to receive: {err:?}")) + } + fn failure_response(call: ToolCall, err: FunctionCallError) -> ResponseInputItem { let message = err.to_string(); match call.payload { @@ -202,3 +216,147 @@ impl ToolCallRuntime { } } } + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + use crate::tools::context::FunctionToolOutput; + use crate::tools::context::ToolInvocation; + use crate::tools::registry::CoreToolRuntime; + use crate::tools::registry::ToolExecutor; + use crate::tools::registry::ToolRegistry; + use crate::turn_diff_tracker::TurnDiffTracker; + use codex_extension_api::ToolCallOutcome; + use codex_protocol::models::FunctionCallOutputBody; + use codex_protocol::models::FunctionCallOutputPayload; + use pretty_assertions::assert_eq; + use tokio::sync::Notify; + use tokio::sync::oneshot; + + struct ImmediateHandler { + tool_name: codex_tools::ToolName, + } + + #[async_trait::async_trait] + impl ToolExecutor for ImmediateHandler { + fn tool_name(&self) -> codex_tools::ToolName { + self.tool_name.clone() + } + + async fn handle( + &self, + _invocation: ToolInvocation, + ) -> Result, FunctionCallError> { + Ok(Box::new(FunctionToolOutput::from_text( + "ok".to_string(), + Some(true), + ))) + } + } + + impl CoreToolRuntime for ImmediateHandler {} + + struct BlockingFinishContributor { + records: Arc>>, + finish_started: std::sync::Mutex>>, + allow_finish: Arc, + } + + impl codex_extension_api::ToolLifecycleContributor for BlockingFinishContributor { + fn on_tool_finish<'a>( + &'a self, + input: codex_extension_api::ToolFinishInput<'a>, + ) -> codex_extension_api::ToolLifecycleFuture<'a> { + let records = Arc::clone(&self.records); + let allow_finish = Arc::clone(&self.allow_finish); + let finish_started = self + .finish_started + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .take(); + let outcome = input.outcome; + Box::pin(async move { + if let Some(finish_started) = finish_started { + let _ = finish_started.send(()); + } + allow_finish.notified().await; + records + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .push(outcome); + }) + } + } + + #[tokio::test] + async fn cancellation_after_handler_finishes_preserves_completed_lifecycle() + -> anyhow::Result<()> { + let (mut session, turn_context) = crate::session::tests::make_session_and_context().await; + let records = Arc::new(std::sync::Mutex::new(Vec::new())); + let (finish_started_tx, finish_started_rx) = oneshot::channel(); + let allow_finish = Arc::new(Notify::new()); + let mut builder = + codex_extension_api::ExtensionRegistryBuilder::::new(); + builder.tool_lifecycle_contributor(Arc::new(BlockingFinishContributor { + records: Arc::clone(&records), + finish_started: std::sync::Mutex::new(Some(finish_started_tx)), + allow_finish: Arc::clone(&allow_finish), + })); + session.services.extensions = Arc::new(builder.build()); + + let session = Arc::new(session); + let turn_context = Arc::new(turn_context); + let tool_name = codex_tools::ToolName::plain("test_tool"); + let handler = Arc::new(ImmediateHandler { + tool_name: tool_name.clone(), + }) as Arc; + let router = Arc::new(ToolRouter::from_parts( + ToolRegistry::from_tools([handler]), + Vec::new(), + )); + let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new())); + let runtime = ToolCallRuntime::new(router, session, turn_context, tracker); + let cancellation_token = CancellationToken::new(); + let call = ToolCall { + tool_name, + call_id: "call-1".to_string(), + payload: ToolPayload::Function { + arguments: "{}".to_string(), + }, + }; + + let response_task = + tokio::spawn(runtime.handle_tool_call(call, cancellation_token.clone())); + tokio::time::timeout(Duration::from_secs(1), finish_started_rx) + .await + .expect("timed out waiting for lifecycle notification to start") + .expect("lifecycle notification should start"); + cancellation_token.cancel(); + tokio::time::sleep(Duration::from_millis(10)).await; + allow_finish.notify_waiters(); + + let response = tokio::time::timeout(Duration::from_secs(1), response_task) + .await + .expect("timed out waiting for tool response") + .expect("tool response task should join")?; + let expected_response = ResponseInputItem::FunctionCallOutput { + call_id: "call-1".to_string(), + output: FunctionCallOutputPayload { + body: FunctionCallOutputBody::Text("ok".to_string()), + success: Some(true), + }, + }; + assert_eq!(expected_response, response); + + let actual = records + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .drain(..) + .collect::>(); + assert_eq!(vec![ToolCallOutcome::Completed { success: true }], actual); + + Ok(()) + } +} diff --git a/codex-rs/core/src/tools/registry.rs b/codex-rs/core/src/tools/registry.rs index 620c3c0d1180..6dc48ba95d64 100644 --- a/codex-rs/core/src/tools/registry.rs +++ b/codex-rs/core/src/tools/registry.rs @@ -1,5 +1,7 @@ use std::collections::HashMap; use std::sync::Arc; +use std::sync::atomic::AtomicBool; +use std::sync::atomic::Ordering; use std::time::Duration; use crate::function_tool::FunctionCallError; @@ -301,13 +303,23 @@ impl ToolRegistry { Some(tool.supports_parallel_tool_calls()) } + #[allow(dead_code)] + pub(crate) async fn dispatch_any( + &self, + invocation: ToolInvocation, + ) -> Result { + self.dispatch_any_with_handler_finished(invocation, /*handler_finished*/ None) + .await + } + #[expect( clippy::await_holding_invalid_type, reason = "tool dispatch must keep active-turn accounting atomic" )] - pub(crate) async fn dispatch_any( + pub(crate) async fn dispatch_any_with_handler_finished( &self, mut invocation: ToolInvocation, + handler_finished: Option>, ) -> Result { let tool_name = invocation.tool_name.clone(); let tool_name_flat = flat_tool_name(&tool_name); @@ -449,7 +461,9 @@ impl ToolRegistry { let tool = tool.clone(); let response_cell = &response_cell; async move { - match handle_any_tool(tool.as_ref(), invocation_for_tool).await { + match handle_any_tool(tool.as_ref(), invocation_for_tool, handler_finished) + .await + { Ok(result) => { let preview = result.result.log_preview(); let success = result.result.success_for_logging(); @@ -576,10 +590,15 @@ impl ToolRegistry { async fn handle_any_tool( tool: &dyn CoreToolRuntime, invocation: ToolInvocation, + handler_finished: Option>, ) -> Result { let call_id = invocation.call_id.clone(); let payload = invocation.payload.clone(); - let output = tool.handle(invocation.clone()).await?; + let output = tool.handle(invocation.clone()).await; + if let Some(handler_finished) = handler_finished { + handler_finished.store(true, Ordering::Release); + } + let output = output?; let post_tool_use_payload = CoreToolRuntime::post_tool_use_payload(tool, &invocation, output.as_ref()); Ok(AnyToolResult { diff --git a/codex-rs/core/src/tools/router.rs b/codex-rs/core/src/tools/router.rs index 2477ba347c01..9190d75a6a2d 100644 --- a/codex-rs/core/src/tools/router.rs +++ b/codex-rs/core/src/tools/router.rs @@ -19,6 +19,7 @@ use codex_tools::ToolName; use codex_tools::ToolSpec; use codex_tools::ToolsConfig; use std::sync::Arc; +use std::sync::atomic::AtomicBool; use tokio_util::sync::CancellationToken; use tracing::instrument; @@ -123,6 +124,7 @@ impl ToolRouter { } } + #[allow(dead_code)] #[instrument(level = "trace", skip_all, err)] pub async fn dispatch_tool_call_with_code_mode_result( &self, @@ -132,6 +134,53 @@ impl ToolRouter { tracker: SharedTurnDiffTracker, call: ToolCall, source: ToolCallSource, + ) -> Result { + self.dispatch_tool_call_with_code_mode_result_inner( + session, + turn, + cancellation_token, + tracker, + call, + source, + /*handler_finished*/ None, + ) + .await + } + + #[instrument(level = "trace", skip_all, err)] + #[allow(clippy::too_many_arguments)] + pub(crate) async fn dispatch_tool_call_with_handler_finished( + &self, + session: Arc, + turn: Arc, + cancellation_token: CancellationToken, + tracker: SharedTurnDiffTracker, + call: ToolCall, + source: ToolCallSource, + handler_finished: Arc, + ) -> Result { + self.dispatch_tool_call_with_code_mode_result_inner( + session, + turn, + cancellation_token, + tracker, + call, + source, + Some(handler_finished), + ) + .await + } + + #[allow(clippy::too_many_arguments)] + async fn dispatch_tool_call_with_code_mode_result_inner( + &self, + session: Arc, + turn: Arc, + cancellation_token: CancellationToken, + tracker: SharedTurnDiffTracker, + call: ToolCall, + source: ToolCallSource, + handler_finished: Option>, ) -> Result { let ToolCall { tool_name, @@ -150,7 +199,9 @@ impl ToolRouter { payload, }; - self.registry.dispatch_any(invocation).await + self.registry + .dispatch_any_with_handler_finished(invocation, handler_finished) + .await } } diff --git a/codex-rs/ext/goal/src/extension.rs b/codex-rs/ext/goal/src/extension.rs index ba928ad74783..a8d4f5c289f9 100644 --- a/codex-rs/ext/goal/src/extension.rs +++ b/codex-rs/ext/goal/src/extension.rs @@ -232,18 +232,10 @@ where { fn on_tool_finish<'a>(&'a self, input: ToolFinishInput<'a>) -> ToolLifecycleFuture<'a> { Box::pin(async move { - if !goal_enabled(input.thread_store) { - return; - } - - if !tool_attempt_counts_for_goal_progress(input.outcome) { - return; - } - - if input.tool_name.namespace.is_none() && input.tool_name.name == UPDATE_GOAL_TOOL_NAME - { - return; - } + let _should_count_for_goal_progress = goal_enabled(input.thread_store) + && tool_attempt_counts_for_goal_progress(input.outcome) + && !(input.tool_name.namespace.is_none() + && input.tool_name.name == UPDATE_GOAL_TOOL_NAME); // TODO: commit active goal progress through host goal storage and emit // ThreadGoalUpdated when the persisted goal changes. This replaces From 7335966788cf2a6c49229b06b6502fe73895c15d Mon Sep 17 00:00:00 2001 From: jif-oai Date: Mon, 18 May 2026 19:04:48 +0200 Subject: [PATCH 3/3] Fix tool lifecycle cancellation boundary --- codex-rs/core/src/tools/parallel.rs | 37 ++++++++++++++++------------- codex-rs/core/src/tools/registry.rs | 26 ++++++++++---------- codex-rs/core/src/tools/router.rs | 12 +++++----- 3 files changed, 41 insertions(+), 34 deletions(-) diff --git a/codex-rs/core/src/tools/parallel.rs b/codex-rs/core/src/tools/parallel.rs index 0c90fb3b56e7..15954869e2b0 100644 --- a/codex-rs/core/src/tools/parallel.rs +++ b/codex-rs/core/src/tools/parallel.rs @@ -96,8 +96,8 @@ impl ToolCallRuntime { let abort_session = Arc::clone(&session); let abort_source = source.clone(); let abort_turn = Arc::clone(&turn); - let handler_finished = Arc::new(AtomicBool::new(false)); - let dispatch_handler_finished = Arc::clone(&handler_finished); + let terminal_outcome_reached = Arc::new(AtomicBool::new(false)); + let dispatch_terminal_outcome_reached = Arc::clone(&terminal_outcome_reached); let dispatch_call = call.clone(); let dispatch_span = trace_span!( @@ -118,14 +118,14 @@ impl ToolCallRuntime { }; router - .dispatch_tool_call_with_handler_finished( + .dispatch_tool_call_with_terminal_outcome( session, turn, invocation_cancellation_token, tracker, dispatch_call, source, - dispatch_handler_finished, + dispatch_terminal_outcome_reached, ) .instrument(dispatch_span.clone()) .await @@ -135,23 +135,28 @@ impl ToolCallRuntime { tokio::select! { res = &mut handle => res.map_err(Self::tool_task_join_error)?, _ = cancellation_token.cancelled() => { - if handler_finished.load(Ordering::Acquire) || handle.is_finished() { + if terminal_outcome_reached.load(Ordering::Acquire) || handle.is_finished() { handle.await.map_err(Self::tool_task_join_error)? } else { let secs = started.elapsed().as_secs_f32().max(0.1); abort_dispatch_span.record("aborted", true); handle.abort(); - let _ = handle.await; - let response = Self::aborted_response(&call, secs); - notify_tool_aborted( - abort_session.as_ref(), - abort_turn.as_ref(), - call.call_id.as_str(), - &call.tool_name, - abort_source, - ) - .await; - Ok(response) + match handle.await { + Ok(result) => result, + Err(err) if err.is_cancelled() => { + let response = Self::aborted_response(&call, secs); + notify_tool_aborted( + abort_session.as_ref(), + abort_turn.as_ref(), + call.call_id.as_str(), + &call.tool_name, + abort_source, + ) + .await; + Ok(response) + } + Err(err) => Err(Self::tool_task_join_error(err)), + } } }, } diff --git a/codex-rs/core/src/tools/registry.rs b/codex-rs/core/src/tools/registry.rs index 6dc48ba95d64..363c3f2a0146 100644 --- a/codex-rs/core/src/tools/registry.rs +++ b/codex-rs/core/src/tools/registry.rs @@ -308,7 +308,7 @@ impl ToolRegistry { &self, invocation: ToolInvocation, ) -> Result { - self.dispatch_any_with_handler_finished(invocation, /*handler_finished*/ None) + self.dispatch_any_with_terminal_outcome(invocation, /*terminal_outcome_reached*/ None) .await } @@ -316,10 +316,10 @@ impl ToolRegistry { clippy::await_holding_invalid_type, reason = "tool dispatch must keep active-turn accounting atomic" )] - pub(crate) async fn dispatch_any_with_handler_finished( + pub(crate) async fn dispatch_any_with_terminal_outcome( &self, mut invocation: ToolInvocation, - handler_finished: Option>, + terminal_outcome_reached: Option>, ) -> Result { let tool_name = invocation.tool_name.clone(); let tool_name_flat = flat_tool_name(&tool_name); @@ -419,6 +419,9 @@ impl ToolRegistry { PreToolUseHookResult::Blocked(message) => { let err = FunctionCallError::RespondToModel(message); dispatch_trace.record_failed(&err); + if let Some(terminal_outcome_reached) = &terminal_outcome_reached { + terminal_outcome_reached.store(true, Ordering::Release); + } notify_tool_finish(&invocation, ToolCallOutcome::Blocked).await; return Err(err); } @@ -430,6 +433,9 @@ impl ToolRegistry { } Err(err) => { dispatch_trace.record_failed(&err); + if let Some(terminal_outcome_reached) = &terminal_outcome_reached { + terminal_outcome_reached.store(true, Ordering::Release); + } notify_tool_finish( &invocation, ToolCallOutcome::Failed { @@ -461,9 +467,7 @@ impl ToolRegistry { let tool = tool.clone(); let response_cell = &response_cell; async move { - match handle_any_tool(tool.as_ref(), invocation_for_tool, handler_finished) - .await - { + match handle_any_tool(tool.as_ref(), invocation_for_tool).await { Ok(result) => { let preview = result.result.log_preview(); let success = result.result.success_for_logging(); @@ -552,6 +556,9 @@ impl ToolRegistry { handler_executed: true, }, }; + if let Some(terminal_outcome_reached) = &terminal_outcome_reached { + terminal_outcome_reached.store(true, Ordering::Release); + } notify_tool_finish(&invocation, lifecycle_outcome).await; if let Err(err) = invocation @@ -590,15 +597,10 @@ impl ToolRegistry { async fn handle_any_tool( tool: &dyn CoreToolRuntime, invocation: ToolInvocation, - handler_finished: Option>, ) -> Result { let call_id = invocation.call_id.clone(); let payload = invocation.payload.clone(); - let output = tool.handle(invocation.clone()).await; - if let Some(handler_finished) = handler_finished { - handler_finished.store(true, Ordering::Release); - } - let output = output?; + let output = tool.handle(invocation.clone()).await?; let post_tool_use_payload = CoreToolRuntime::post_tool_use_payload(tool, &invocation, output.as_ref()); Ok(AnyToolResult { diff --git a/codex-rs/core/src/tools/router.rs b/codex-rs/core/src/tools/router.rs index 9190d75a6a2d..a279ec88d9b2 100644 --- a/codex-rs/core/src/tools/router.rs +++ b/codex-rs/core/src/tools/router.rs @@ -142,14 +142,14 @@ impl ToolRouter { tracker, call, source, - /*handler_finished*/ None, + /*terminal_outcome_reached*/ None, ) .await } #[instrument(level = "trace", skip_all, err)] #[allow(clippy::too_many_arguments)] - pub(crate) async fn dispatch_tool_call_with_handler_finished( + pub(crate) async fn dispatch_tool_call_with_terminal_outcome( &self, session: Arc, turn: Arc, @@ -157,7 +157,7 @@ impl ToolRouter { tracker: SharedTurnDiffTracker, call: ToolCall, source: ToolCallSource, - handler_finished: Arc, + terminal_outcome_reached: Arc, ) -> Result { self.dispatch_tool_call_with_code_mode_result_inner( session, @@ -166,7 +166,7 @@ impl ToolRouter { tracker, call, source, - Some(handler_finished), + Some(terminal_outcome_reached), ) .await } @@ -180,7 +180,7 @@ impl ToolRouter { tracker: SharedTurnDiffTracker, call: ToolCall, source: ToolCallSource, - handler_finished: Option>, + terminal_outcome_reached: Option>, ) -> Result { let ToolCall { tool_name, @@ -200,7 +200,7 @@ impl ToolRouter { }; self.registry - .dispatch_any_with_handler_finished(invocation, handler_finished) + .dispatch_any_with_terminal_outcome(invocation, terminal_outcome_reached) .await } }