Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from ..tools.structured_output._structured_output_context import StructuredOutputContext
from ..tools.watcher import ToolWatcher
from ..types._events import AgentResultEvent, EventLoopStopEvent, InitEventLoopEvent, ModelStreamChunkEvent, TypedEvent
from ..types.agent import AgentInput
from ..types.agent import AgentInput, ConcurrentInvocationMode
from ..types.content import ContentBlock, Message, Messages, SystemContentBlock
from ..types.exceptions import ConcurrencyException, ContextWindowOverflowException
from ..types.traces import AttributeValue
Expand Down Expand Up @@ -129,6 +129,7 @@ def __init__(
structured_output_prompt: str | None = None,
tool_executor: ToolExecutor | None = None,
retry_strategy: ModelRetryStrategy | _DefaultRetryStrategySentinel | None = _DEFAULT_RETRY_STRATEGY,
concurrent_invocation_mode: ConcurrentInvocationMode = "throw",
):
"""Initialize the Agent with the specified configuration.

Expand Down Expand Up @@ -186,6 +187,10 @@ def __init__(
retry_strategy: Strategy for retrying model calls on throttling or other transient errors.
Defaults to ModelRetryStrategy with max_attempts=6, initial_delay=4s, max_delay=240s.
Implement a custom HookProvider for custom retry logic, or pass None to disable retries.
concurrent_invocation_mode: Mode controlling concurrent invocation behavior.
Defaults to "throw" which raises ConcurrencyException if concurrent invocation is attempted.
Set to "unsafe_reentrant" to skip lock acquisition entirely, allowing concurrent invocations
(restores pre-locking behavior, use with caution).

Raises:
ValueError: If agent id contains path separators.
Expand Down Expand Up @@ -263,6 +268,7 @@ def __init__(
# Using threading.Lock instead of asyncio.Lock because run_async() creates
# separate event loops in different threads, so asyncio.Lock wouldn't work
self._invocation_lock = threading.Lock()
self._concurrent_invocation_mode = concurrent_invocation_mode
Comment thread
mehtarac marked this conversation as resolved.

# In the future, we'll have a RetryStrategy base class but until
# that API is determined we only allow ModelRetryStrategy
Expand Down Expand Up @@ -622,14 +628,16 @@ async def stream_async(
yield event["data"]
```
"""
# Acquire lock to prevent concurrent invocations
# Conditionally acquire lock based on concurrent_invocation_mode
# Using threading.Lock instead of asyncio.Lock because run_async() creates
# separate event loops in different threads
acquired = self._invocation_lock.acquire(blocking=False)
if not acquired:
raise ConcurrencyException(
"Agent is already processing a request. Concurrent invocations are not supported."
)
lock_acquired = False
if self._concurrent_invocation_mode == "throw":
Comment thread
zastrowm marked this conversation as resolved.
Outdated
lock_acquired = self._invocation_lock.acquire(blocking=False)
if not lock_acquired:
raise ConcurrencyException(
"Agent is already processing a request. Concurrent invocations are not supported."
)

try:
self._interrupt_state.resume(prompt)
Expand Down Expand Up @@ -678,7 +686,8 @@ async def stream_async(
raise

finally:
self._invocation_lock.release()
if lock_acquired:
self._invocation_lock.release()
Comment thread
zastrowm marked this conversation as resolved.

async def _run_loop(
self,
Expand Down
10 changes: 9 additions & 1 deletion src/strands/types/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,17 @@
This module defines the types used for an Agent.
"""

from typing import TypeAlias
from typing import Literal, TypeAlias

from .content import ContentBlock, Messages
from .interrupt import InterruptResponseContent

AgentInput: TypeAlias = str | list[ContentBlock] | list[InterruptResponseContent] | Messages | None

ConcurrentInvocationMode = Literal["throw", "unsafe_reentrant"]
"""Mode controlling concurrent invocation behavior.

Values:
throw: Raises ConcurrencyException if concurrent invocation is attempted (default).
unsafe_reentrant: Allows concurrent invocations without locking (unsafe, restores pre-lock behavior).
Comment thread
zastrowm marked this conversation as resolved.
Outdated
"""
66 changes: 64 additions & 2 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2231,7 +2231,7 @@ def test_agent_concurrent_call_raises_exception():
{"role": "assistant", "content": [{"text": "world"}]},
]
)
agent = Agent(model=model)
agent = Agent(model=model, concurrent_invocation_mode="throw")

results = []
errors = []
Expand Down Expand Up @@ -2282,7 +2282,7 @@ def test_agent_concurrent_structured_output_raises_exception():
{"role": "assistant", "content": [{"text": "response2"}]},
],
)
agent = Agent(model=model)
agent = Agent(model=model, concurrent_invocation_mode="throw")

results = []
errors = []
Expand Down Expand Up @@ -2320,6 +2320,68 @@ def invoke():
assert "concurrent" in str(errors[0]).lower() and "invocation" in str(errors[0]).lower()


def test_agent_concurrent_call_succeeds_with_unsafe_reentrant_mode():
"""Test that concurrent __call__() calls succeed when concurrent_invocation_mode is 'unsafe_reentrant'."""
model = SyncEventMockedModel(
[
{"role": "assistant", "content": [{"text": "hello"}]},
{"role": "assistant", "content": [{"text": "world"}]},
]
)
agent = Agent(model=model, concurrent_invocation_mode="unsafe_reentrant")

results = []
errors = []
lock = threading.Lock()

def invoke():
try:
result = agent("test")
with lock:
Comment thread
zastrowm marked this conversation as resolved.
results.append(result)
except ConcurrencyException as e:
with lock:
errors.append(e)

# Start first thread and wait for it to begin streaming
t1 = threading.Thread(target=invoke)
t1.start()
model.started_event.wait() # Wait until first thread is in the model.stream()

# Start second thread while first is still running
t2 = threading.Thread(target=invoke)
t2.start()

# Let both threads proceed
model.proceed_event.set()
t1.join()
t2.join()

# Both should succeed, no ConcurrencyException raised
assert len(errors) == 0, f"Expected 0 errors, got {len(errors)}: {errors}"
assert len(results) == 2, f"Expected 2 successes, got {len(results)}"


def test_agent_concurrent_invocation_mode_default_is_throw():
"""Test that the default concurrent_invocation_mode is 'throw'."""
model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}])
agent = Agent(model=model)

# Verify the default mode
assert agent._concurrent_invocation_mode == "throw"


def test_agent_concurrent_invocation_mode_stores_value():
"""Test that concurrent_invocation_mode is stored correctly as instance variable."""
model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}])

agent_throw = Agent(model=model, concurrent_invocation_mode="throw")
assert agent_throw._concurrent_invocation_mode == "throw"

agent_reentrant = Agent(model=model, concurrent_invocation_mode="unsafe_reentrant")
assert agent_reentrant._concurrent_invocation_mode == "unsafe_reentrant"


@pytest.mark.asyncio
async def test_agent_sequential_invocations_work():
"""Test that sequential invocations work correctly after lock is released."""
Expand Down