Skip to content

Commit 59670d2

Browse files
XinranTangcopybara-github
authored andcommitted
feat: Support resuming from a paused invocation starting from a sub-agent
PiperOrigin-RevId: 817766247
1 parent bddc70b commit 59670d2

3 files changed

Lines changed: 162 additions & 8 deletions

File tree

src/google/adk/runners.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -381,12 +381,11 @@ async def _run_with_trace(
381381
run_config=run_config,
382382
state_delta=state_delta,
383383
)
384-
if invocation_context.end_of_agents.get(self.agent.name):
385-
# Directly return if the root agent has already ended.
386-
# TODO: Handle the case where the invocation-to-resume started from
387-
# a sub_agent:
388-
# invocation1: root_agent -> sub_agent1
389-
# invocation2: sub_agent1 [paused][resume]
384+
if invocation_context.end_of_agents.get(
385+
invocation_context.agent.name
386+
):
387+
# Directly return if the current agent in invocation context is
388+
# already final.
390389
return
391390
else:
392391
invocation_context = await self._setup_context_for_new_invocation(
@@ -869,6 +868,13 @@ async def _setup_context_for_resumed_invocation(
869868
)
870869
# Step 4: Populate agent states for the current invocation.
871870
invocation_context.populate_invocation_agent_states()
871+
# Step 5: Set agent to run for the invocation.
872+
#
873+
# If the root agent is not found in end_of_agents, it means the invocation
874+
# started from a sub-agent and paused on a sub-agent.
875+
# We should find the appropriate agent to run to continue the invocation.
876+
if self.agent.name not in invocation_context.end_of_agents:
877+
invocation_context.agent = self._find_agent_to_run(session, self.agent)
872878
return invocation_context
873879

874880
def _find_user_message_for_invocation(
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for edge cases of resuming invocations."""
15+
16+
import copy
17+
18+
from google.adk.agents.llm_agent import LlmAgent
19+
from google.adk.apps.app import App
20+
from google.adk.apps.app import ResumabilityConfig
21+
from google.genai.types import Part
22+
import pytest
23+
24+
from .. import testing_utils
25+
26+
27+
def transfer_call_part(agent_name: str) -> Part:
28+
return Part.from_function_call(
29+
name="transfer_to_agent", args={"agent_name": agent_name}
30+
)
31+
32+
33+
TRANSFER_RESPONSE_PART = Part.from_function_response(
34+
name="transfer_to_agent", response={"result": None}
35+
)
36+
37+
38+
@pytest.mark.asyncio
39+
async def test_resume_invocation_from_sub_agent():
40+
"""A test case for an edge case, where an invocation-to-resume starts from a sub-agent.
41+
42+
For example:
43+
invocation1: root_agent -> sub_agent
44+
invocation2: sub_agent [paused][resume]
45+
"""
46+
# Step 1: Setup
47+
# root_agent -> sub_agent
48+
sub_agent = LlmAgent(
49+
name="sub_agent",
50+
model=testing_utils.MockModel.create(
51+
responses=[
52+
"first response from sub_agent",
53+
"second response from sub_agent",
54+
"third response from sub_agent",
55+
]
56+
),
57+
)
58+
root_agent = LlmAgent(
59+
name="root_agent",
60+
model=testing_utils.MockModel.create(
61+
responses=[transfer_call_part(sub_agent.name)]
62+
),
63+
sub_agents=[sub_agent],
64+
)
65+
runner = testing_utils.InMemoryRunner(
66+
app=App(
67+
name="test_app",
68+
root_agent=root_agent,
69+
resumability_config=ResumabilityConfig(is_resumable=True),
70+
)
71+
)
72+
73+
# Step 2: Run the first invocation
74+
# Expect the invocation to start from root_agent and transferred to sub_agent.
75+
invocation_1_events = runner.run("test user query")
76+
assert testing_utils.simplify_resumable_app_events(
77+
copy.deepcopy(invocation_1_events)
78+
) == [
79+
(
80+
root_agent.name,
81+
transfer_call_part(sub_agent.name),
82+
),
83+
(
84+
root_agent.name,
85+
TRANSFER_RESPONSE_PART,
86+
),
87+
(
88+
sub_agent.name,
89+
"first response from sub_agent",
90+
),
91+
(
92+
sub_agent.name,
93+
testing_utils.END_OF_AGENT,
94+
),
95+
(
96+
root_agent.name,
97+
testing_utils.END_OF_AGENT,
98+
),
99+
]
100+
101+
# Step 3: Run the second invocation
102+
# Expect the invocation to directly start from sub_agent.
103+
invocation_2_events = runner.run(
104+
"test user query 2",
105+
)
106+
assert testing_utils.simplify_resumable_app_events(
107+
copy.deepcopy(invocation_2_events)
108+
) == [
109+
(
110+
sub_agent.name,
111+
"second response from sub_agent",
112+
),
113+
(sub_agent.name, testing_utils.END_OF_AGENT),
114+
]
115+
# Asserts the invocation will be a no-op if the current agent in context is
116+
# already final.
117+
assert not await runner.run_async(
118+
invocation_id=invocation_2_events[0].invocation_id
119+
)
120+
121+
# Step 4: Copy all session.events[:-1] to a new session
122+
# This is to simulate the case where we pause on the second invocation.
123+
session_id = runner.session_id
124+
session = await runner.runner.session_service.get_session(
125+
app_name="test_app", user_id="test_user", session_id=session_id
126+
)
127+
new_session = await runner.runner.session_service.create_session(
128+
app_name=session.app_name, user_id=session.user_id
129+
)
130+
for event in session.events[:-1]:
131+
await runner.runner.session_service.append_event(new_session, event)
132+
runner.session_id = new_session.id
133+
134+
# Step 5: Resume the second invocation
135+
resumed_invocation_2_events = await runner.run_async(
136+
invocation_id=invocation_2_events[0].invocation_id
137+
)
138+
assert testing_utils.simplify_resumable_app_events(
139+
copy.deepcopy(resumed_invocation_2_events)
140+
) == [
141+
(
142+
sub_agent.name,
143+
"third response from sub_agent",
144+
),
145+
(sub_agent.name, testing_utils.END_OF_AGENT),
146+
]

tests/unittests/testing_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,14 +274,16 @@ def run(self, new_message: types.ContentUnion) -> list[Event]:
274274
)
275275

276276
async def run_async(
277-
self, new_message: types.ContentUnion, invocation_id: Optional[str] = None
277+
self,
278+
new_message: Optional[types.ContentUnion] = None,
279+
invocation_id: Optional[str] = None,
278280
) -> list[Event]:
279281
events = []
280282
async for event in self.runner.run_async(
281283
user_id=self.session.user_id,
282284
session_id=self.session.id,
283285
invocation_id=invocation_id,
284-
new_message=get_user_content(new_message),
286+
new_message=get_user_content(new_message) if new_message else None,
285287
):
286288
events.append(event)
287289
return events

0 commit comments

Comments
 (0)