Skip to content

Commit e2518dc

Browse files
hangfeicopybara-github
authored andcommitted
fix: Ignore AsyncGenerator return types in function declarations
For Vertex model backend, we send response back. This doesn't work for streaming tools that the return type is AsyncGenerator. So the fix here is to ignore the return type when it's AsyncGenerator. We can't distinguish streaming vs non-streaming tool with AsyncGenerator though as LiveRequestQueue is optional in streaming tool. Adds an `ignore_response` option to `build_function_declaration` to skip including the return type in the function declaration. This is enabled for tools that return `AsyncGenerator`, as the model does not yet support understanding these return types, while streaming tools can still handle them. Also, removes redundant return statements in `_get_mandatory_params`. PiperOrigin-RevId: 794392846
1 parent 8c65967 commit e2518dc

22 files changed

Lines changed: 134 additions & 54 deletions

File tree

contributing/samples/langchain_structured_tool_agent/agent.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
"""
1616
This agent aims to test the Langchain tool with Langchain's StructuredTool
1717
"""
18+
from __future__ import annotations
19+
1820
from google.adk.agents.llm_agent import Agent
1921
from google.adk.tools.langchain_tool import LangchainTool
2022
from langchain.tools import tool
@@ -23,11 +25,13 @@
2325

2426

2527
async def add(x, y) -> int:
28+
"""Adds two numbers."""
2629
return x + y
2730

2831

2932
@tool
3033
def minus(x, y) -> int:
34+
"""Minus two numbers."""
3135
return x - y
3236

3337

src/google/adk/agents/llm_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@
109109

110110

111111
async def _convert_tool_union_to_tools(
112-
tool_union: ToolUnion, ctx: ReadonlyContext
112+
tool_union: ToolUnion,
113+
ctx: ReadonlyContext,
113114
) -> list[BaseTool]:
114115
if isinstance(tool_union, BaseTool):
115116
return [tool_union]

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ async def run_live(
7373
invocation_context: InvocationContext,
7474
) -> AsyncGenerator[Event, None]:
7575
"""Runs the flow using live api."""
76-
llm_request = LlmRequest()
76+
llm_request = LlmRequest(live_connect_config=types.LiveConnectConfig())
7777
event_id = Event.new_id()
7878

7979
# Preprocess before calling the LLM.
@@ -373,7 +373,9 @@ async def _run_one_step_async(
373373
yield event
374374

375375
async def _preprocess_async(
376-
self, invocation_context: InvocationContext, llm_request: LlmRequest
376+
self,
377+
invocation_context: InvocationContext,
378+
llm_request: LlmRequest,
377379
) -> AsyncGenerator[Event, None]:
378380
from ...agents.llm_agent import LlmAgent
379381

src/google/adk/flows/llm_flows/basic.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -57,30 +57,31 @@ async def run_async(
5757
if agent.output_schema and not agent.tools:
5858
llm_request.set_output_schema(agent.output_schema)
5959

60-
llm_request.live_connect_config.response_modalities = (
61-
invocation_context.run_config.response_modalities
62-
)
63-
llm_request.live_connect_config.speech_config = (
64-
invocation_context.run_config.speech_config
65-
)
66-
llm_request.live_connect_config.output_audio_transcription = (
67-
invocation_context.run_config.output_audio_transcription
68-
)
69-
llm_request.live_connect_config.input_audio_transcription = (
70-
invocation_context.run_config.input_audio_transcription
71-
)
72-
llm_request.live_connect_config.realtime_input_config = (
73-
invocation_context.run_config.realtime_input_config
74-
)
75-
llm_request.live_connect_config.enable_affective_dialog = (
76-
invocation_context.run_config.enable_affective_dialog
77-
)
78-
llm_request.live_connect_config.proactivity = (
79-
invocation_context.run_config.proactivity
80-
)
81-
llm_request.live_connect_config.session_resumption = (
82-
invocation_context.run_config.session_resumption
83-
)
60+
if llm_request.live_connect_config:
61+
llm_request.live_connect_config.response_modalities = (
62+
invocation_context.run_config.response_modalities
63+
)
64+
llm_request.live_connect_config.speech_config = (
65+
invocation_context.run_config.speech_config
66+
)
67+
llm_request.live_connect_config.output_audio_transcription = (
68+
invocation_context.run_config.output_audio_transcription
69+
)
70+
llm_request.live_connect_config.input_audio_transcription = (
71+
invocation_context.run_config.input_audio_transcription
72+
)
73+
llm_request.live_connect_config.realtime_input_config = (
74+
invocation_context.run_config.realtime_input_config
75+
)
76+
llm_request.live_connect_config.enable_affective_dialog = (
77+
invocation_context.run_config.enable_affective_dialog
78+
)
79+
llm_request.live_connect_config.proactivity = (
80+
invocation_context.run_config.proactivity
81+
)
82+
llm_request.live_connect_config.session_resumption = (
83+
invocation_context.run_config.session_resumption
84+
)
8485

8586
# TODO: handle tool append here, instead of in BaseTool.process_llm_request.
8687

src/google/adk/models/llm_request.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
from __future__ import annotations
1616

17+
from collections.abc import AsyncGenerator as ABCAsyncGenerator
18+
import inspect
19+
from typing import get_origin
1720
from typing import Optional
1821

1922
from google.genai import types
@@ -22,6 +25,7 @@
2225
from pydantic import Field
2326

2427
from ..tools.base_tool import BaseTool
28+
from ..tools.function_tool import FunctionTool
2529

2630

2731
def _find_tool_with_function_declarations(
@@ -66,13 +70,13 @@ class LlmRequest(BaseModel):
6670
config: types.GenerateContentConfig = Field(
6771
default_factory=types.GenerateContentConfig
6872
)
69-
live_connect_config: types.LiveConnectConfig = Field(
70-
default_factory=types.LiveConnectConfig
71-
)
7273
"""Additional config for the generate content request.
7374
7475
tools in generate_content_config should not be set.
7576
"""
77+
live_connect_config: Optional[types.LiveConnectConfig] = None
78+
"""Live connection config.
79+
"""
7680
tools_dict: dict[str, BaseTool] = Field(default_factory=dict, exclude=True)
7781
"""The tools dictionary."""
7882

@@ -99,7 +103,23 @@ def append_tools(self, tools: list[BaseTool]) -> None:
99103
return
100104
declarations = []
101105
for tool in tools:
102-
declaration = tool._get_declaration()
106+
if self.live_connect_config is not None:
107+
# ignore response for tools that returns AsyncGenerator that the model
108+
# can't understand yet even though the model can't handle it, streaming
109+
# tools can handle it.
110+
# to check type, use typing.collections.abc.AsyncGenerator and not
111+
# typing.AsyncGenerator
112+
is_async_generator_return = False
113+
if isinstance(tool, FunctionTool):
114+
signature = inspect.signature(tool.func)
115+
is_async_generator_return = (
116+
get_origin(signature.return_annotation) is ABCAsyncGenerator
117+
)
118+
declaration = tool._get_declaration(
119+
ignore_return_declaration=is_async_generator_return
120+
)
121+
else:
122+
declaration = tool._get_declaration()
103123
if declaration:
104124
declarations.append(declaration)
105125
self.tools_dict[tool.name] = tool

src/google/adk/tools/_automatic_function_calling_util.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def build_function_declaration(
195195
func: Union[Callable, BaseModel],
196196
ignore_params: Optional[list[str]] = None,
197197
variant: GoogleLLMVariant = GoogleLLMVariant.GEMINI_API,
198+
ignore_return_declaration: bool = False,
198199
) -> types.FunctionDeclaration:
199200
signature = inspect.signature(func)
200201
should_update_signature = False
@@ -232,9 +233,11 @@ def build_function_declaration(
232233
new_func.__annotations__ = func.__annotations__
233234

234235
return (
235-
from_function_with_options(func, variant)
236+
from_function_with_options(func, variant, ignore_return_declaration)
236237
if not should_update_signature
237-
else from_function_with_options(new_func, variant)
238+
else from_function_with_options(
239+
new_func, variant, ignore_return_declaration
240+
)
238241
)
239242

240243

@@ -293,6 +296,7 @@ def build_function_declaration_util(
293296
def from_function_with_options(
294297
func: Callable,
295298
variant: GoogleLLMVariant = GoogleLLMVariant.GEMINI_API,
299+
ignore_return_declaration: bool = False,
296300
) -> 'types.FunctionDeclaration':
297301

298302
parameters_properties = {}
@@ -324,7 +328,8 @@ def from_function_with_options(
324328
declaration.parameters
325329
)
326330
)
327-
if variant == GoogleLLMVariant.GEMINI_API:
331+
332+
if variant == GoogleLLMVariant.GEMINI_API or ignore_return_declaration:
328333
return declaration
329334

330335
return_annotation = inspect.signature(func).return_annotation

src/google/adk/tools/agent_tool.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
from typing import Any
18+
from typing import Optional
1819
from typing import TYPE_CHECKING
1920

2021
from google.genai import types
@@ -61,7 +62,9 @@ def populate_name(cls, data: Any) -> Any:
6162
return data
6263

6364
@override
64-
def _get_declaration(self) -> types.FunctionDeclaration:
65+
def _get_declaration(
66+
self, ignore_return_declaration: bool = False
67+
) -> Optional[types.FunctionDeclaration]:
6568
from ..agents.llm_agent import LlmAgent
6669
from ..utils.variant_utils import GoogleLLMVariant
6770

src/google/adk/tools/application_integration_tool/integration_connector_tool.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing import Optional
2121
from typing import Union
2222

23-
from google.genai.types import FunctionDeclaration
23+
from google.genai import types
2424
from typing_extensions import override
2525

2626
from ...auth.auth_credential import AuthCredential
@@ -115,7 +115,9 @@ def __init__(
115115
self._auth_credential = auth_credential
116116

117117
@override
118-
def _get_declaration(self) -> FunctionDeclaration:
118+
def _get_declaration(
119+
self, ignore_return_declaration: bool = False
120+
) -> Optional[types.FunctionDeclaration]:
119121
"""Returns the function declaration in the Gemini Schema format."""
120122
schema_dict = self._rest_api_tool._operation_parser.get_json_schema()
121123
for field in self.EXCLUDE_FIELDS:
@@ -126,7 +128,7 @@ def _get_declaration(self) -> FunctionDeclaration:
126128
schema_dict['required'].remove(field)
127129

128130
parameters = _to_gemini_schema(schema_dict)
129-
function_decl = FunctionDeclaration(
131+
function_decl = types.FunctionDeclaration(
130132
name=self.name, description=self.description, parameters=parameters
131133
)
132134
return function_decl

src/google/adk/tools/base_tool.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ def __init__(
7878
self.is_long_running = is_long_running
7979
self.custom_metadata = custom_metadata
8080

81-
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
81+
def _get_declaration(
82+
self, ignore_return_declaration: bool = False
83+
) -> Optional[types.FunctionDeclaration]:
8284
"""Gets the OpenAPI specification of this tool in the form of a FunctionDeclaration.
8385
8486
NOTE:

src/google/adk/tools/crewai_tool.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
from __future__ import annotations
1616

17+
from typing import Optional
18+
1719
from google.genai import types
1820
from typing_extensions import override
1921

@@ -62,7 +64,9 @@ def __init__(self, tool: CrewaiBaseTool, *, name: str, description: str):
6264
self.description = tool.description
6365

6466
@override
65-
def _get_declaration(self) -> types.FunctionDeclaration:
67+
def _get_declaration(
68+
self, ignore_return_declaration: bool = False
69+
) -> Optional[types.FunctionDeclaration]:
6670
"""Build the function declaration for the tool."""
6771
function_declaration = _automatic_function_calling_util.build_function_declaration_for_params_for_crewai(
6872
False,

0 commit comments

Comments
 (0)