Skip to content

Commit c0c67c8

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Add ADK-based agent factory for Tau-bench
PiperOrigin-RevId: 825674874
1 parent 87f415a commit c0c67c8

1 file changed

Lines changed: 140 additions & 0 deletions

File tree

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Allows to run an ADK agent implementation with a Tau-bench environment.
16+
17+
Note that Tau-bench needs to be installed to run this module. To install
18+
Tau-bench you can follow the steps below:
19+
20+
```
21+
git clone https://github.com/sierra-research/tau-bench.git
22+
cd tau-bench/
23+
pip install -e . --quiet
24+
```
25+
"""
26+
from __future__ import annotations
27+
28+
from typing import Any
29+
30+
import adk_agent
31+
from google.genai import types
32+
from tau_bench import envs
33+
from tau_bench import types as tau_bench_types
34+
from tau_bench.agents import tool_calling_agent
35+
36+
37+
class _EnvWrapper:
38+
"""Wraps the Tau-bench environment to match ADK environment protocol."""
39+
40+
def __init__(self, env: envs.Env):
41+
self._env = env
42+
43+
def step(self, action: types.Part) -> adk_agent.EnvResponse:
44+
if function_call := action.function_call:
45+
return self._env.step(
46+
tau_bench_types.Action(
47+
name=function_call.name, kwargs=function_call.args
48+
)
49+
)
50+
return self._env.step(
51+
tau_bench_types.Action(
52+
name=tau_bench_types.RESPOND_ACTION_NAME,
53+
kwargs=dict(content=action.text),
54+
)
55+
)
56+
57+
def reset(self, task_index: int) -> adk_agent.EnvResponse:
58+
return self._env.reset(task_index)
59+
60+
61+
def _convert_tool(tool_def: dict[str, Any]) -> types.FunctionDeclaration:
62+
if tool_def['type'] != 'function':
63+
raise ValueError(f'Unsupported tool {tool_def}')
64+
return types.FunctionDeclaration(**tool_def['function'])
65+
66+
67+
class _ADKAgent(tool_calling_agent.ToolCallingAgent):
68+
"""ADK agent implementation for Tau Bench."""
69+
70+
def solve(
71+
self,
72+
env: envs.Env,
73+
task_index: int | None = None,
74+
max_num_steps: int = 30,
75+
) -> tau_bench_types.SolveResult:
76+
"""Solves the task using ADK agent.
77+
78+
Args:
79+
env: The environment to solve the task in.
80+
task_index: The index of the task to solve.
81+
max_num_steps: The maximum number of steps to run the agent.
82+
83+
Returns:
84+
The result of the solve.
85+
"""
86+
# Thought-signature is excluded from the message serialization for the
87+
# following reasons:
88+
# - it is not serializable out of the box
89+
# - it is not relevant for trajectory validation as agent inputs / outputs
90+
# are.
91+
content_exclusion = {'parts': {'__all__': 'thought_signature'}}
92+
messages = [
93+
types.Content(
94+
role='system', parts=[types.Part(text=self.wiki)]
95+
).model_dump(exclude=content_exclusion),
96+
]
97+
reward = 0.0
98+
for event in adk_agent.run_environment_loop(
99+
instruction=self.wiki,
100+
env=_EnvWrapper(env),
101+
temperature=self.temperature,
102+
tools=[_convert_tool(t) for t in env.tools_info],
103+
task_index=task_index,
104+
max_num_steps=max_num_steps,
105+
):
106+
if not event.content:
107+
continue
108+
messages.append(event.content.model_dump(exclude=content_exclusion))
109+
reward = event.actions.state_delta.get('reward', reward)
110+
return tau_bench_types.SolveResult(
111+
reward=reward,
112+
info={},
113+
messages=messages,
114+
)
115+
116+
117+
# Equivalent of default `agent_factory` from Tau-bench in
118+
# https://github.com/sierra-research/tau-bench/blob/4754e6b406507dbcbce8e8b3855dcf80aaec18ac/tau_bench/run.py#L124
119+
def adk_agent_factory(
120+
tools_info: list[dict[str, Any]],
121+
wiki: str,
122+
config: tau_bench_types.RunConfig,
123+
) -> tool_calling_agent.ToolCallingAgent:
124+
"""Factory for creating a Tau-bench agent implemented with the ADK.
125+
126+
Args:
127+
tools_info: A list of tool definitions.
128+
wiki: The instructions for the agent.
129+
config: The run configuration.
130+
131+
Returns:
132+
An ADK agent.
133+
"""
134+
return _ADKAgent(
135+
tools_info=tools_info,
136+
wiki=wiki,
137+
model=config.model,
138+
provider=config.model_provider,
139+
temperature=config.temperature,
140+
)

0 commit comments

Comments
 (0)