@@ -344,3 +344,215 @@ def call_external_api2(tool_context: ToolContext) -> int:
344344 assert parts [0 ].function_response .response == {'result' : 1 }
345345 assert parts [1 ].function_response .name == 'call_external_api2'
346346 assert parts [1 ].function_response .response == {'result' : 2 }
347+
348+
349+ def test_function_get_auth_response_partial ():
350+ id_1 = 'id_1'
351+ id_2 = 'id_2'
352+ responses = [
353+ [
354+ function_call (id_1 , 'call_external_api1' , {}),
355+ function_call (id_2 , 'call_external_api2' , {}),
356+ ],
357+ [
358+ types .Part .from_text (text = 'response1' ),
359+ ],
360+ [
361+ types .Part .from_text (text = 'response2' ),
362+ ],
363+ ]
364+
365+ mock_model = testing_utils .MockModel .create (responses = responses )
366+ function_invoked = 0
367+
368+ auth_config1 = AuthConfig (
369+ auth_scheme = OAuth2 (
370+ flows = OAuthFlows (
371+ authorizationCode = OAuthFlowAuthorizationCode (
372+ authorizationUrl = 'https://accounts.google.com/o/oauth2/auth' ,
373+ tokenUrl = 'https://oauth2.googleapis.com/token' ,
374+ scopes = {
375+ 'https://www.googleapis.com/auth/calendar' : (
376+ 'See, edit, share, and permanently delete all the'
377+ ' calendars you can access using Google Calendar'
378+ )
379+ },
380+ )
381+ )
382+ ),
383+ raw_auth_credential = AuthCredential (
384+ auth_type = AuthCredentialTypes .OAUTH2 ,
385+ oauth2 = OAuth2Auth (
386+ client_id = 'oauth_client_id_1' ,
387+ client_secret = 'oauth_client_secret1' ,
388+ ),
389+ ),
390+ )
391+ auth_config2 = AuthConfig (
392+ auth_scheme = OAuth2 (
393+ flows = OAuthFlows (
394+ authorizationCode = OAuthFlowAuthorizationCode (
395+ authorizationUrl = 'https://accounts.google.com/o/oauth2/auth' ,
396+ tokenUrl = 'https://oauth2.googleapis.com/token' ,
397+ scopes = {
398+ 'https://www.googleapis.com/auth/calendar' : (
399+ 'See, edit, share, and permanently delete all the'
400+ ' calendars you can access using Google Calendar'
401+ )
402+ },
403+ )
404+ )
405+ ),
406+ raw_auth_credential = AuthCredential (
407+ auth_type = AuthCredentialTypes .OAUTH2 ,
408+ oauth2 = OAuth2Auth (
409+ client_id = 'oauth_client_id_2' ,
410+ client_secret = 'oauth_client_secret2' ,
411+ ),
412+ ),
413+ )
414+
415+ auth_response1 = AuthConfig (
416+ auth_scheme = OAuth2 (
417+ flows = OAuthFlows (
418+ authorizationCode = OAuthFlowAuthorizationCode (
419+ authorizationUrl = 'https://accounts.google.com/o/oauth2/auth' ,
420+ tokenUrl = 'https://oauth2.googleapis.com/token' ,
421+ scopes = {
422+ 'https://www.googleapis.com/auth/calendar' : (
423+ 'See, edit, share, and permanently delete all the'
424+ ' calendars you can access using Google Calendar'
425+ )
426+ },
427+ )
428+ )
429+ ),
430+ raw_auth_credential = AuthCredential (
431+ auth_type = AuthCredentialTypes .OAUTH2 ,
432+ oauth2 = OAuth2Auth (
433+ client_id = 'oauth_client_id_1' ,
434+ client_secret = 'oauth_client_secret1' ,
435+ ),
436+ ),
437+ exchanged_auth_credential = AuthCredential (
438+ auth_type = AuthCredentialTypes .OAUTH2 ,
439+ oauth2 = OAuth2Auth (
440+ client_id = 'oauth_client_id_1' ,
441+ client_secret = 'oauth_client_secret1' ,
442+ access_token = 'token1' ,
443+ ),
444+ ),
445+ )
446+ auth_response2 = AuthConfig (
447+ auth_scheme = OAuth2 (
448+ flows = OAuthFlows (
449+ authorizationCode = OAuthFlowAuthorizationCode (
450+ authorizationUrl = 'https://accounts.google.com/o/oauth2/auth' ,
451+ tokenUrl = 'https://oauth2.googleapis.com/token' ,
452+ scopes = {
453+ 'https://www.googleapis.com/auth/calendar' : (
454+ 'See, edit, share, and permanently delete all the'
455+ ' calendars you can access using Google Calendar'
456+ )
457+ },
458+ )
459+ )
460+ ),
461+ raw_auth_credential = AuthCredential (
462+ auth_type = AuthCredentialTypes .OAUTH2 ,
463+ oauth2 = OAuth2Auth (
464+ client_id = 'oauth_client_id_2' ,
465+ client_secret = 'oauth_client_secret2' ,
466+ ),
467+ ),
468+ exchanged_auth_credential = AuthCredential (
469+ auth_type = AuthCredentialTypes .OAUTH2 ,
470+ oauth2 = OAuth2Auth (
471+ client_id = 'oauth_client_id_2' ,
472+ client_secret = 'oauth_client_secret2' ,
473+ access_token = 'token2' ,
474+ ),
475+ ),
476+ )
477+
478+ def call_external_api1 (tool_context : ToolContext ) -> int :
479+ nonlocal function_invoked
480+ function_invoked += 1
481+ auth_response = tool_context .get_auth_response (auth_config1 )
482+ if not auth_response :
483+ tool_context .request_credential (auth_config1 )
484+ return
485+ assert auth_response == auth_response1 .exchanged_auth_credential
486+ return 1
487+
488+ def call_external_api2 (tool_context : ToolContext ) -> int :
489+ nonlocal function_invoked
490+ function_invoked += 1
491+ auth_response = tool_context .get_auth_response (auth_config2 )
492+ if not auth_response :
493+ tool_context .request_credential (auth_config2 )
494+ return
495+ assert auth_response == auth_response2 .exchanged_auth_credential
496+ return 2
497+
498+ agent = Agent (
499+ name = 'root_agent' ,
500+ model = mock_model ,
501+ tools = [call_external_api1 , call_external_api2 ],
502+ )
503+ runner = testing_utils .InMemoryRunner (agent )
504+ runner .run ('test' )
505+ request_euc_function_call_event = runner .session .events [- 3 ]
506+ function_response1 = types .FunctionResponse (
507+ name = request_euc_function_call_event .content .parts [0 ].function_call .name ,
508+ response = auth_response1 .model_dump (),
509+ )
510+ function_response1 .id = request_euc_function_call_event .content .parts [
511+ 0
512+ ].function_call .id
513+
514+ function_response2 = types .FunctionResponse (
515+ name = request_euc_function_call_event .content .parts [1 ].function_call .name ,
516+ response = auth_response2 .model_dump (),
517+ )
518+ function_response2 .id = request_euc_function_call_event .content .parts [
519+ 1
520+ ].function_call .id
521+ runner .run (
522+ new_message = types .Content (
523+ role = 'user' ,
524+ parts = [
525+ types .Part (function_response = function_response1 ),
526+ ],
527+ ),
528+ )
529+
530+ assert function_invoked == 3
531+ assert len (mock_model .requests ) == 3
532+ request = mock_model .requests [- 1 ]
533+ content = request .contents [- 1 ]
534+ parts = content .parts
535+ assert len (parts ) == 2
536+ assert parts [0 ].function_response .name == 'call_external_api1'
537+ assert parts [0 ].function_response .response == {'result' : 1 }
538+ assert parts [1 ].function_response .name == 'call_external_api2'
539+ assert parts [1 ].function_response .response == {'result' : None }
540+
541+ runner .run (
542+ new_message = types .Content (
543+ role = 'user' ,
544+ parts = [
545+ types .Part (function_response = function_response2 ),
546+ ],
547+ ),
548+ )
549+ # assert function_invoked == 4
550+ assert len (mock_model .requests ) == 4
551+ request = mock_model .requests [- 1 ]
552+ content = request .contents [- 1 ]
553+ parts = content .parts
554+ assert len (parts ) == 2
555+ assert parts [0 ].function_response .name == 'call_external_api1'
556+ assert parts [0 ].function_response .response == {'result' : None }
557+ assert parts [1 ].function_response .name == 'call_external_api2'
558+ assert parts [1 ].function_response .response == {'result' : 2 }
0 commit comments