@@ -227,11 +227,18 @@ async def run_live(
227227 """
228228 with tracer .start_as_current_span (f'agent_run [{ self .name } ]' ):
229229 ctx = self ._create_invocation_context (parent_context )
230- # TODO(hangfei): support before/after_agent_callback
230+
231+ if event := await self .__handle_before_agent_callback (ctx ):
232+ yield event
233+ if ctx .end_invocation :
234+ return
231235
232236 async for event in self ._run_live_impl (ctx ):
233237 yield event
234238
239+ if event := await self .__handle_after_agent_callback (ctx ):
240+ yield event
241+
235242 async def _run_async_impl (
236243 self , ctx : InvocationContext
237244 ) -> AsyncGenerator [Event , None ]:
@@ -335,82 +342,117 @@ async def __handle_before_agent_callback(
335342 ) -> Optional [Event ]:
336343 """Runs the before_agent_callback if it exists.
337344
345+ Args:
346+ ctx: InvocationContext, the invocation context for this agent.
347+
338348 Returns:
339349 Optional[Event]: an event if callback provides content or changed state.
340350 """
341- ret_event = None
342-
343- if not self .canonical_before_agent_callbacks :
344- return ret_event
345-
346351 callback_context = CallbackContext (ctx )
347352
348- for callback in self .canonical_before_agent_callbacks :
349- before_agent_callback_content = callback (
350- callback_context = callback_context
351- )
352- if inspect .isawaitable (before_agent_callback_content ):
353- before_agent_callback_content = await before_agent_callback_content
354- if before_agent_callback_content :
355- ret_event = Event (
356- invocation_id = ctx .invocation_id ,
357- author = self .name ,
358- branch = ctx .branch ,
359- content = before_agent_callback_content ,
360- actions = callback_context ._event_actions ,
353+ # Run callbacks from the plugins.
354+ before_agent_callback_content = (
355+ await ctx .plugin_manager .run_before_agent_callback (
356+ agent = self , callback_context = callback_context
361357 )
362- ctx .end_invocation = True
363- return ret_event
358+ )
364359
365- if callback_context .state .has_delta ():
360+ # If no overrides are provided from the plugins, further run the canonical
361+ # callbacks.
362+ if (
363+ not before_agent_callback_content
364+ and self .canonical_before_agent_callbacks
365+ ):
366+ for callback in self .canonical_before_agent_callbacks :
367+ before_agent_callback_content = callback (
368+ callback_context = callback_context
369+ )
370+ if inspect .isawaitable (before_agent_callback_content ):
371+ before_agent_callback_content = await before_agent_callback_content
372+ if before_agent_callback_content :
373+ break
374+
375+ # Process the override content if exists, and further process the state
376+ # change if exists.
377+ if before_agent_callback_content :
366378 ret_event = Event (
379+ invocation_id = ctx .invocation_id ,
380+ author = self .name ,
381+ branch = ctx .branch ,
382+ content = before_agent_callback_content ,
383+ actions = callback_context ._event_actions ,
384+ )
385+ ctx .end_invocation = True
386+ return ret_event
387+
388+ if callback_context .state .has_delta ():
389+ return Event (
367390 invocation_id = ctx .invocation_id ,
368391 author = self .name ,
369392 branch = ctx .branch ,
370393 actions = callback_context ._event_actions ,
371394 )
372395
373- return ret_event
396+ return None
374397
375398 async def __handle_after_agent_callback (
376399 self , invocation_context : InvocationContext
377400 ) -> Optional [Event ]:
378401 """Runs the after_agent_callback if it exists.
379402
403+ Args:
404+ invocation_context: InvocationContext, the invocation context for this
405+ agent.
406+
380407 Returns:
381408 Optional[Event]: an event if callback provides content or changed state.
382409 """
383- ret_event = None
384-
385- if not self .canonical_after_agent_callbacks :
386- return ret_event
387410
388411 callback_context = CallbackContext (invocation_context )
389412
390- for callback in self .canonical_after_agent_callbacks :
391- after_agent_callback_content = callback (callback_context = callback_context )
392- if inspect .isawaitable (after_agent_callback_content ):
393- after_agent_callback_content = await after_agent_callback_content
394- if after_agent_callback_content :
395- ret_event = Event (
396- invocation_id = invocation_context .invocation_id ,
397- author = self .name ,
398- branch = invocation_context .branch ,
399- content = after_agent_callback_content ,
400- actions = callback_context ._event_actions ,
413+ # Run callbacks from the plugins.
414+ after_agent_callback_content = (
415+ await invocation_context .plugin_manager .run_after_agent_callback (
416+ agent = self , callback_context = callback_context
401417 )
402- return ret_event
418+ )
403419
404- if callback_context .state .has_delta ():
420+ # If no overrides are provided from the plugins, further run the canonical
421+ # callbacks.
422+ if (
423+ not after_agent_callback_content
424+ and self .canonical_after_agent_callbacks
425+ ):
426+ for callback in self .canonical_after_agent_callbacks :
427+ after_agent_callback_content = callback (
428+ callback_context = callback_context
429+ )
430+ if inspect .isawaitable (after_agent_callback_content ):
431+ after_agent_callback_content = await after_agent_callback_content
432+ if after_agent_callback_content :
433+ break
434+
435+ # Process the override content if exists, and further process the state
436+ # change if exists.
437+ if after_agent_callback_content :
405438 ret_event = Event (
406439 invocation_id = invocation_context .invocation_id ,
407440 author = self .name ,
408441 branch = invocation_context .branch ,
409442 content = after_agent_callback_content ,
410443 actions = callback_context ._event_actions ,
411444 )
445+ return ret_event
412446
413- return ret_event
447+ if callback_context .state .has_delta ():
448+ return Event (
449+ invocation_id = invocation_context .invocation_id ,
450+ author = self .name ,
451+ branch = invocation_context .branch ,
452+ content = after_agent_callback_content ,
453+ actions = callback_context ._event_actions ,
454+ )
455+ return None
414456
415457 @override
416458 def model_post_init (self , __context : Any ) -> None :
0 commit comments