diff --git a/.changeset/async-tool-calls.md b/.changeset/async-tool-calls.md new file mode 100644 index 000000000..2fce489a2 --- /dev/null +++ b/.changeset/async-tool-calls.md @@ -0,0 +1,5 @@ +--- +"@livekit/agents": patch +--- + +Add AsyncToolset support for background tool calls with progress updates, duplicate handling, and cancellation helpers. diff --git a/agents/src/llm/async_toolset.test.ts b/agents/src/llm/async_toolset.test.ts new file mode 100644 index 000000000..a2549512f --- /dev/null +++ b/agents/src/llm/async_toolset.test.ts @@ -0,0 +1,190 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import { describe, expect, it, vi } from 'vitest'; +import { z } from 'zod'; +import { delay } from '../utils.js'; +import type { AgentSession } from '../voice/agent_session.js'; +import type { RunContext } from '../voice/run_context.js'; +import { AsyncRunContext, type AsyncToolOptions, AsyncToolset } from './async_toolset.js'; +import { ChatContext, FunctionCall, FunctionCallOutput } from './chat_context.js'; +import { tool } from './tool_context.js'; + +type TestAgent = { + readonly chatCtx: ChatContext; + updateChatCtx: (nextChatCtx: ChatContext) => Promise; +}; + +type TestSession = { + agentState: 'listening' | 'thinking'; + currentAgent: TestAgent; + generateReply: ReturnType; + _toolItemsAdded: (items: (FunctionCall | FunctionCallOutput)[]) => void; +}; + +function createRunContext(callId: string, name: string, session: TestSession): RunContext { + return { + session: session as unknown as AgentSession, + speechHandle: { id: 'speech_test', allowInterruptions: true }, + functionCall: FunctionCall.create({ + callId, + name, + args: '{}', + }), + } as unknown as RunContext; +} + +function createSession(agentState: TestSession['agentState'] = 'listening'): TestSession { + let chatCtx = ChatContext.empty(); + const generateReply = vi.fn(); + const agent = { + get chatCtx() { + return chatCtx; + }, + updateChatCtx: async (nextChatCtx: ChatContext) => { + chatCtx = nextChatCtx; + }, + }; + + return { + agentState, + currentAgent: agent, + generateReply, + _toolItemsAdded: () => {}, + }; +} + +describe('AsyncToolset', () => { + it('returns the first update immediately and delivers the final output later', async () => { + const session = createSession(); + const asyncToolset = new AsyncToolset({ + tools: { + long_task: tool({ + description: 'Long task', + parameters: z.object({}), + execute: async (_, { ctx }: AsyncToolOptions) => { + await ctx.update('started'); + await delay(10); + return 'finished'; + }, + }), + }, + }); + + const result = await asyncToolset.tools.long_task!.execute( + {}, + { + ctx: createRunContext('call_async', 'long_task', session), + toolCallId: 'call_async', + }, + ); + + expect(result).toContain('started'); + + await vi.waitFor(() => { + expect(session.currentAgent.chatCtx.items).toHaveLength(2); + expect(session.generateReply).toHaveBeenCalledOnce(); + }); + + expect(session.currentAgent.chatCtx.items[0]?.type).toBe('function_call'); + expect(session.currentAgent.chatCtx.items[1]?.type).toBe('function_call_output'); + }); + + it('rejects duplicate calls when configured', async () => { + const session = createSession(); + const asyncToolset = new AsyncToolset({ + onDuplicateCall: 'reject', + tools: { + long_task: tool({ + description: 'Long task', + parameters: z.object({}), + execute: async (_, { ctx }: AsyncToolOptions) => { + await ctx.update('running'); + await delay(50); + return 'done'; + }, + }), + }, + }); + + await asyncToolset.tools.long_task!.execute( + {}, + { + ctx: createRunContext('call_one', 'long_task', session), + toolCallId: 'call_one', + }, + ); + + const duplicate = await asyncToolset.tools.long_task!.execute( + {}, + { + ctx: createRunContext('call_two', 'long_task', session), + toolCallId: 'call_two', + }, + ); + + expect(duplicate).toContain('Same tool `long_task` is already running'); + }); + + it('exposes running task cancellation', async () => { + const session = createSession(); + const asyncToolset = new AsyncToolset({ + tools: { + long_task: tool({ + description: 'Long task', + parameters: z.object({}), + execute: async (_, { ctx, abortSignal }: AsyncToolOptions) => { + await ctx.update('running'); + await delay(1000, { signal: abortSignal }); + return 'done'; + }, + }), + }, + }); + + await asyncToolset.tools.long_task!.execute( + {}, + { + ctx: createRunContext('call_cancel', 'long_task', session), + toolCallId: 'call_cancel', + }, + ); + + const result = await asyncToolset.tools.cancel_task!.execute( + { call_id: 'call_cancel' }, + { + ctx: createRunContext('call_cancel_tool', 'cancel_task', session), + toolCallId: 'call_cancel_tool', + }, + ); + + expect(result).toBe('Task call_cancel cancelled successfully.'); + }); + + it('closes while waiting to deliver a reply', async () => { + const session = createSession('thinking'); + const asyncToolset = new AsyncToolset({}); + const ctx = new AsyncRunContext({ + runCtx: createRunContext('call_waiting', 'long_task', session), + toolset: asyncToolset, + }); + + await asyncToolset.enqueueReply(ctx, [ + FunctionCall.create({ + callId: 'call_waiting_finished', + name: 'long_task', + args: '{}', + }), + FunctionCallOutput.create({ + callId: 'call_waiting_finished', + name: 'long_task', + output: '"done"', + isError: false, + }), + ]); + + await expect( + Promise.race([asyncToolset.close().then(() => 'closed'), delay(200).then(() => 'timeout')]), + ).resolves.toBe('closed'); + }); +}); diff --git a/agents/src/llm/async_toolset.ts b/agents/src/llm/async_toolset.ts new file mode 100644 index 000000000..62cb37f19 --- /dev/null +++ b/agents/src/llm/async_toolset.ts @@ -0,0 +1,543 @@ +// SPDX-FileCopyrightText: 2026 LiveKit, Inc. +// +// SPDX-License-Identifier: Apache-2.0 +import type { JSONSchema7 } from 'json-schema'; +import { z } from 'zod'; +import { type JobContext, getJobContext } from '../job.js'; +import { log } from '../log.js'; +import { Task, delay } from '../utils.js'; +import { RunContext, type UnknownUserData } from '../voice/run_context.js'; +import { FunctionCall, FunctionCallOutput } from './chat_context.js'; +import { + type FunctionTool, + type JSONObject, + type ToolContext, + ToolError, + type ToolOptions, + isToolError, + tool, +} from './tool_context.js'; +import { isZodObjectSchema, isZodSchema } from './zod-utils.js'; + +const logger = log(); + +const CONFIRM_DUPLICATE_PARAM = '_lk_agents_confirm_duplicate'; + +const UPDATE_TEMPLATE = `The tool \`{functionName}\` has updated, message: {message} +The task is still running, so DON'T make up or give information not included in the message above.`; + +const DUPLICATE_REJECT = `Same tool \`{functionName}\` is already running: +{runningFunctionCalls} +If you want to cancel the existing one, call \`cancel_task\` with call_id.`; + +const DUPLICATE_CONFIRM = `Same tool \`{functionName}\` is already running: +{runningFunctionCalls} +Re-call with confirm duplicate True to run a duplicate if needed, +or if you want to cancel the existing one, call \`cancel_task\` with call_id.`; + +const REPLY_INSTRUCTIONS = `New results arrived from background tool calls (call_ids: {pendingCallIds}). +Summarize these results to the user naturally. Do NOT repeat information you have already told the user.`; + +export type AsyncToolDuplicateMode = 'allow' | 'replace' | 'reject' | 'confirm'; + +type RunningTask = { + ctx: AsyncRunContext; + task: Task; +}; + +type PendingUpdate = { + ctx: AsyncRunContext; + items: ToolItem[]; +}; + +type ToolItem = FunctionCall | FunctionCallOutput; + +type AsyncToolSession = RunContext['session'] & { + agentState: 'initializing' | 'idle' | 'listening' | 'thinking' | 'speaking'; + _toolItemsAdded: (items: ToolItem[]) => void; +}; + +const runningTasks = new Map< + string, + { jobCtx: JobContext | undefined; task: RunningTask } +>(); + +function formatTemplate(template: string, values: Record): string { + return template.replace(/\{(\w+)\}/g, (_, key: string) => values[key] ?? ''); +} + +function stringifyOutput(output: unknown): string { + if (typeof output === 'string') return output; + return JSON.stringify(output); +} + +function cloneFunctionCall(functionCall: FunctionCall, callId: string): FunctionCall { + return FunctionCall.create({ + callId, + name: functionCall.name, + args: functionCall.args, + extra: functionCall.extra, + groupId: functionCall.groupId, + thoughtSignature: functionCall.thoughtSignature, + }); +} + +function makeToolItems( + functionCall: FunctionCall, + output: unknown, + callId: string, +): ToolItem[] | undefined { + const toolCall = cloneFunctionCall(functionCall, callId); + + if (output instanceof Error) { + return [ + toolCall, + FunctionCallOutput.create({ + name: toolCall.name, + callId: toolCall.callId, + output: isToolError(output) ? output.message : 'An internal error occurred', + isError: true, + }), + ]; + } + + const serializedOutput = JSON.stringify(output); + if (serializedOutput === undefined) return undefined; + + return [ + toolCall, + FunctionCallOutput.create({ + name: toolCall.name, + callId: toolCall.callId, + output: serializedOutput, + isError: false, + }), + ]; +} + +function addConfirmDuplicateParameter( + parameters: FunctionTool['parameters'], +): FunctionTool['parameters'] { + if (isZodSchema(parameters) && isZodObjectSchema(parameters)) { + const zodObject = parameters as typeof parameters & { + extend?: (shape: Record) => typeof parameters; + }; + try { + return ( + zodObject.extend?.({ + [CONFIRM_DUPLICATE_PARAM]: z + .boolean() + .optional() + .describe( + 'Set this to true to confirm you want to run a duplicate. Only do this when user confirms the duplication is needed.', + ), + }) ?? parameters + ); + } catch { + return parameters; + } + } + + const rawSchema = parameters as JSONSchema7; + if (typeof rawSchema === 'object' && rawSchema !== null && !isZodSchema(rawSchema)) { + return { + ...rawSchema, + properties: { + ...(typeof rawSchema.properties === 'object' ? rawSchema.properties : {}), + [CONFIRM_DUPLICATE_PARAM]: { + type: 'boolean', + description: + 'Set this to true to confirm you want to run a duplicate. Only do this when user confirms the duplication is needed.', + default: false, + }, + }, + } as FunctionTool['parameters']; + } + + return parameters; +} + +export class AsyncRunContext extends RunContext { + private readonly pendingUpdateFuture: Promise; + private resolvePendingUpdate!: (value: unknown) => void; + private rejectPendingUpdate!: (error: Error) => void; + private pendingUpdateDone = false; + private stepIdx = 0; + + constructor({ + runCtx, + toolset, + }: { + runCtx: RunContext; + toolset: AsyncToolset; + }) { + super(runCtx.session, runCtx.speechHandle, runCtx.functionCall); + this._toolset = toolset; + this.pendingUpdateFuture = new Promise((resolve, reject) => { + this.resolvePendingUpdate = resolve; + this.rejectPendingUpdate = reject; + }); + } + + /** @internal */ + readonly _toolset: AsyncToolset; + + /** + * Push an intermediate progress update into the conversation. + * + * The first update completes the original tool call immediately. Later updates + * are inserted as background tool outputs and trigger a follow-up reply when the + * agent is idle. + */ + async update(message: string | unknown, options?: { template?: string }): Promise { + const output = + typeof message === 'string' + ? formatTemplate(options?.template ?? UPDATE_TEMPLATE, { + functionName: this.functionCall.name, + callId: this.functionCall.callId, + message, + }) + : message; + + if (!this.pendingUpdateDone) { + this.pendingUpdateDone = true; + this.functionCall.extra.__livekit_agents_tool_pending = true; + this.resolvePendingUpdate(output); + return; + } + + this.stepIdx += 1; + const items = makeToolItems( + this.functionCall, + output, + `${this.functionCall.callId}_update_${this.stepIdx}`, + ); + if (items) await this._toolset.enqueueReply(this, items); + } + + /** @internal */ + _resolvePending(value: unknown): void { + if (this.pendingUpdateDone) return; + this.pendingUpdateDone = true; + this.resolvePendingUpdate(value); + } + + /** @internal */ + _rejectPending(error: Error): void { + if (this.pendingUpdateDone) return; + this.pendingUpdateDone = true; + this.rejectPendingUpdate(error); + } + + /** @internal */ + get _pending(): Promise { + return this.pendingUpdateFuture; + } + + /** @internal */ + get _hasPendingUpdate(): boolean { + return this.pendingUpdateDone; + } +} + +export interface AsyncToolOptions extends ToolOptions { + ctx: AsyncRunContext; +} + +export interface AsyncToolsetOptions { + id?: string; + tools?: ToolContext; + onDuplicateCall?: AsyncToolDuplicateMode; +} + +export class AsyncToolset { + readonly id: string; + readonly tools: ToolContext; + + private readonly onDuplicateCall: AsyncToolDuplicateMode; + private readonly localRunningTasks = new Map>(); + private pendingUpdates: PendingUpdate[] = []; + private replyTask?: Task; + private closed = false; + + constructor({ + id = 'async_tools', + tools = {}, + onDuplicateCall = 'confirm', + }: AsyncToolsetOptions) { + this.id = id; + this.onDuplicateCall = onDuplicateCall; + this.tools = { + ...Object.fromEntries( + Object.entries(tools).map(([name, functionTool]) => [ + name, + this.wrapTool(name, functionTool), + ]), + ), + get_running_tasks: this.getRunningTasksTool(), + cancel_task: this.cancelTaskTool(), + } as ToolContext; + } + + async cancel(callId: string): Promise { + const running = this.localRunningTasks.get(callId); + if (!running) return false; + + if (!running.ctx.speechHandle.allowInterruptions) { + throw new ToolError( + `Tool call ${callId} is not cancellable because interruptions are disallowed`, + ); + } + + running.task.cancel(); + return true; + } + + async close(): Promise { + this.closed = true; + this.replyTask?.cancel(); + for (const running of this.localRunningTasks.values()) { + running.task.cancel(); + } + await Promise.allSettled([ + ...(this.replyTask ? [this.replyTask.result] : []), + ...Array.from(this.localRunningTasks.values()).map((running) => running.task.result), + ]); + this.localRunningTasks.clear(); + } + + async aclose(): Promise { + await this.close(); + } + + /** @internal */ + async enqueueReply(ctx: AsyncRunContext, items: ToolItem[]): Promise { + if (this.closed) return; + + const agent = ctx.session.currentAgent; + const chatCtx = agent.chatCtx.copy(); + chatCtx.insert(items); + await agent.updateChatCtx(chatCtx); + ctx.session._toolItemsAdded(items); + + this.pendingUpdates.push({ ctx, items }); + + if (!this.replyTask || this.replyTask.done) { + this.replyTask = Task.from( + (controller) => this.deliverReply(ctx.session, controller.signal), + undefined, + 'asyncToolsetReply', + ); + } + } + + private wrapTool( + name: string, + functionTool: FunctionTool, + ): FunctionTool { + return tool({ + description: functionTool.description, + parameters: + this.onDuplicateCall === 'confirm' + ? addConfirmDuplicateParameter(functionTool.parameters) + : functionTool.parameters, + flags: functionTool.flags, + execute: async (rawArgs: Parameters, opts: ToolOptions) => { + const args = { ...rawArgs } as Parameters & Record; + const confirmDuplicate = Boolean(args[CONFIRM_DUPLICATE_PARAM]); + delete args[CONFIRM_DUPLICATE_PARAM]; + + const duplicateResult = await this.checkDuplicate(name, confirmDuplicate); + if (duplicateResult !== undefined) return duplicateResult; + + if (this.localRunningTasks.has(opts.toolCallId)) { + throw new Error(`Task already running for call_id: ${opts.toolCallId}`); + } + + const asyncCtx = new AsyncRunContext({ runCtx: opts.ctx, toolset: this }); + const controller = new AbortController(); + + const task = Task.from( + async () => { + let output: unknown; + try { + output = await functionTool.execute(args, { + ...opts, + ctx: asyncCtx, + abortSignal: controller.signal, + } as AsyncToolOptions); + } catch (error) { + if (controller.signal.aborted) { + logger.debug({ callId: opts.toolCallId, function: name }, 'async tool cancelled'); + asyncCtx._resolvePending(undefined); + return; + } + + output = error instanceof Error ? error : new Error(String(error)); + logger.error( + { callId: opts.toolCallId, function: name, error }, + 'error in async tool', + ); + } + + if (!asyncCtx._hasPendingUpdate) { + if (output instanceof Error) asyncCtx._rejectPending(output); + else asyncCtx._resolvePending(output); + return; + } + + if (output === undefined || output === null) return; + + const items = makeToolItems( + asyncCtx.functionCall, + output, + `${opts.toolCallId}_finished`, + ); + if (items) await this.enqueueReply(asyncCtx, items); + }, + controller, + `asyncTool:${name}`, + ); + + const runningTask = { ctx: asyncCtx, task }; + this.localRunningTasks.set(opts.toolCallId, runningTask); + const jobCtx = getJobContext(false); + runningTasks.set(opts.toolCallId, { + jobCtx, + task: runningTask as unknown as RunningTask, + }); + + task.addDoneCallback(() => { + this.localRunningTasks.delete(opts.toolCallId); + const registered = runningTasks.get(opts.toolCallId); + if (registered?.task === runningTask) runningTasks.delete(opts.toolCallId); + }); + + return await asyncCtx._pending; + }, + }); + } + + private getRunningTasksTool(): FunctionTool, UserData, unknown[]> { + return tool({ + description: 'Get the list of running async tool calls across all async toolsets.', + execute: async () => { + const jobCtx = getJobContext(false); + return Array.from(runningTasks.values()) + .filter((running) => running.jobCtx === jobCtx) + .map((running) => running.task.ctx.functionCall.toJSON()); + }, + }); + } + + private cancelTaskTool(): FunctionTool<{ call_id: string }, UserData, string> { + return tool({ + description: 'Cancel a running async tool call by call_id.', + parameters: z.object({ call_id: z.string() }), + execute: async ({ call_id }) => { + const jobCtx = getJobContext(false); + const running = runningTasks.get(call_id); + if ( + running && + running.jobCtx === jobCtx && + (await running.task.ctx._toolset.cancel(call_id)) + ) { + return `Task ${call_id} cancelled successfully.`; + } + return `Task ${call_id} not found or already completed.`; + }, + }); + } + + private async deliverReply( + session: AsyncToolSession, + abortSignal: AbortSignal, + ): Promise { + await waitForInactive(session, abortSignal); + + const updates = this.pendingUpdates; + this.pendingUpdates = []; + + const pendingItems = updates.flatMap((update) => update.items); + if (pendingItems.length === 0) return; + + const agentChatItems = session.currentAgent.chatCtx.items; + const latestPendingItem = pendingItems[pendingItems.length - 1]; + const latestAgentItem = agentChatItems[agentChatItems.length - 1]; + if ( + latestPendingItem && + latestAgentItem && + latestAgentItem.createdAt > latestPendingItem.createdAt + ) { + logger.debug('skipping async toolset reply because agent already spoke after updates'); + return; + } + + const pendingCallIds = pendingItems + .filter((item) => item.type === 'function_call_output') + .map((item) => item.callId); + + session.generateReply({ + instructions: formatTemplate(REPLY_INSTRUCTIONS, { + pendingCallIds: pendingCallIds.join(', '), + }), + toolChoice: 'none', + }); + } + + private async checkDuplicate( + functionName: string, + confirmDuplicate: boolean, + ): Promise { + if (this.onDuplicateCall === 'allow') return undefined; + + const runningFunctionCalls = Array.from(this.localRunningTasks.values()) + .map((running) => running.ctx.functionCall) + .filter((functionCall) => functionCall.name === functionName); + + if (runningFunctionCalls.length === 0) return undefined; + + if (this.onDuplicateCall === 'replace') { + const results = await Promise.allSettled( + runningFunctionCalls.map((functionCall) => this.cancel(functionCall.callId)), + ); + const errors = results.filter( + (result): result is PromiseRejectedResult => result.status === 'rejected', + ); + if (errors.length > 0) { + throw new ToolError( + `Failed to cancel duplicate tool calls: ${errors + .map((error) => String(error.reason)) + .join('\n')}`, + ); + } + return undefined; + } + + const runningFunctionCallsText = runningFunctionCalls + .map((functionCall) => stringifyOutput(functionCall.toJSON())) + .join('\n'); + + if (this.onDuplicateCall === 'reject') { + return formatTemplate(DUPLICATE_REJECT, { + functionName, + runningFunctionCalls: runningFunctionCallsText, + }); + } + + if (this.onDuplicateCall === 'confirm' && !confirmDuplicate) { + return formatTemplate(DUPLICATE_CONFIRM, { + functionName, + runningFunctionCalls: runningFunctionCallsText, + }); + } + + return undefined; + } +} + +async function waitForInactive(session: AsyncToolSession, abortSignal: AbortSignal): Promise { + while (session.agentState === 'speaking' || session.agentState === 'thinking') { + await delay(50, { signal: abortSignal }); + } +} diff --git a/agents/src/llm/index.ts b/agents/src/llm/index.ts index 6d9bad4f4..8cf030d14 100644 --- a/agents/src/llm/index.ts +++ b/agents/src/llm/index.ts @@ -17,6 +17,14 @@ export { type ToolType, } from './tool_context.js'; +export { + AsyncRunContext, + AsyncToolset, + type AsyncToolDuplicateMode, + type AsyncToolOptions, + type AsyncToolsetOptions, +} from './async_toolset.js'; + export { AgentHandoffItem, ChatContext,