|
26 | 26 | from strands.session.repository_session_manager import RepositorySessionManager |
27 | 27 | from strands.telemetry.tracer import serialize |
28 | 28 | from strands.types._events import EventLoopStopEvent, ModelStreamEvent |
| 29 | +from strands.types.agent import ConcurrentInvocationMode |
29 | 30 | from strands.types.content import Messages |
30 | 31 | from strands.types.exceptions import ConcurrencyException, ContextWindowOverflowException, EventLoopException |
31 | 32 | from strands.types.session import Session, SessionAgent, SessionMessage, SessionType |
@@ -2231,20 +2232,17 @@ def test_agent_concurrent_call_raises_exception(): |
2231 | 2232 | {"role": "assistant", "content": [{"text": "world"}]}, |
2232 | 2233 | ] |
2233 | 2234 | ) |
2234 | | - agent = Agent(model=model) |
| 2235 | + agent = Agent(model=model, concurrent_invocation_mode="throw") |
2235 | 2236 |
|
2236 | 2237 | results = [] |
2237 | 2238 | errors = [] |
2238 | | - lock = threading.Lock() |
2239 | 2239 |
|
2240 | 2240 | def invoke(): |
2241 | 2241 | try: |
2242 | 2242 | result = agent("test") |
2243 | | - with lock: |
2244 | | - results.append(result) |
| 2243 | + results.append(result) |
2245 | 2244 | except ConcurrencyException as e: |
2246 | | - with lock: |
2247 | | - errors.append(e) |
| 2245 | + errors.append(e) |
2248 | 2246 |
|
2249 | 2247 | # Start first thread and wait for it to begin streaming |
2250 | 2248 | t1 = threading.Thread(target=invoke) |
@@ -2282,7 +2280,7 @@ def test_agent_concurrent_structured_output_raises_exception(): |
2282 | 2280 | {"role": "assistant", "content": [{"text": "response2"}]}, |
2283 | 2281 | ], |
2284 | 2282 | ) |
2285 | | - agent = Agent(model=model) |
| 2283 | + agent = Agent(model=model, concurrent_invocation_mode="throw") |
2286 | 2284 |
|
2287 | 2285 | results = [] |
2288 | 2286 | errors = [] |
@@ -2320,6 +2318,83 @@ def invoke(): |
2320 | 2318 | assert "concurrent" in str(errors[0]).lower() and "invocation" in str(errors[0]).lower() |
2321 | 2319 |
|
2322 | 2320 |
|
| 2321 | +def test_agent_concurrent_call_succeeds_with_unsafe_reentrant_mode(): |
| 2322 | + """Test that concurrent __call__() calls succeed when concurrent_invocation_mode is 'unsafe_reentrant'.""" |
| 2323 | + model = SyncEventMockedModel( |
| 2324 | + [ |
| 2325 | + {"role": "assistant", "content": [{"text": "hello"}]}, |
| 2326 | + {"role": "assistant", "content": [{"text": "world"}]}, |
| 2327 | + ] |
| 2328 | + ) |
| 2329 | + agent = Agent(model=model, concurrent_invocation_mode="unsafe_reentrant") |
| 2330 | + |
| 2331 | + results = [] |
| 2332 | + errors = [] |
| 2333 | + lock = threading.Lock() |
| 2334 | + |
| 2335 | + def invoke(): |
| 2336 | + try: |
| 2337 | + result = agent("test") |
| 2338 | + with lock: |
| 2339 | + results.append(result) |
| 2340 | + except ConcurrencyException as e: |
| 2341 | + with lock: |
| 2342 | + errors.append(e) |
| 2343 | + |
| 2344 | + # Start first thread and wait for it to begin streaming |
| 2345 | + t1 = threading.Thread(target=invoke) |
| 2346 | + t1.start() |
| 2347 | + model.started_event.wait() # Wait until first thread is in the model.stream() |
| 2348 | + |
| 2349 | + # Start second thread while first is still running |
| 2350 | + t2 = threading.Thread(target=invoke) |
| 2351 | + t2.start() |
| 2352 | + |
| 2353 | + # Let both threads proceed |
| 2354 | + model.proceed_event.set() |
| 2355 | + t1.join() |
| 2356 | + t2.join() |
| 2357 | + |
| 2358 | + # Both should succeed, no ConcurrencyException raised |
| 2359 | + assert len(errors) == 0, f"Expected 0 errors, got {len(errors)}: {errors}" |
| 2360 | + assert len(results) == 2, f"Expected 2 successes, got {len(results)}" |
| 2361 | + |
| 2362 | + |
| 2363 | +def test_agent_concurrent_invocation_mode_default_is_throw(): |
| 2364 | + """Test that the default concurrent_invocation_mode is 'throw'.""" |
| 2365 | + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}]) |
| 2366 | + agent = Agent(model=model) |
| 2367 | + |
| 2368 | + # Verify the default mode |
| 2369 | + assert agent._concurrent_invocation_mode == "throw" |
| 2370 | + |
| 2371 | + |
| 2372 | +def test_agent_concurrent_invocation_mode_stores_value(): |
| 2373 | + """Test that concurrent_invocation_mode is stored correctly as instance variable.""" |
| 2374 | + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}]) |
| 2375 | + |
| 2376 | + agent_throw = Agent(model=model, concurrent_invocation_mode="throw") |
| 2377 | + assert agent_throw._concurrent_invocation_mode == "throw" |
| 2378 | + |
| 2379 | + agent_reentrant = Agent(model=model, concurrent_invocation_mode="unsafe_reentrant") |
| 2380 | + assert agent_reentrant._concurrent_invocation_mode == "unsafe_reentrant" |
| 2381 | + |
| 2382 | + |
| 2383 | +def test_agent_concurrent_invocation_mode_accepts_enum(): |
| 2384 | + """Test that concurrent_invocation_mode accepts enum values as well as strings.""" |
| 2385 | + |
| 2386 | + model = MockedModelProvider([{"role": "assistant", "content": [{"text": "hello"}]}]) |
| 2387 | + |
| 2388 | + # Using enum values |
| 2389 | + agent_throw = Agent(model=model, concurrent_invocation_mode=ConcurrentInvocationMode.THROW) |
| 2390 | + assert agent_throw._concurrent_invocation_mode == "throw" |
| 2391 | + assert agent_throw._concurrent_invocation_mode == ConcurrentInvocationMode.THROW |
| 2392 | + |
| 2393 | + agent_reentrant = Agent(model=model, concurrent_invocation_mode=ConcurrentInvocationMode.UNSAFE_REENTRANT) |
| 2394 | + assert agent_reentrant._concurrent_invocation_mode == "unsafe_reentrant" |
| 2395 | + assert agent_reentrant._concurrent_invocation_mode == ConcurrentInvocationMode.UNSAFE_REENTRANT |
| 2396 | + |
| 2397 | + |
2323 | 2398 | @pytest.mark.asyncio |
2324 | 2399 | async def test_agent_sequential_invocations_work(): |
2325 | 2400 | """Test that sequential invocations work correctly after lock is released.""" |
|
0 commit comments