Skip to content

Commit c5c6744

Browse files
zastrowmstrands-agent
authored andcommitted
feat(agent): add concurrent_invocation_mode parameter (strands-agents#1707)
Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com> Co-authored-by: Mackenzie Zastrow <zastrowm@users.noreply.github.com>
1 parent 23e5a18 commit c5c6744

3 files changed

Lines changed: 116 additions & 15 deletions

File tree

src/strands/agent/agent.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from ..tools.structured_output._structured_output_context import StructuredOutputContext
5555
from ..tools.watcher import ToolWatcher
5656
from ..types._events import AgentResultEvent, EventLoopStopEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent
57-
from ..types.agent import AgentInput
57+
from ..types.agent import AgentInput, ConcurrentInvocationMode
5858
from ..types.content import ContentBlock, Message, Messages, SystemContentBlock
5959
from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException
6060
from ..types.traces import AttributeValue
@@ -129,6 +129,7 @@ def __init__(
129129
structured_output_prompt: str | None = None,
130130
tool_executor: ToolExecutor | None = None,
131131
retry_strategy: ModelRetryStrategy | _DefaultRetryStrategySentinel | None = _DEFAULT_RETRY_STRATEGY,
132+
concurrent_invocation_mode: ConcurrentInvocationMode = ConcurrentInvocationMode.THROW,
132133
):
133134
"""Initialize the Agent with the specified configuration.
134135
@@ -186,6 +187,11 @@ def __init__(
186187
retry_strategy: Strategy for retrying model calls on throttling or other transient errors.
187188
Defaults to ModelRetryStrategy with max_attempts=6, initial_delay=4s, max_delay=240s.
188189
Implement a custom HookProvider for custom retry logic, or pass None to disable retries.
190+
concurrent_invocation_mode: Mode controlling concurrent invocation behavior.
191+
Defaults to "throw" which raises ConcurrencyException if concurrent invocation is attempted.
192+
Set to "unsafe_reentrant" to skip lock acquisition entirely, allowing concurrent invocations.
193+
Warning: "unsafe_reentrant" makes no guarantees about resulting behavior and is provided
194+
only for advanced use cases where the caller understands the risks.
189195
190196
Raises:
191197
ValueError: If agent id contains path separators.
@@ -263,6 +269,7 @@ def __init__(
263269
# Using threading.Lock instead of asyncio.Lock because run_async() creates
264270
# separate event loops in different threads, so asyncio.Lock wouldn't work
265271
self._invocation_lock = threading.Lock()
272+
self._concurrent_invocation_mode = concurrent_invocation_mode
266273

267274
# In the future, we'll have a RetryStrategy base class but until
268275
# that API is determined we only allow ModelRetryStrategy
@@ -622,14 +629,15 @@ async def stream_async(
622629
yield event["data"]
623630
```
624631
"""
625-
# Acquire lock to prevent concurrent invocations
632+
# Conditionally acquire lock based on concurrent_invocation_mode
626633
# Using threading.Lock instead of asyncio.Lock because run_async() creates
627634
# separate event loops in different threads
628-
acquired = self._invocation_lock.acquire(blocking=False)
629-
if not acquired:
630-
raise ConcurrencyException(
631-
"Agent is already processing a request. Concurrent invocations are not supported."
632-
)
635+
if self._concurrent_invocation_mode == ConcurrentInvocationMode.THROW:
636+
lock_acquired = self._invocation_lock.acquire(blocking=False)
637+
if not lock_acquired:
638+
raise ConcurrencyException(
639+
"Agent is already processing a request. Concurrent invocations are not supported."
640+
)
633641

634642
try:
635643
self._interrupt_state.resume(prompt)
@@ -678,7 +686,8 @@ async def stream_async(
678686
raise
679687

680688
finally:
681-
self._invocation_lock.release()
689+
if self._invocation_lock.locked():
690+
self._invocation_lock.release()
682691

683692
async def _run_loop(
684693
self,

src/strands/types/agent.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,26 @@
33
This module defines the types used for an Agent.
44
"""
55

6+
from enum import Enum
67
from typing import TypeAlias
78

89
from .content import ContentBlock, Messages
910
from .interrupt import InterruptResponseContent
1011

1112
AgentInput: TypeAlias = str | list[ContentBlock] | list[InterruptResponseContent] | Messages | None
13+
14+
15+
class ConcurrentInvocationMode(str, Enum):
16+
"""Mode controlling concurrent invocation behavior.
17+
18+
Values:
19+
THROW: Raises ConcurrencyException if concurrent invocation is attempted (default).
20+
UNSAFE_REENTRANT: Allows concurrent invocations without locking.
21+
22+
Warning:
23+
The ``UNSAFE_REENTRANT`` mode makes no guarantees about resulting behavior and is
24+
provided only for advanced use cases where the caller understands the risks.
25+
"""
26+
27+
THROW = "throw"
28+
UNSAFE_REENTRANT = "unsafe_reentrant"

tests/strands/agent/test_agent.py

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from strands.session.repository_session_manager import RepositorySessionManager
2727
from strands.telemetry.tracer import serialize
2828
from strands.types._events import EventLoopStopEvent, ModelStreamEvent
29+
from strands.types.agent import ConcurrentInvocationMode
2930
from strands.types.content import Messages
3031
from strands.types.exceptions import ConcurrencyException, ContextWindowOverflowException, EventLoopException
3132
from strands.types.session import Session, SessionAgent, SessionMessage, SessionType
@@ -2231,20 +2232,17 @@ def test_agent_concurrent_call_raises_exception():
22312232
{"role": "assistant", "content": [{"text": "world"}]},
22322233
]
22332234
)
2234-
agent = Agent(model=model)
2235+
agent = Agent(model=model, concurrent_invocation_mode="throw")
22352236

22362237
results = []
22372238
errors = []
2238-
lock = threading.Lock()
22392239

22402240
def invoke():
22412241
try:
22422242
result = agent("test")
2243-
with lock:
2244-
results.append(result)
2243+
results.append(result)
22452244
except ConcurrencyException as e:
2246-
with lock:
2247-
errors.append(e)
2245+
errors.append(e)
22482246

22492247
# Start first thread and wait for it to begin streaming
22502248
t1 = threading.Thread(target=invoke)
@@ -2282,7 +2280,7 @@ def test_agent_concurrent_structured_output_raises_exception():
22822280
{"role": "assistant", "content": [{"text": "response2"}]},
22832281
],
22842282
)
2285-
agent = Agent(model=model)
2283+
agent = Agent(model=model, concurrent_invocation_mode="throw")
22862284

22872285
results = []
22882286
errors = []
@@ -2320,6 +2318,83 @@ def invoke():
23202318
assert "concurrent" in str(errors[0]).lower() and "invocation" in str(errors[0]).lower()
23212319

23222320

2321+
def test_agent_concurrent_call_succeeds_with_unsafe_reentrant_mode():
2322+
"""Test that concurrent __call__() calls succeed when concurrent_invocation_mode is 'unsafe_reentrant'."""
2323+
model = SyncEventMockedModel(
2324+
[
2325+
{"role": "assistant", "content": [{"text": "hello"}]},
2326+
{"role": "assistant", "content": [{"text": "world"}]},
2327+
]
2328+
)
2329+
agent = Agent(model=model, concurrent_invocation_mode="unsafe_reentrant")
2330+
2331+
results = []
2332+
errors = []
2333+
lock = threading.Lock()
2334+
2335+
def invoke():
2336+
try:
2337+
result = agent("test")
2338+
with lock:
2339+
results.append(result)
2340+
except ConcurrencyException as e:
2341+
with lock:
2342+
errors.append(e)
2343+
2344+
# Start first thread and wait for it to begin streaming
2345+
t1 = threading.Thread(target=invoke)
2346+
t1.start()
2347+
model.started_event.wait() # Wait until first thread is in the model.stream()
2348+
2349+
# Start second thread while first is still running
2350+
t2 = threading.Thread(target=invoke)
2351+
t2.start()
2352+
2353+
# Let both threads proceed
2354+
model.proceed_event.set()
2355+
t1.join()
2356+
t2.join()
2357+
2358+
# Both should succeed, no ConcurrencyException raised
2359+
assert len(errors) == 0, f"Expected 0 errors, got {len(errors)}: {errors}"
2360+
assert len(results) == 2, f"Expected 2 successes, got {len(results)}"
2361+
2362+
2363+
def test_agent_concurrent_invocation_mode_default_is_throw():
2364+
"""Test that the default concurrent_invocation_mode is 'throw'."""
2365+
model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}])
2366+
agent = Agent(model=model)
2367+
2368+
# Verify the default mode
2369+
assert agent._concurrent_invocation_mode == "throw"
2370+
2371+
2372+
def test_agent_concurrent_invocation_mode_stores_value():
2373+
"""Test that concurrent_invocation_mode is stored correctly as instance variable."""
2374+
model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}])
2375+
2376+
agent_throw = Agent(model=model, concurrent_invocation_mode="throw")
2377+
assert agent_throw._concurrent_invocation_mode == "throw"
2378+
2379+
agent_reentrant = Agent(model=model, concurrent_invocation_mode="unsafe_reentrant")
2380+
assert agent_reentrant._concurrent_invocation_mode == "unsafe_reentrant"
2381+
2382+
2383+
def test_agent_concurrent_invocation_mode_accepts_enum():
2384+
"""Test that concurrent_invocation_mode accepts enum values as well as strings."""
2385+
2386+
model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}])
2387+
2388+
# Using enum values
2389+
agent_throw = Agent(model=model, concurrent_invocation_mode=ConcurrentInvocationMode.THROW)
2390+
assert agent_throw._concurrent_invocation_mode == "throw"
2391+
assert agent_throw._concurrent_invocation_mode == ConcurrentInvocationMode.THROW
2392+
2393+
agent_reentrant = Agent(model=model, concurrent_invocation_mode=ConcurrentInvocationMode.UNSAFE_REENTRANT)
2394+
assert agent_reentrant._concurrent_invocation_mode == "unsafe_reentrant"
2395+
assert agent_reentrant._concurrent_invocation_mode == ConcurrentInvocationMode.UNSAFE_REENTRANT
2396+
2397+
23232398
@pytest.mark.asyncio
23242399
async def test_agent_sequential_invocations_work():
23252400
"""Test that sequential invocations work correctly after lock is released."""

0 commit comments

Comments
 (0)