Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions codex-rs/core/src/tools/lifecycle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use codex_extension_api::ToolCallOutcome;
use codex_extension_api::ToolCallSource as ExtensionToolCallSource;
use codex_extension_api::ToolFinishInput;
use codex_extension_api::ToolStartInput;
use codex_tools::ToolName;

use crate::session::session::Session;
use crate::session::turn_context::TurnContext;
use crate::tools::context::ToolCallSource;
use crate::tools::context::ToolInvocation;

pub(crate) async fn notify_tool_start(invocation: &ToolInvocation) {
for contributor in invocation
.session
.services
.extensions
.tool_lifecycle_contributors()
{
contributor
.on_tool_start(ToolStartInput {
session_store: &invocation.session.services.session_extension_data,
thread_store: &invocation.session.services.thread_extension_data,
turn_store: invocation.turn.extension_data.as_ref(),
turn_id: invocation.turn.sub_id.as_str(),
call_id: invocation.call_id.as_str(),
tool_name: &invocation.tool_name,
source: extension_tool_call_source(invocation.source.clone()),
})
.await;
}
}

pub(crate) async fn notify_tool_finish(invocation: &ToolInvocation, outcome: ToolCallOutcome) {
notify_tool_finish_parts(
invocation.session.as_ref(),
invocation.turn.as_ref(),
invocation.call_id.as_str(),
&invocation.tool_name,
invocation.source.clone(),
outcome,
)
.await;
}

pub(crate) async fn notify_tool_aborted(
session: &Session,
turn: &TurnContext,
call_id: &str,
tool_name: &ToolName,
source: ToolCallSource,
) {
notify_tool_finish_parts(
session,
turn,
call_id,
tool_name,
source,
ToolCallOutcome::Aborted,
)
.await;
}

async fn notify_tool_finish_parts(
session: &Session,
turn: &TurnContext,
call_id: &str,
tool_name: &ToolName,
source: ToolCallSource,
outcome: ToolCallOutcome,
) {
for contributor in session.services.extensions.tool_lifecycle_contributors() {
contributor
.on_tool_finish(ToolFinishInput {
session_store: &session.services.session_extension_data,
thread_store: &session.services.thread_extension_data,
turn_store: turn.extension_data.as_ref(),
turn_id: turn.sub_id.as_str(),
call_id,
tool_name,
source: extension_tool_call_source(source.clone()),
outcome,
})
.await;
}
}

fn extension_tool_call_source(source: ToolCallSource) -> ExtensionToolCallSource {
match source {
ToolCallSource::Direct => ExtensionToolCallSource::Direct,
ToolCallSource::CodeMode {
cell_id,
runtime_tool_call_id,
} => ExtensionToolCallSource::CodeMode {
cell_id,
runtime_tool_call_id,
},
}
}
1 change: 1 addition & 0 deletions codex-rs/core/src/tools/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub(crate) mod events;
pub(crate) mod handlers;
pub(crate) mod hook_names;
pub(crate) mod hosted_spec;
pub(crate) mod lifecycle;
pub(crate) mod network_approval;
pub(crate) mod orchestrator;
pub(crate) mod parallel;
Expand Down
236 changes: 206 additions & 30 deletions codex-rs/core/src/tools/parallel.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::time::Instant;

use tokio::sync::RwLock;
use tokio::task::JoinError;
use tokio_util::either::Either;
use tokio_util::sync::CancellationToken;
use tokio_util::task::AbortOnDropHandle;
Expand All @@ -15,6 +18,7 @@ use crate::session::turn_context::TurnContext;
use crate::tools::context::AbortedToolOutput;
use crate::tools::context::SharedTurnDiffTracker;
use crate::tools::context::ToolPayload;
use crate::tools::lifecycle::notify_tool_aborted;
use crate::tools::registry::AnyToolResult;
use crate::tools::registry::ToolArgumentDiffConsumer;
use crate::tools::router::ToolCall;
Expand Down Expand Up @@ -89,6 +93,12 @@ impl ToolCallRuntime {
let lock = Arc::clone(&self.parallel_execution);
let invocation_cancellation_token = cancellation_token.clone();
let started = Instant::now();
let abort_session = Arc::clone(&session);
let abort_source = source.clone();
let abort_turn = Arc::clone(&turn);
let terminal_outcome_reached = Arc::new(AtomicBool::new(false));
let dispatch_terminal_outcome_reached = Arc::clone(&terminal_outcome_reached);
let dispatch_call = call.clone();

let dispatch_span = trace_span!(
"dispatch_tool_call_with_code_mode_result",
Expand All @@ -97,47 +107,69 @@ impl ToolCallRuntime {
call_id = call.call_id.as_str(),
aborted = false,
);
let abort_dispatch_span = dispatch_span.clone();

let handle: AbortOnDropHandle<Result<AnyToolResult, FunctionCallError>> =
let mut handle: AbortOnDropHandle<Result<AnyToolResult, FunctionCallError>> =
AbortOnDropHandle::new(tokio::spawn(async move {
tokio::select! {
_ = cancellation_token.cancelled() => {
let secs = started.elapsed().as_secs_f32().max(0.1);
dispatch_span.record("aborted", true);
Ok(Self::aborted_response(&call, secs))
},
res = async {
let _guard = if supports_parallel {
Either::Left(lock.read().await)
} else {
Either::Right(lock.write().await)
};

router
.dispatch_tool_call_with_code_mode_result(
session,
turn,
invocation_cancellation_token,
tracker,
call.clone(),
source,
)
.instrument(dispatch_span.clone())
.await
} => res,
}
let _guard = if supports_parallel {
Either::Left(lock.read().await)
} else {
Either::Right(lock.write().await)
};

router
.dispatch_tool_call_with_terminal_outcome(
session,
turn,
invocation_cancellation_token,
tracker,
dispatch_call,
source,
dispatch_terminal_outcome_reached,
)
.instrument(dispatch_span.clone())
.await
}));

async move {
handle.await.map_err(|err| {
FunctionCallError::Fatal(format!("tool task failed to receive: {err:?}"))
})?
tokio::select! {
res = &mut handle => res.map_err(Self::tool_task_join_error)?,
_ = cancellation_token.cancelled() => {
if terminal_outcome_reached.load(Ordering::Acquire) || handle.is_finished() {
handle.await.map_err(Self::tool_task_join_error)?
} else {
let secs = started.elapsed().as_secs_f32().max(0.1);
abort_dispatch_span.record("aborted", true);
handle.abort();
match handle.await {
Ok(result) => result,
Err(err) if err.is_cancelled() => {
let response = Self::aborted_response(&call, secs);
notify_tool_aborted(
abort_session.as_ref(),
abort_turn.as_ref(),
call.call_id.as_str(),
&call.tool_name,
abort_source,
)
.await;
Ok(response)
}
Err(err) => Err(Self::tool_task_join_error(err)),
}
}
},
}
}
.in_current_span()
}
}

impl ToolCallRuntime {
fn tool_task_join_error(err: JoinError) -> FunctionCallError {
FunctionCallError::Fatal(format!("tool task failed to receive: {err:?}"))
}

fn failure_response(call: ToolCall, err: FunctionCallError) -> ResponseInputItem {
let message = err.to_string();
match call.payload {
Expand Down Expand Up @@ -189,3 +221,147 @@ impl ToolCallRuntime {
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;

use crate::tools::context::FunctionToolOutput;
use crate::tools::context::ToolInvocation;
use crate::tools::registry::CoreToolRuntime;
use crate::tools::registry::ToolExecutor;
use crate::tools::registry::ToolRegistry;
use crate::turn_diff_tracker::TurnDiffTracker;
use codex_extension_api::ToolCallOutcome;
use codex_protocol::models::FunctionCallOutputBody;
use codex_protocol::models::FunctionCallOutputPayload;
use pretty_assertions::assert_eq;
use tokio::sync::Notify;
use tokio::sync::oneshot;

struct ImmediateHandler {
tool_name: codex_tools::ToolName,
}

#[async_trait::async_trait]
impl ToolExecutor<ToolInvocation> for ImmediateHandler {
fn tool_name(&self) -> codex_tools::ToolName {
self.tool_name.clone()
}

async fn handle(
&self,
_invocation: ToolInvocation,
) -> Result<Box<dyn crate::tools::context::ToolOutput>, FunctionCallError> {
Ok(Box::new(FunctionToolOutput::from_text(
"ok".to_string(),
Some(true),
)))
}
}

impl CoreToolRuntime for ImmediateHandler {}

struct BlockingFinishContributor {
records: Arc<std::sync::Mutex<Vec<ToolCallOutcome>>>,
finish_started: std::sync::Mutex<Option<oneshot::Sender<()>>>,
allow_finish: Arc<Notify>,
}

impl codex_extension_api::ToolLifecycleContributor for BlockingFinishContributor {
fn on_tool_finish<'a>(
&'a self,
input: codex_extension_api::ToolFinishInput<'a>,
) -> codex_extension_api::ToolLifecycleFuture<'a> {
let records = Arc::clone(&self.records);
let allow_finish = Arc::clone(&self.allow_finish);
let finish_started = self
.finish_started
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take();
let outcome = input.outcome;
Box::pin(async move {
if let Some(finish_started) = finish_started {
let _ = finish_started.send(());
}
allow_finish.notified().await;
records
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.push(outcome);
})
}
}

#[tokio::test]
async fn cancellation_after_handler_finishes_preserves_completed_lifecycle()
-> anyhow::Result<()> {
let (mut session, turn_context) = crate::session::tests::make_session_and_context().await;
let records = Arc::new(std::sync::Mutex::new(Vec::new()));
let (finish_started_tx, finish_started_rx) = oneshot::channel();
let allow_finish = Arc::new(Notify::new());
let mut builder =
codex_extension_api::ExtensionRegistryBuilder::<crate::config::Config>::new();
builder.tool_lifecycle_contributor(Arc::new(BlockingFinishContributor {
records: Arc::clone(&records),
finish_started: std::sync::Mutex::new(Some(finish_started_tx)),
allow_finish: Arc::clone(&allow_finish),
}));
session.services.extensions = Arc::new(builder.build());

let session = Arc::new(session);
let turn_context = Arc::new(turn_context);
let tool_name = codex_tools::ToolName::plain("test_tool");
let handler = Arc::new(ImmediateHandler {
tool_name: tool_name.clone(),
}) as Arc<dyn CoreToolRuntime>;
let router = Arc::new(ToolRouter::from_parts(
ToolRegistry::from_tools([handler]),
Vec::new(),
));
let tracker = Arc::new(tokio::sync::Mutex::new(TurnDiffTracker::new()));
let runtime = ToolCallRuntime::new(router, session, turn_context, tracker);
let cancellation_token = CancellationToken::new();
let call = ToolCall {
tool_name,
call_id: "call-1".to_string(),
payload: ToolPayload::Function {
arguments: "{}".to_string(),
},
};

let response_task =
tokio::spawn(runtime.handle_tool_call(call, cancellation_token.clone()));
tokio::time::timeout(Duration::from_secs(1), finish_started_rx)
.await
.expect("timed out waiting for lifecycle notification to start")
.expect("lifecycle notification should start");
cancellation_token.cancel();
tokio::time::sleep(Duration::from_millis(10)).await;
allow_finish.notify_waiters();

let response = tokio::time::timeout(Duration::from_secs(1), response_task)
.await
.expect("timed out waiting for tool response")
.expect("tool response task should join")?;
let expected_response = ResponseInputItem::FunctionCallOutput {
call_id: "call-1".to_string(),
output: FunctionCallOutputPayload {
body: FunctionCallOutputBody::Text("ok".to_string()),
success: Some(true),
},
};
assert_eq!(expected_response, response);

let actual = records
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.drain(..)
.collect::<Vec<_>>();
assert_eq!(vec![ToolCallOutcome::Completed { success: true }], actual);

Ok(())
}
}
Loading
Loading