Skip to content

Commit 0e72efb

Browse files
calvingilescopybara-github
authored andcommitted
fix: Call all tools in parallel calls during partial authentication
Copybara import of the project: -- ffd6184 by Calvin Giles <calvin.giles@trademe.co.nz>: fix: Call all tools in parallel calls during partial authentication -- c71782a by seanzhou1023 <seanzhou1023@gmail.com>: Update auth_preprocessor.py -- 843af6b by seanzhou1023 <seanzhou1023@gmail.com>: Update test_functions_request_euc.py -- 955e3fa by seanzhou1023 <seanzhou1023@gmail.com>: Update test_functions_request_euc.py COPYBARA_INTEGRATE_REVIEW=#853 from calvingiles:fix-parallel-auth-tool-calls f44671e PiperOrigin-RevId: 765639904
1 parent 036f954 commit 0e72efb

File tree

3 files changed

+234
-21
lines changed

3 files changed

+234
-21
lines changed

src/google/adk/auth/auth_preprocessor.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -100,23 +100,24 @@ async def run_async(
100100
function_calls = event.get_function_calls()
101101
if not function_calls:
102102
continue
103-
for function_call in function_calls:
104-
function_response_event = None
105-
if function_call.id in tools_to_resume:
106-
function_response_event = await functions.handle_function_calls_async(
107-
invocation_context,
108-
event,
109-
{
110-
tool.name: tool
111-
for tool in await agent.canonical_tools(
112-
ReadonlyContext(invocation_context)
113-
)
114-
},
115-
# there could be parallel function calls that require auth
116-
# auth response would be a dict keyed by function call id
117-
tools_to_resume,
118-
)
119-
if function_response_event:
103+
104+
if any([
105+
function_call.id in tools_to_resume
106+
for function_call in function_calls
107+
]):
108+
if function_response_event := await functions.handle_function_calls_async(
109+
invocation_context,
110+
event,
111+
{
112+
tool.name: tool
113+
for tool in await agent.canonical_tools(
114+
ReadonlyContext(invocation_context)
115+
)
116+
},
117+
# there could be parallel function calls that require auth
118+
# auth response would be a dict keyed by function call id
119+
tools_to_resume,
120+
):
120121
yield function_response_event
121122
return
122123
return

src/google/adk/flows/llm_flows/contents.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,10 +170,10 @@ def _rearrange_events_for_latest_function_response(
170170
for idx in range(function_call_event_idx + 1, len(events) - 1):
171171
event = events[idx]
172172
function_responses = event.get_function_responses()
173-
if (
174-
function_responses
175-
and function_responses[0].id in function_responses_ids
176-
):
173+
if function_responses and any([
174+
function_response.id in function_responses_ids
175+
for function_response in function_responses
176+
]):
177177
function_response_events.append(event)
178178
function_response_events.append(events[-1])
179179

tests/unittests/flows/llm_flows/test_functions_request_euc.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,3 +344,215 @@ def call_external_api2(tool_context: ToolContext) -> int:
344344
assert parts[0].function_response.response == {'result': 1}
345345
assert parts[1].function_response.name == 'call_external_api2'
346346
assert parts[1].function_response.response == {'result': 2}
347+
348+
349+
def test_function_get_auth_response_partial():
350+
id_1 = 'id_1'
351+
id_2 = 'id_2'
352+
responses = [
353+
[
354+
function_call(id_1, 'call_external_api1', {}),
355+
function_call(id_2, 'call_external_api2', {}),
356+
],
357+
[
358+
types.Part.from_text(text='response1'),
359+
],
360+
[
361+
types.Part.from_text(text='response2'),
362+
],
363+
]
364+
365+
mock_model = testing_utils.MockModel.create(responses=responses)
366+
function_invoked = 0
367+
368+
auth_config1 = AuthConfig(
369+
auth_scheme=OAuth2(
370+
flows=OAuthFlows(
371+
authorizationCode=OAuthFlowAuthorizationCode(
372+
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
373+
tokenUrl='https://oauth2.googleapis.com/token',
374+
scopes={
375+
'https://www.googleapis.com/auth/calendar': (
376+
'See, edit, share, and permanently delete all the'
377+
' calendars you can access using Google Calendar'
378+
)
379+
},
380+
)
381+
)
382+
),
383+
raw_auth_credential=AuthCredential(
384+
auth_type=AuthCredentialTypes.OAUTH2,
385+
oauth2=OAuth2Auth(
386+
client_id='oauth_client_id_1',
387+
client_secret='oauth_client_secret1',
388+
),
389+
),
390+
)
391+
auth_config2 = AuthConfig(
392+
auth_scheme=OAuth2(
393+
flows=OAuthFlows(
394+
authorizationCode=OAuthFlowAuthorizationCode(
395+
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
396+
tokenUrl='https://oauth2.googleapis.com/token',
397+
scopes={
398+
'https://www.googleapis.com/auth/calendar': (
399+
'See, edit, share, and permanently delete all the'
400+
' calendars you can access using Google Calendar'
401+
)
402+
},
403+
)
404+
)
405+
),
406+
raw_auth_credential=AuthCredential(
407+
auth_type=AuthCredentialTypes.OAUTH2,
408+
oauth2=OAuth2Auth(
409+
client_id='oauth_client_id_2',
410+
client_secret='oauth_client_secret2',
411+
),
412+
),
413+
)
414+
415+
auth_response1 = AuthConfig(
416+
auth_scheme=OAuth2(
417+
flows=OAuthFlows(
418+
authorizationCode=OAuthFlowAuthorizationCode(
419+
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
420+
tokenUrl='https://oauth2.googleapis.com/token',
421+
scopes={
422+
'https://www.googleapis.com/auth/calendar': (
423+
'See, edit, share, and permanently delete all the'
424+
' calendars you can access using Google Calendar'
425+
)
426+
},
427+
)
428+
)
429+
),
430+
raw_auth_credential=AuthCredential(
431+
auth_type=AuthCredentialTypes.OAUTH2,
432+
oauth2=OAuth2Auth(
433+
client_id='oauth_client_id_1',
434+
client_secret='oauth_client_secret1',
435+
),
436+
),
437+
exchanged_auth_credential=AuthCredential(
438+
auth_type=AuthCredentialTypes.OAUTH2,
439+
oauth2=OAuth2Auth(
440+
client_id='oauth_client_id_1',
441+
client_secret='oauth_client_secret1',
442+
access_token='token1',
443+
),
444+
),
445+
)
446+
auth_response2 = AuthConfig(
447+
auth_scheme=OAuth2(
448+
flows=OAuthFlows(
449+
authorizationCode=OAuthFlowAuthorizationCode(
450+
authorizationUrl='https://accounts.google.com/o/oauth2/auth',
451+
tokenUrl='https://oauth2.googleapis.com/token',
452+
scopes={
453+
'https://www.googleapis.com/auth/calendar': (
454+
'See, edit, share, and permanently delete all the'
455+
' calendars you can access using Google Calendar'
456+
)
457+
},
458+
)
459+
)
460+
),
461+
raw_auth_credential=AuthCredential(
462+
auth_type=AuthCredentialTypes.OAUTH2,
463+
oauth2=OAuth2Auth(
464+
client_id='oauth_client_id_2',
465+
client_secret='oauth_client_secret2',
466+
),
467+
),
468+
exchanged_auth_credential=AuthCredential(
469+
auth_type=AuthCredentialTypes.OAUTH2,
470+
oauth2=OAuth2Auth(
471+
client_id='oauth_client_id_2',
472+
client_secret='oauth_client_secret2',
473+
access_token='token2',
474+
),
475+
),
476+
)
477+
478+
def call_external_api1(tool_context: ToolContext) -> int:
479+
nonlocal function_invoked
480+
function_invoked += 1
481+
auth_response = tool_context.get_auth_response(auth_config1)
482+
if not auth_response:
483+
tool_context.request_credential(auth_config1)
484+
return
485+
assert auth_response == auth_response1.exchanged_auth_credential
486+
return 1
487+
488+
def call_external_api2(tool_context: ToolContext) -> int:
489+
nonlocal function_invoked
490+
function_invoked += 1
491+
auth_response = tool_context.get_auth_response(auth_config2)
492+
if not auth_response:
493+
tool_context.request_credential(auth_config2)
494+
return
495+
assert auth_response == auth_response2.exchanged_auth_credential
496+
return 2
497+
498+
agent = Agent(
499+
name='root_agent',
500+
model=mock_model,
501+
tools=[call_external_api1, call_external_api2],
502+
)
503+
runner = testing_utils.InMemoryRunner(agent)
504+
runner.run('test')
505+
request_euc_function_call_event = runner.session.events[-3]
506+
function_response1 = types.FunctionResponse(
507+
name=request_euc_function_call_event.content.parts[0].function_call.name,
508+
response=auth_response1.model_dump(),
509+
)
510+
function_response1.id = request_euc_function_call_event.content.parts[
511+
0
512+
].function_call.id
513+
514+
function_response2 = types.FunctionResponse(
515+
name=request_euc_function_call_event.content.parts[1].function_call.name,
516+
response=auth_response2.model_dump(),
517+
)
518+
function_response2.id = request_euc_function_call_event.content.parts[
519+
1
520+
].function_call.id
521+
runner.run(
522+
new_message=types.Content(
523+
role='user',
524+
parts=[
525+
types.Part(function_response=function_response1),
526+
],
527+
),
528+
)
529+
530+
assert function_invoked == 3
531+
assert len(mock_model.requests) == 3
532+
request = mock_model.requests[-1]
533+
content = request.contents[-1]
534+
parts = content.parts
535+
assert len(parts) == 2
536+
assert parts[0].function_response.name == 'call_external_api1'
537+
assert parts[0].function_response.response == {'result': 1}
538+
assert parts[1].function_response.name == 'call_external_api2'
539+
assert parts[1].function_response.response == {'result': None}
540+
541+
runner.run(
542+
new_message=types.Content(
543+
role='user',
544+
parts=[
545+
types.Part(function_response=function_response2),
546+
],
547+
),
548+
)
549+
# assert function_invoked == 4
550+
assert len(mock_model.requests) == 4
551+
request = mock_model.requests[-1]
552+
content = request.contents[-1]
553+
parts = content.parts
554+
assert len(parts) == 2
555+
assert parts[0].function_response.name == 'call_external_api1'
556+
assert parts[0].function_response.response == {'result': None}
557+
assert parts[1].function_response.name == 'call_external_api2'
558+
assert parts[1].function_response.response == {'result': 2}

0 commit comments

Comments
 (0)