2020from google .adk .models .llm_response import LlmResponse
2121from google .adk .plugins .base_plugin import BasePlugin
2222from google .genai import types
23+ from google .genai .errors import ClientError
2324import pytest
2425
2526from ... import testing_utils
2627
28+ mock_error = ClientError (
29+ code = 429 ,
30+ response_json = {
31+ 'error' : {
32+ 'code' : 429 ,
33+ 'message' : 'Quota exceeded.' ,
34+ 'status' : 'RESOURCE_EXHAUSTED' ,
35+ }
36+ },
37+ )
38+
2739
2840class MockPlugin (BasePlugin ):
2941 before_model_text = 'before_model_text from MockPlugin'
3042 after_model_text = 'after_model_text from MockPlugin'
43+ on_model_error_text = 'on_model_error_text from MockPlugin'
3144
3245 def __init__ (self , name = 'mock_plugin' ):
3346 self .name = name
3447 self .enable_before_model_callback = False
3548 self .enable_after_model_callback = False
49+ self .enable_on_model_error_callback = False
3650 self .before_model_response = LlmResponse (
3751 content = testing_utils .ModelContent (
3852 [types .Part .from_text (text = self .before_model_text )]
@@ -43,6 +57,11 @@ def __init__(self, name='mock_plugin'):
4357 [types .Part .from_text (text = self .after_model_text )]
4458 )
4559 )
60+ self .on_model_error_response = LlmResponse (
61+ content = testing_utils .ModelContent (
62+ [types .Part .from_text (text = self .on_model_error_text )]
63+ )
64+ )
4665
4766 async def before_model_callback (
4867 self , * , callback_context : CallbackContext , llm_request : LlmRequest
@@ -58,6 +77,17 @@ async def after_model_callback(
5877 return None
5978 return self .after_model_response
6079
80+ async def on_model_error_callback (
81+ self ,
82+ * ,
83+ callback_context : CallbackContext ,
84+ llm_request : LlmRequest ,
85+ error : Exception ,
86+ ) -> Optional [LlmResponse ]:
87+ if not self .enable_on_model_error_callback :
88+ return None
89+ return self .on_model_error_response
90+
6191
6292CANONICAL_MODEL_CALLBACK_CONTENT = 'canonical_model_callback_content'
6393
@@ -124,5 +154,36 @@ def test_before_model_callback_fallback_model(mock_plugin):
124154 ]
125155
126156
157+ def test_on_model_error_callback_with_plugin (mock_plugin ):
158+ """Tests that the model error is handled by the plugin."""
159+ mock_model = testing_utils .MockModel .create (error = mock_error , responses = [])
160+ mock_plugin .enable_on_model_error_callback = True
161+ agent = Agent (
162+ name = 'root_agent' ,
163+ model = mock_model ,
164+ )
165+
166+ runner = testing_utils .InMemoryRunner (agent , plugins = [mock_plugin ])
167+
168+ assert testing_utils .simplify_events (runner .run ('test' )) == [
169+ ('root_agent' , mock_plugin .on_model_error_text ),
170+ ]
171+
172+
173+ def test_on_model_error_callback_fallback_to_runner (mock_plugin ):
174+ """Tests that the model error is not handled and falls back to raise from runner."""
175+ mock_model = testing_utils .MockModel .create (error = mock_error , responses = [])
176+ mock_plugin .enable_on_model_error_callback = False
177+ agent = Agent (
178+ name = 'root_agent' ,
179+ model = mock_model ,
180+ )
181+
182+ try :
183+ testing_utils .InMemoryRunner (agent , plugins = [mock_plugin ])
184+ except Exception as e :
185+ assert e == mock_error
186+
187+
127188if __name__ == '__main__' :
128189 pytest .main ([__file__ ])
0 commit comments