diff --git a/src/a2a/contrib/tasks/vertex_task_store.py b/src/a2a/contrib/tasks/vertex_task_store.py index 1b5d852da..ccd9fffba 100644 --- a/src/a2a/contrib/tasks/vertex_task_store.py +++ b/src/a2a/contrib/tasks/vertex_task_store.py @@ -44,9 +44,7 @@ def __init__( self._client = client self._agent_engine_resource_id = agent_engine_resource_id - async def save( - self, task: Task, context: ServerCallContext | None = None - ) -> None: + async def save(self, task: Task, context: ServerCallContext) -> None: """Saves or updates a task in the store.""" compat_task = to_compat_task(task) previous_task = await self._get_stored_task(compat_task.id) @@ -206,7 +204,7 @@ async def _get_stored_task( return a2a_task async def get( - self, task_id: str, context: ServerCallContext | None = None + self, task_id: str, context: ServerCallContext ) -> Task | None: """Retrieves a task from the database by ID.""" a2a_task = await self._get_stored_task(task_id) @@ -217,13 +215,11 @@ async def get( async def list( self, params: ListTasksRequest, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> ListTasksResponse: """Retrieves a list of tasks from the store.""" raise NotImplementedError - async def delete( - self, task_id: str, context: ServerCallContext | None = None - ) -> None: + async def delete(self, task_id: str, context: ServerCallContext) -> None: """The backend doesn't support deleting tasks, so this is not implemented.""" raise NotImplementedError diff --git a/src/a2a/server/agent_execution/context.py b/src/a2a/server/agent_execution/context.py index 73a4a9f4e..91284f37c 100644 --- a/src/a2a/server/agent_execution/context.py +++ b/src/a2a/server/agent_execution/context.py @@ -26,35 +26,35 @@ class RequestContext: def __init__( # noqa: PLR0913 self, + call_context: ServerCallContext, request: SendMessageRequest | None = None, task_id: str | None = None, context_id: str | None = None, task: Task | None = None, related_tasks: list[Task] | None = None, - call_context: ServerCallContext | None = None, task_id_generator: IDGenerator | None = None, context_id_generator: IDGenerator | None = None, ): """Initializes the RequestContext. Args: + call_context: The server call context associated with this request. request: The incoming `SendMessageRequest` request payload. task_id: The ID of the task explicitly provided in the request or path. context_id: The ID of the context explicitly provided in the request or path. task: The existing `Task` object retrieved from the store, if any. related_tasks: A list of other tasks related to the current request (e.g., for tool use). - call_context: The server call context associated with this request. task_id_generator: ID generator for new task IDs. Defaults to UUID generator. context_id_generator: ID generator for new context IDs. Defaults to UUID generator. """ if related_tasks is None: related_tasks = [] + self._call_context = call_context self._params = request self._task_id = task_id self._context_id = context_id self._current_task = task self._related_tasks = related_tasks - self._call_context = call_context self._task_id_generator = ( task_id_generator if task_id_generator else UUIDGenerator() ) @@ -140,7 +140,7 @@ def configuration(self) -> SendMessageConfiguration | None: return self._params.configuration if self._params else None @property - def call_context(self) -> ServerCallContext | None: + def call_context(self) -> ServerCallContext: """The server call context associated with this request.""" return self._call_context @@ -157,22 +157,17 @@ def add_activated_extension(self, uri: str) -> None: This causes the extension to be indicated back to the client in the response. """ - if self._call_context: - self._call_context.activated_extensions.add(uri) + self._call_context.activated_extensions.add(uri) @property def tenant(self) -> str: """The tenant associated with this request.""" - return self._call_context.tenant if self._call_context else '' + return self._call_context.tenant @property def requested_extensions(self) -> set[str]: """Extensions that the client requested to activate.""" - return ( - self._call_context.requested_extensions - if self._call_context - else set() - ) + return self._call_context.requested_extensions def _check_or_generate_task_id(self) -> None: """Ensures a task ID is present, generating one if necessary.""" diff --git a/src/a2a/server/agent_execution/request_context_builder.py b/src/a2a/server/agent_execution/request_context_builder.py index 984a10149..cab82b401 100644 --- a/src/a2a/server/agent_execution/request_context_builder.py +++ b/src/a2a/server/agent_execution/request_context_builder.py @@ -11,10 +11,10 @@ class RequestContextBuilder(ABC): @abstractmethod async def build( self, + context: ServerCallContext, params: SendMessageRequest | None = None, task_id: str | None = None, context_id: str | None = None, task: Task | None = None, - context: ServerCallContext | None = None, ) -> RequestContext: pass diff --git a/src/a2a/server/agent_execution/simple_request_context_builder.py b/src/a2a/server/agent_execution/simple_request_context_builder.py index 9a1223afa..5f2b7c521 100644 --- a/src/a2a/server/agent_execution/simple_request_context_builder.py +++ b/src/a2a/server/agent_execution/simple_request_context_builder.py @@ -35,11 +35,11 @@ def __init__( async def build( self, + context: ServerCallContext, params: SendMessageRequest | None = None, task_id: str | None = None, context_id: str | None = None, task: Task | None = None, - context: ServerCallContext | None = None, ) -> RequestContext: """Builds the request context for an agent execution. @@ -48,11 +48,11 @@ async def build( referenced in `params.message.reference_task_ids` from the `task_store`. Args: + context: The server call context, containing metadata about the call. params: The parameters of the incoming message send request. task_id: The ID of the task being executed. context_id: The ID of the current execution context. task: The primary task object associated with the request. - context: The server call context, containing metadata about the call. Returns: An instance of RequestContext populated with the provided information @@ -68,19 +68,19 @@ async def build( ): tasks = await asyncio.gather( *[ - self._task_store.get(task_id) + self._task_store.get(task_id, context) for task_id in params.message.reference_task_ids ] ) related_tasks = [x for x in tasks if x is not None] return RequestContext( + call_context=context, request=params, task_id=task_id, context_id=context_id, task=task, related_tasks=related_tasks, - call_context=context, task_id_generator=self._task_id_generator, context_id_generator=self._context_id_generator, ) diff --git a/src/a2a/server/owner_resolver.py b/src/a2a/server/owner_resolver.py index 798eb8c9b..4fca42d24 100644 --- a/src/a2a/server/owner_resolver.py +++ b/src/a2a/server/owner_resolver.py @@ -4,13 +4,10 @@ # Definition -OwnerResolver = Callable[[ServerCallContext | None], str] +OwnerResolver = Callable[[ServerCallContext], str] # Example Default Implementation -def resolve_user_scope(context: ServerCallContext | None) -> str: +def resolve_user_scope(context: ServerCallContext) -> str: """Resolves the owner scope based on the user in the context.""" - if not context: - return 'unknown' - # Example: Basic user name. Adapt as needed for your user model. return context.user.user_name diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index d1835073a..ac8c5778f 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -196,7 +196,8 @@ async def on_cancel_task( await self.agent_executor.cancel( RequestContext( - None, + call_context=context, + request=None, task_id=task.id, context_id=task.context_id, task=task, @@ -290,7 +291,7 @@ async def _setup_message_execution( await self._push_config_store.set_info( task_id, params.configuration.task_push_notification_config, - context or ServerCallContext(), + context, ) queue = await self._queue_manager.create_or_tap(task_id) @@ -504,7 +505,7 @@ async def on_create_task_push_notification_config( await self._push_config_store.set_info( task_id, params, - context or ServerCallContext(), + context, ) return params @@ -529,10 +530,7 @@ async def on_get_task_push_notification_config( raise TaskNotFoundError push_notification_configs: list[TaskPushNotificationConfig] = ( - await self._push_config_store.get_info( - task_id, context or ServerCallContext() - ) - or [] + await self._push_config_store.get_info(task_id, context) or [] ) for config in push_notification_configs: @@ -603,7 +601,7 @@ async def on_list_task_push_notification_configs( raise TaskNotFoundError push_notification_config_list = await self._push_config_store.get_info( - task_id, context or ServerCallContext() + task_id, context ) return ListTaskPushNotificationConfigsResponse( @@ -629,6 +627,4 @@ async def on_delete_task_push_notification_config( if not task: raise TaskNotFoundError - await self._push_config_store.delete_info( - task_id, context or ServerCallContext(), config_id - ) + await self._push_config_store.delete_info(task_id, context, config_id) diff --git a/src/a2a/server/tasks/copying_task_store.py b/src/a2a/server/tasks/copying_task_store.py index 6bfda5e74..f7f41bf1f 100644 --- a/src/a2a/server/tasks/copying_task_store.py +++ b/src/a2a/server/tasks/copying_task_store.py @@ -24,16 +24,14 @@ class CopyingTaskStoreAdapter(TaskStore): def __init__(self, underlying_store: TaskStore): self._store = underlying_store - async def save( - self, task: Task, context: ServerCallContext | None = None - ) -> None: + async def save(self, task: Task, context: ServerCallContext) -> None: """Saves a copy of the task to the underlying store.""" task_copy = Task() task_copy.CopyFrom(task) await self._store.save(task_copy, context) async def get( - self, task_id: str, context: ServerCallContext | None = None + self, task_id: str, context: ServerCallContext ) -> Task | None: """Retrieves a task from the underlying store and returns a copy.""" task = await self._store.get(task_id, context) @@ -46,7 +44,7 @@ async def get( async def list( self, params: ListTasksRequest, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> ListTasksResponse: """Retrieves a list of tasks from the underlying store and returns a copy.""" response = await self._store.list(params, context) @@ -54,8 +52,6 @@ async def list( response_copy.CopyFrom(response) return response_copy - async def delete( - self, task_id: str, context: ServerCallContext | None = None - ) -> None: + async def delete(self, task_id: str, context: ServerCallContext) -> None: """Deletes a task from the underlying store.""" await self._store.delete(task_id, context) diff --git a/src/a2a/server/tasks/database_task_store.py b/src/a2a/server/tasks/database_task_store.py index 2c95da2ca..62a760b24 100644 --- a/src/a2a/server/tasks/database_task_store.py +++ b/src/a2a/server/tasks/database_task_store.py @@ -169,9 +169,7 @@ def _from_orm(self, task_model: TaskModel) -> Task: # Legacy conversion return compat_task_model_to_core(task_model) - async def save( - self, task: Task, context: ServerCallContext | None = None - ) -> None: + async def save(self, task: Task, context: ServerCallContext) -> None: """Saves or updates a task in the database for the resolved owner.""" await self._ensure_initialized() owner = self.owner_resolver(context) @@ -185,7 +183,7 @@ async def save( ) async def get( - self, task_id: str, context: ServerCallContext | None = None + self, task_id: str, context: ServerCallContext ) -> Task | None: """Retrieves a task from the database by ID, for the given owner.""" await self._ensure_initialized() @@ -216,7 +214,7 @@ async def get( async def list( self, params: a2a_pb2.ListTasksRequest, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> a2a_pb2.ListTasksResponse: """Retrieves tasks from the database based on provided parameters, for the given owner.""" await self._ensure_initialized() @@ -315,9 +313,7 @@ async def list( page_size=page_size, ) - async def delete( - self, task_id: str, context: ServerCallContext | None = None - ) -> None: + async def delete(self, task_id: str, context: ServerCallContext) -> None: """Deletes a task from the database by ID, for the given owner.""" await self._ensure_initialized() owner = self.owner_resolver(context) diff --git a/src/a2a/server/tasks/inmemory_task_store.py b/src/a2a/server/tasks/inmemory_task_store.py index f887b77ba..75d2269bc 100644 --- a/src/a2a/server/tasks/inmemory_task_store.py +++ b/src/a2a/server/tasks/inmemory_task_store.py @@ -35,9 +35,7 @@ def __init__( def _get_owner_tasks(self, owner: str) -> dict[str, Task]: return self.tasks.get(owner, {}) - async def save( - self, task: Task, context: ServerCallContext | None = None - ) -> None: + async def save(self, task: Task, context: ServerCallContext) -> None: """Saves or updates a task in the in-memory store for the resolved owner.""" owner = self.owner_resolver(context) if owner not in self.tasks: @@ -50,7 +48,7 @@ async def save( ) async def get( - self, task_id: str, context: ServerCallContext | None = None + self, task_id: str, context: ServerCallContext ) -> Task | None: """Retrieves a task from the in-memory store by ID, for the given owner.""" owner = self.owner_resolver(context) @@ -77,7 +75,7 @@ async def get( async def list( self, params: a2a_pb2.ListTasksRequest, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> a2a_pb2.ListTasksResponse: """Retrieves a list of tasks from the store, for the given owner.""" owner = self.owner_resolver(context) @@ -156,9 +154,7 @@ async def list( page_size=page_size, ) - async def delete( - self, task_id: str, context: ServerCallContext | None = None - ) -> None: + async def delete(self, task_id: str, context: ServerCallContext) -> None: """Deletes a task from the in-memory store by ID, for the given owner.""" owner = self.owner_resolver(context) async with self.lock: @@ -211,14 +207,12 @@ def __init__( CopyingTaskStoreAdapter(self._impl) if use_copying else self._impl ) - async def save( - self, task: Task, context: ServerCallContext | None = None - ) -> None: + async def save(self, task: Task, context: ServerCallContext) -> None: """Saves or updates a task in the store.""" await self._store.save(task, context) async def get( - self, task_id: str, context: ServerCallContext | None = None + self, task_id: str, context: ServerCallContext ) -> Task | None: """Retrieves a task from the store by ID.""" return await self._store.get(task_id, context) @@ -226,13 +220,11 @@ async def get( async def list( self, params: a2a_pb2.ListTasksRequest, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> a2a_pb2.ListTasksResponse: """Retrieves a list of tasks from the store.""" return await self._store.list(params, context) - async def delete( - self, task_id: str, context: ServerCallContext | None = None - ) -> None: + async def delete(self, task_id: str, context: ServerCallContext) -> None: """Deletes a task from the store by ID.""" await self._store.delete(task_id, context) diff --git a/src/a2a/server/tasks/task_manager.py b/src/a2a/server/tasks/task_manager.py index 440100b1f..905b11af3 100644 --- a/src/a2a/server/tasks/task_manager.py +++ b/src/a2a/server/tasks/task_manager.py @@ -27,31 +27,31 @@ class TaskManager: def __init__( self, + task_store: TaskStore, + context: ServerCallContext, task_id: str | None, context_id: str | None, - task_store: TaskStore, initial_message: Message | None, - context: ServerCallContext | None = None, ): """Initializes the TaskManager. Args: + task_store: The `TaskStore` instance for persistence. + context: The `ServerCallContext` that this task is produced under. task_id: The ID of the task, if known from the request. context_id: The ID of the context, if known from the request. - task_store: The `TaskStore` instance for persistence. initial_message: The `Message` that initiated the task, if any. Used when creating a new task object. - context: The `ServerCallContext` that this task is produced under. """ if task_id is not None and not (isinstance(task_id, str) and task_id): raise ValueError('Task ID must be a non-empty string') + self.task_store = task_store + self._call_context: ServerCallContext = context self.task_id = task_id self.context_id = context_id - self.task_store = task_store self._initial_message = initial_message self._current_task: Task | None = None - self._call_context: ServerCallContext | None = context logger.debug( 'TaskManager initialized with task_id: %s, context_id: %s', task_id, diff --git a/src/a2a/server/tasks/task_store.py b/src/a2a/server/tasks/task_store.py index a4d3308c0..25e4838d1 100644 --- a/src/a2a/server/tasks/task_store.py +++ b/src/a2a/server/tasks/task_store.py @@ -11,14 +11,12 @@ class TaskStore(ABC): """ @abstractmethod - async def save( - self, task: Task, context: ServerCallContext | None = None - ) -> None: + async def save(self, task: Task, context: ServerCallContext) -> None: """Saves or updates a task in the store.""" @abstractmethod async def get( - self, task_id: str, context: ServerCallContext | None = None + self, task_id: str, context: ServerCallContext ) -> Task | None: """Retrieves a task from the store by ID.""" @@ -26,12 +24,10 @@ async def get( async def list( self, params: ListTasksRequest, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> ListTasksResponse: """Retrieves a list of tasks from the store.""" @abstractmethod - async def delete( - self, task_id: str, context: ServerCallContext | None = None - ) -> None: + async def delete(self, task_id: str, context: ServerCallContext) -> None: """Deletes a task from the store by ID.""" diff --git a/tests/contrib/tasks/test_vertex_task_store.py b/tests/contrib/tasks/test_vertex_task_store.py index 96037c697..75e3bdf08 100644 --- a/tests/contrib/tasks/test_vertex_task_store.py +++ b/tests/contrib/tasks/test_vertex_task_store.py @@ -62,6 +62,7 @@ def backend_type(request) -> str: from a2a.contrib.tasks.vertex_task_store import VertexTaskStore +from a2a.server.context import ServerCallContext from a2a.types.a2a_pb2 import ( Artifact, Part, @@ -140,9 +141,11 @@ async def test_save_task(vertex_store: VertexTaskStore) -> None: task_to_save = Task() task_to_save.CopyFrom(MINIMAL_TASK_OBJ) task_to_save.id = 'save-test-task-2' - await vertex_store.save(task_to_save) + await vertex_store.save(task_to_save, ServerCallContext()) - retrieved_task = await vertex_store.get(task_to_save.id) + retrieved_task = await vertex_store.get( + task_to_save.id, ServerCallContext() + ) assert retrieved_task is not None assert retrieved_task.id == task_to_save.id @@ -156,9 +159,11 @@ async def test_get_task(vertex_store: VertexTaskStore) -> None: task_to_save = Task() task_to_save.CopyFrom(MINIMAL_TASK_OBJ) task_to_save.id = task_id - await vertex_store.save(task_to_save) + await vertex_store.save(task_to_save, ServerCallContext()) - retrieved_task = await vertex_store.get(task_to_save.id) + retrieved_task = await vertex_store.get( + task_to_save.id, ServerCallContext() + ) assert retrieved_task is not None assert retrieved_task.id == task_to_save.id assert retrieved_task.context_id == task_to_save.context_id @@ -170,7 +175,9 @@ async def test_get_nonexistent_task( vertex_store: VertexTaskStore, ) -> None: """Test retrieving a nonexistent task.""" - retrieved_task = await vertex_store.get('nonexistent-task-id') + retrieved_task = await vertex_store.get( + 'nonexistent-task-id', ServerCallContext() + ) assert retrieved_task is None @@ -196,8 +203,8 @@ async def test_save_and_get_detailed_task( test_task.metadata['key1'] = 'value1' test_task.metadata['key2'] = 123 - await vertex_store.save(test_task) - retrieved_task = await vertex_store.get(test_task.id) + await vertex_store.save(test_task, ServerCallContext()) + retrieved_task = await vertex_store.get(test_task.id, ServerCallContext()) assert retrieved_task is not None assert retrieved_task.id == test_task.id @@ -221,9 +228,11 @@ async def test_update_task_status_and_metadata( artifacts=[], history=[], ) - await vertex_store.save(original_task) + await vertex_store.save(original_task, ServerCallContext()) - retrieved_before_update = await vertex_store.get(task_id) + retrieved_before_update = await vertex_store.get( + task_id, ServerCallContext() + ) assert retrieved_before_update is not None assert ( retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED @@ -236,9 +245,11 @@ async def test_update_task_status_and_metadata( updated_task.status.timestamp.FromJsonString('2023-01-02T11:00:00Z') updated_task.metadata.update({'update_key': 'update_value'}) - await vertex_store.save(updated_task) + await vertex_store.save(updated_task, ServerCallContext()) - retrieved_after_update = await vertex_store.get(task_id) + retrieved_after_update = await vertex_store.get( + task_id, ServerCallContext() + ) assert retrieved_after_update is not None assert retrieved_after_update.status.state == TaskState.TASK_STATE_COMPLETED assert retrieved_after_update.metadata == {'update_key': 'update_value'} @@ -260,9 +271,11 @@ async def test_update_task_add_artifact(vertex_store: VertexTaskStore) -> None: ], history=[], ) - await vertex_store.save(original_task) + await vertex_store.save(original_task, ServerCallContext()) - retrieved_before_update = await vertex_store.get(task_id) + retrieved_before_update = await vertex_store.get( + task_id, ServerCallContext() + ) assert retrieved_before_update is not None assert ( retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED @@ -281,9 +294,11 @@ async def test_update_task_add_artifact(vertex_store: VertexTaskStore) -> None: ) ) - await vertex_store.save(updated_task) + await vertex_store.save(updated_task, ServerCallContext()) - retrieved_after_update = await vertex_store.get(task_id) + retrieved_after_update = await vertex_store.get( + task_id, ServerCallContext() + ) assert retrieved_after_update is not None assert retrieved_after_update.status.state == TaskState.TASK_STATE_WORKING @@ -321,9 +336,11 @@ async def test_update_task_update_artifact( ], history=[], ) - await vertex_store.save(original_task) + await vertex_store.save(original_task, ServerCallContext()) - retrieved_before_update = await vertex_store.get(task_id) + retrieved_before_update = await vertex_store.get( + task_id, ServerCallContext() + ) assert retrieved_before_update is not None assert ( retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED @@ -337,9 +354,11 @@ async def test_update_task_update_artifact( updated_task.artifacts[0].parts[0].text = 'ahoy' - await vertex_store.save(updated_task) + await vertex_store.save(updated_task, ServerCallContext()) - retrieved_after_update = await vertex_store.get(task_id) + retrieved_after_update = await vertex_store.get( + task_id, ServerCallContext() + ) assert retrieved_after_update is not None assert retrieved_after_update.status.state == TaskState.TASK_STATE_WORKING @@ -377,9 +396,11 @@ async def test_update_task_delete_artifact( ], history=[], ) - await vertex_store.save(original_task) + await vertex_store.save(original_task, ServerCallContext()) - retrieved_before_update = await vertex_store.get(task_id) + retrieved_before_update = await vertex_store.get( + task_id, ServerCallContext() + ) assert retrieved_before_update is not None assert ( retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED @@ -393,9 +414,11 @@ async def test_update_task_delete_artifact( del updated_task.artifacts[1] - await vertex_store.save(updated_task) + await vertex_store.save(updated_task, ServerCallContext()) - retrieved_after_update = await vertex_store.get(task_id) + retrieved_after_update = await vertex_store.get( + task_id, ServerCallContext() + ) assert retrieved_after_update is not None assert retrieved_after_update.status.state == TaskState.TASK_STATE_WORKING @@ -426,8 +449,10 @@ async def test_metadata_field_mapping( context_id='session-meta-1', status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), ) - await vertex_store.save(task_no_metadata) - retrieved_no_metadata = await vertex_store.get('task-metadata-test-1') + await vertex_store.save(task_no_metadata, ServerCallContext()) + retrieved_no_metadata = await vertex_store.get( + 'task-metadata-test-1', ServerCallContext() + ) assert retrieved_no_metadata is not None assert retrieved_no_metadata.metadata == {} @@ -439,8 +464,10 @@ async def test_metadata_field_mapping( status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), metadata=simple_metadata, ) - await vertex_store.save(task_simple_metadata) - retrieved_simple = await vertex_store.get('task-metadata-test-2') + await vertex_store.save(task_simple_metadata, ServerCallContext()) + retrieved_simple = await vertex_store.get( + 'task-metadata-test-2', ServerCallContext() + ) assert retrieved_simple is not None assert retrieved_simple.metadata == simple_metadata @@ -463,8 +490,10 @@ async def test_metadata_field_mapping( status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), metadata=complex_metadata, ) - await vertex_store.save(task_complex_metadata) - retrieved_complex = await vertex_store.get('task-metadata-test-3') + await vertex_store.save(task_complex_metadata, ServerCallContext()) + retrieved_complex = await vertex_store.get( + 'task-metadata-test-3', ServerCallContext() + ) assert retrieved_complex is not None assert retrieved_complex.metadata == complex_metadata @@ -474,16 +503,18 @@ async def test_metadata_field_mapping( context_id='session-meta-4', status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), ) - await vertex_store.save(task_update_metadata) + await vertex_store.save(task_update_metadata, ServerCallContext()) # Update metadata task_update_metadata.metadata.Clear() task_update_metadata.metadata.update( {'updated': True, 'timestamp': '2024-01-01'} ) - await vertex_store.save(task_update_metadata) + await vertex_store.save(task_update_metadata, ServerCallContext()) - retrieved_updated = await vertex_store.get('task-metadata-test-4') + retrieved_updated = await vertex_store.get( + 'task-metadata-test-4', ServerCallContext() + ) assert retrieved_updated is not None assert retrieved_updated.metadata == { 'updated': True, @@ -492,8 +523,10 @@ async def test_metadata_field_mapping( # Test 5: Update metadata from dict to None task_update_metadata.metadata.Clear() - await vertex_store.save(task_update_metadata) + await vertex_store.save(task_update_metadata, ServerCallContext()) - retrieved_none = await vertex_store.get('task-metadata-test-4') + retrieved_none = await vertex_store.get( + 'task-metadata-test-4', ServerCallContext() + ) assert retrieved_none is not None assert retrieved_none.metadata == {} diff --git a/tests/server/agent_execution/test_context.py b/tests/server/agent_execution/test_context.py index 2e9423324..7ec612986 100644 --- a/tests/server/agent_execution/test_context.py +++ b/tests/server/agent_execution/test_context.py @@ -35,7 +35,7 @@ def mock_task(self) -> Mock: def test_init_without_params(self) -> None: """Test initialization without parameters.""" - context = RequestContext() + context = RequestContext(ServerCallContext()) assert context.message is None assert context.task_id is None assert context.context_id is None @@ -51,7 +51,7 @@ def test_init_with_params_no_ids(self, mock_params: Mock) -> None: uuid.UUID('00000000-0000-0000-0000-000000000002'), ], ): - context = RequestContext(request=mock_params) + context = RequestContext(ServerCallContext(), request=mock_params) assert context.message == mock_params.message assert context.task_id == '00000000-0000-0000-0000-000000000001' @@ -68,7 +68,9 @@ def test_init_with_params_no_ids(self, mock_params: Mock) -> None: def test_init_with_task_id(self, mock_params: Mock) -> None: """Test initialization with task ID provided.""" task_id = 'task-123' - context = RequestContext(request=mock_params, task_id=task_id) + context = RequestContext( + ServerCallContext(), request=mock_params, task_id=task_id + ) assert context.task_id == task_id assert mock_params.message.task_id == task_id @@ -76,7 +78,9 @@ def test_init_with_task_id(self, mock_params: Mock) -> None: def test_init_with_context_id(self, mock_params: Mock) -> None: """Test initialization with context ID provided.""" context_id = 'context-456' - context = RequestContext(request=mock_params, context_id=context_id) + context = RequestContext( + ServerCallContext(), request=mock_params, context_id=context_id + ) assert context.context_id == context_id assert mock_params.message.context_id == context_id @@ -86,7 +90,10 @@ def test_init_with_both_ids(self, mock_params: Mock) -> None: task_id = 'task-123' context_id = 'context-456' context = RequestContext( - request=mock_params, task_id=task_id, context_id=context_id + ServerCallContext(), + request=mock_params, + task_id=task_id, + context_id=context_id, ) assert context.task_id == task_id @@ -96,18 +103,20 @@ def test_init_with_both_ids(self, mock_params: Mock) -> None: def test_init_with_task(self, mock_params: Mock, mock_task: Mock) -> None: """Test initialization with a task object.""" - context = RequestContext(request=mock_params, task=mock_task) + context = RequestContext( + ServerCallContext(), request=mock_params, task=mock_task + ) assert context.current_task == mock_task def test_get_user_input_no_params(self) -> None: """Test get_user_input with no params returns empty string.""" - context = RequestContext() + context = RequestContext(ServerCallContext()) assert context.get_user_input() == '' def test_attach_related_task(self, mock_task: Mock) -> None: """Test attach_related_task adds a task to related_tasks.""" - context = RequestContext() + context = RequestContext(ServerCallContext()) assert len(context.related_tasks) == 0 context.attach_related_task(mock_task) @@ -122,7 +131,7 @@ def test_attach_related_task(self, mock_task: Mock) -> None: def test_current_task_property(self, mock_task: Mock) -> None: """Test current_task getter and setter.""" - context = RequestContext() + context = RequestContext(ServerCallContext()) assert context.current_task is None context.current_task = mock_task @@ -135,7 +144,7 @@ def test_current_task_property(self, mock_task: Mock) -> None: def test_check_or_generate_task_id_no_params(self) -> None: """Test _check_or_generate_task_id with no params does nothing.""" - context = RequestContext() + context = RequestContext(ServerCallContext()) context._check_or_generate_task_id() assert context.task_id is None @@ -146,7 +155,7 @@ def test_check_or_generate_task_id_with_existing_task_id( existing_id = 'existing-task-id' mock_params.message.task_id = existing_id - context = RequestContext(request=mock_params) + context = RequestContext(ServerCallContext(), request=mock_params) # The method is called during initialization assert context.task_id == existing_id @@ -160,7 +169,9 @@ def test_check_or_generate_task_id_with_custom_id_generator( id_generator.generate.return_value = 'custom-task-id' context = RequestContext( - request=mock_params, task_id_generator=id_generator + ServerCallContext(), + request=mock_params, + task_id_generator=id_generator, ) # The method is called during initialization @@ -168,7 +179,7 @@ def test_check_or_generate_task_id_with_custom_id_generator( def test_check_or_generate_context_id_no_params(self) -> None: """Test _check_or_generate_context_id with no params does nothing.""" - context = RequestContext() + context = RequestContext(ServerCallContext()) context._check_or_generate_context_id() assert context.context_id is None @@ -179,7 +190,7 @@ def test_check_or_generate_context_id_with_existing_context_id( existing_id = 'existing-context-id' mock_params.message.context_id = existing_id - context = RequestContext(request=mock_params) + context = RequestContext(ServerCallContext(), request=mock_params) # The method is called during initialization assert context.context_id == existing_id @@ -193,7 +204,9 @@ def test_check_or_generate_context_id_with_custom_id_generator( id_generator.generate.return_value = 'custom-context-id' context = RequestContext( - request=mock_params, context_id_generator=id_generator + ServerCallContext(), + request=mock_params, + context_id_generator=id_generator, ) # The method is called during initialization @@ -205,7 +218,10 @@ def test_init_raises_error_on_task_id_mismatch( """Test that an error is raised if provided task_id mismatches task.id.""" with pytest.raises(InvalidParamsError) as exc_info: RequestContext( - request=mock_params, task_id='wrong-task-id', task=mock_task + ServerCallContext(), + request=mock_params, + task_id='wrong-task-id', + task=mock_task, ) assert 'bad task id' in exc_info.value.message @@ -218,6 +234,7 @@ def test_init_raises_error_on_context_id_mismatch( with pytest.raises(InvalidParamsError) as exc_info: RequestContext( + ServerCallContext(), request=mock_params, task_id=mock_task.id, context_id='wrong-context-id', @@ -229,30 +246,32 @@ def test_init_raises_error_on_context_id_mismatch( def test_with_related_tasks_provided(self, mock_task: Mock) -> None: """Test initialization with related tasks provided.""" related_tasks = [mock_task, Mock(spec=Task)] - context = RequestContext(related_tasks=related_tasks) # type: ignore[arg-type] + context = RequestContext( + ServerCallContext(), related_tasks=related_tasks + ) # type: ignore[arg-type] assert context.related_tasks == related_tasks assert len(context.related_tasks) == 2 def test_message_property_without_params(self) -> None: """Test message property returns None when no params are provided.""" - context = RequestContext() + context = RequestContext(ServerCallContext()) assert context.message is None def test_message_property_with_params(self, mock_params: Mock) -> None: """Test message property returns the message from params.""" - context = RequestContext(request=mock_params) + context = RequestContext(ServerCallContext(), request=mock_params) assert context.message == mock_params.message def test_metadata_property_without_content(self) -> None: """Test metadata property returns empty dict when no content are provided.""" - context = RequestContext() + context = RequestContext(ServerCallContext()) assert context.metadata == {} def test_metadata_property_with_content(self, mock_params: Mock) -> None: """Test metadata property returns the metadata from params.""" mock_params.metadata = {'key': 'value'} - context = RequestContext(request=mock_params) + context = RequestContext(ServerCallContext(), request=mock_params) assert context.metadata == {'key': 'value'} def test_init_with_existing_ids_in_message( @@ -262,7 +281,7 @@ def test_init_with_existing_ids_in_message( mock_message.task_id = 'existing-task-id' mock_message.context_id = 'existing-context-id' - context = RequestContext(request=mock_params) + context = RequestContext(ServerCallContext(), request=mock_params) assert context.task_id == 'existing-task-id' assert context.context_id == 'existing-context-id' @@ -275,7 +294,10 @@ def test_init_with_task_id_and_existing_task_id_match( mock_params.message.task_id = mock_task.id context = RequestContext( - request=mock_params, task_id=mock_task.id, task=mock_task + ServerCallContext(), + request=mock_params, + task_id=mock_task.id, + task=mock_task, ) assert context.task_id == mock_task.id @@ -289,6 +311,7 @@ def test_init_with_context_id_and_existing_context_id_match( mock_params.message.context_id = mock_task.context_id context = RequestContext( + ServerCallContext(), request=mock_params, task_id=mock_task.id, context_id=mock_task.context_id, diff --git a/tests/server/agent_execution/test_simple_request_context_builder.py b/tests/server/agent_execution/test_simple_request_context_builder.py index caab48342..ef374e364 100644 --- a/tests/server/agent_execution/test_simple_request_context_builder.py +++ b/tests/server/agent_execution/test_simple_request_context_builder.py @@ -127,10 +127,12 @@ async def test_build_populate_true_with_reference_task_ids(self) -> None: mock_ref_task1 = create_sample_task(task_id=ref_task_id1) mock_ref_task3 = create_sample_task(task_id=ref_task_id3) + server_call_context = ServerCallContext(user=UnauthenticatedUser()) + # Configure task_store.get mock # Note: AsyncMock side_effect needs to handle multiple calls if they have different args. # A simple way is a list of return values, or a function. - async def get_side_effect(task_id): + async def get_side_effect(task_id, server_call_context): if task_id == ref_task_id1: return mock_ref_task1 if task_id == ref_task_id3: @@ -144,7 +146,6 @@ async def get_side_effect(task_id): reference_task_ids=[ref_task_id1, ref_task_id2, ref_task_id3] ) ) - server_call_context = ServerCallContext(user=UnauthenticatedUser()) request_context = await builder.build( params=params, @@ -155,9 +156,15 @@ async def get_side_effect(task_id): ) self.assertEqual(self.mock_task_store.get.call_count, 3) - self.mock_task_store.get.assert_any_call(ref_task_id1) - self.mock_task_store.get.assert_any_call(ref_task_id2) - self.mock_task_store.get.assert_any_call(ref_task_id3) + self.mock_task_store.get.assert_any_call( + ref_task_id1, server_call_context + ) + self.mock_task_store.get.assert_any_call( + ref_task_id2, server_call_context + ) + self.mock_task_store.get.assert_any_call( + ref_task_id3, server_call_context + ) self.assertIsNotNone(request_context.related_tasks) self.assertEqual( diff --git a/tests/server/tasks/test_database_task_store.py b/tests/server/tasks/test_database_task_store.py index ff2ab1938..021345a7e 100644 --- a/tests/server/tasks/test_database_task_store.py +++ b/tests/server/tasks/test_database_task_store.py @@ -56,6 +56,9 @@ def user_name(self) -> str: return self._user_name +TEST_CONTEXT = ServerCallContext(user=SampleUser('test_user')) + + # DSNs for different databases SQLITE_TEST_DSN = ( 'sqlite+aiosqlite:///file:testdb?mode=memory&cache=shared&uri=true' @@ -170,13 +173,17 @@ async def test_save_task(db_store_parameterized: DatabaseTaskStore) -> None: task_to_save.id = ( f'save-task-{db_store_parameterized.engine.url.drivername}' ) - await db_store_parameterized.save(task_to_save) + await db_store_parameterized.save(task_to_save, TEST_CONTEXT) - retrieved_task = await db_store_parameterized.get(task_to_save.id) + retrieved_task = await db_store_parameterized.get( + task_to_save.id, TEST_CONTEXT + ) assert retrieved_task is not None assert retrieved_task.id == task_to_save.id assert MessageToDict(retrieved_task) == MessageToDict(task_to_save) - await db_store_parameterized.delete(task_to_save.id) # Cleanup + await db_store_parameterized.delete( + task_to_save.id, TEST_CONTEXT + ) # Cleanup @pytest.mark.asyncio @@ -186,14 +193,18 @@ async def test_get_task(db_store_parameterized: DatabaseTaskStore) -> None: task_to_save = Task() task_to_save.CopyFrom(MINIMAL_TASK_OBJ) task_to_save.id = task_id - await db_store_parameterized.save(task_to_save) + await db_store_parameterized.save(task_to_save, TEST_CONTEXT) - retrieved_task = await db_store_parameterized.get(task_to_save.id) + retrieved_task = await db_store_parameterized.get( + task_to_save.id, TEST_CONTEXT + ) assert retrieved_task is not None assert retrieved_task.id == task_to_save.id assert retrieved_task.context_id == task_to_save.context_id assert retrieved_task.status.state == TaskState.TASK_STATE_SUBMITTED - await db_store_parameterized.delete(task_to_save.id) # Cleanup + await db_store_parameterized.delete( + task_to_save.id, TEST_CONTEXT + ) # Cleanup @pytest.mark.asyncio @@ -321,9 +332,9 @@ async def test_list_tasks( ), ] for task in tasks_to_create: - await db_store_parameterized.save(task) + await db_store_parameterized.save(task, TEST_CONTEXT) - page = await db_store_parameterized.list(params) + page = await db_store_parameterized.list(params, TEST_CONTEXT) retrieved_ids = [task.id for task in page.tasks] assert retrieved_ids == expected_ids @@ -333,7 +344,7 @@ async def test_list_tasks( # Cleanup for task in tasks_to_create: - await db_store_parameterized.delete(task.id) + await db_store_parameterized.delete(task.id, TEST_CONTEXT) @pytest.mark.asyncio @@ -381,16 +392,16 @@ async def test_list_tasks_fails( ), ] for task in tasks_to_create: - await db_store_parameterized.save(task) + await db_store_parameterized.save(task, TEST_CONTEXT) with pytest.raises(InvalidParamsError) as excinfo: - await db_store_parameterized.list(params) + await db_store_parameterized.list(params, TEST_CONTEXT) assert expected_error_message in str(excinfo.value) # Cleanup for task in tasks_to_create: - await db_store_parameterized.delete(task.id) + await db_store_parameterized.delete(task.id, TEST_CONTEXT) @pytest.mark.asyncio @@ -398,7 +409,9 @@ async def test_get_nonexistent_task( db_store_parameterized: DatabaseTaskStore, ) -> None: """Test retrieving a nonexistent task.""" - retrieved_task = await db_store_parameterized.get('nonexistent-task-id') + retrieved_task = await db_store_parameterized.get( + 'nonexistent-task-id', TEST_CONTEXT + ) assert retrieved_task is None @@ -409,13 +422,23 @@ async def test_delete_task(db_store_parameterized: DatabaseTaskStore) -> None: task_to_save_and_delete = Task() task_to_save_and_delete.CopyFrom(MINIMAL_TASK_OBJ) task_to_save_and_delete.id = task_id - await db_store_parameterized.save(task_to_save_and_delete) + await db_store_parameterized.save(task_to_save_and_delete, TEST_CONTEXT) assert ( - await db_store_parameterized.get(task_to_save_and_delete.id) is not None + await db_store_parameterized.get( + task_to_save_and_delete.id, TEST_CONTEXT + ) + is not None + ) + await db_store_parameterized.delete( + task_to_save_and_delete.id, TEST_CONTEXT + ) + assert ( + await db_store_parameterized.get( + task_to_save_and_delete.id, TEST_CONTEXT + ) + is None ) - await db_store_parameterized.delete(task_to_save_and_delete.id) - assert await db_store_parameterized.get(task_to_save_and_delete.id) is None @pytest.mark.asyncio @@ -423,7 +446,9 @@ async def test_delete_nonexistent_task( db_store_parameterized: DatabaseTaskStore, ) -> None: """Test deleting a nonexistent task. Should not error.""" - await db_store_parameterized.delete('nonexistent-delete-task-id') + await db_store_parameterized.delete( + 'nonexistent-delete-task-id', TEST_CONTEXT + ) @pytest.mark.asyncio @@ -455,8 +480,10 @@ async def test_save_and_get_detailed_task( ], ) - await db_store_parameterized.save(test_task) - retrieved_task = await db_store_parameterized.get(test_task.id) + await db_store_parameterized.save(test_task, TEST_CONTEXT) + retrieved_task = await db_store_parameterized.get( + test_task.id, TEST_CONTEXT + ) assert retrieved_task is not None assert retrieved_task.id == test_task.id @@ -479,8 +506,8 @@ async def test_save_and_get_detailed_task( == MessageToDict(test_task)['history'] ) - await db_store_parameterized.delete(test_task.id) - assert await db_store_parameterized.get(test_task.id) is None + await db_store_parameterized.delete(test_task.id, TEST_CONTEXT) + assert await db_store_parameterized.get(test_task.id, TEST_CONTEXT) is None @pytest.mark.asyncio @@ -498,9 +525,11 @@ async def test_update_task(db_store_parameterized: DatabaseTaskStore) -> None: artifacts=[], history=[], ) - await db_store_parameterized.save(original_task) + await db_store_parameterized.save(original_task, TEST_CONTEXT) - retrieved_before_update = await db_store_parameterized.get(task_id) + retrieved_before_update = await db_store_parameterized.get( + task_id, TEST_CONTEXT + ) assert retrieved_before_update is not None assert ( retrieved_before_update.status.state == TaskState.TASK_STATE_SUBMITTED @@ -516,16 +545,18 @@ async def test_update_task(db_store_parameterized: DatabaseTaskStore) -> None: updated_task.status.timestamp.FromDatetime(updated_timestamp) updated_task.metadata['update_key'] = 'update_value' - await db_store_parameterized.save(updated_task) + await db_store_parameterized.save(updated_task, TEST_CONTEXT) - retrieved_after_update = await db_store_parameterized.get(task_id) + retrieved_after_update = await db_store_parameterized.get( + task_id, TEST_CONTEXT + ) assert retrieved_after_update is not None assert retrieved_after_update.status.state == TaskState.TASK_STATE_COMPLETED assert dict(retrieved_after_update.metadata) == { 'update_key': 'update_value' } - await db_store_parameterized.delete(task_id) + await db_store_parameterized.delete(task_id, TEST_CONTEXT) @pytest.mark.asyncio @@ -547,9 +578,9 @@ async def test_metadata_field_mapping( context_id='session-meta-1', status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), ) - await db_store_parameterized.save(task_no_metadata) + await db_store_parameterized.save(task_no_metadata, TEST_CONTEXT) retrieved_no_metadata = await db_store_parameterized.get( - 'task-metadata-test-1' + 'task-metadata-test-1', TEST_CONTEXT ) assert retrieved_no_metadata is not None # Proto Struct is empty, not None @@ -563,8 +594,10 @@ async def test_metadata_field_mapping( status=TaskStatus(state=TaskState.TASK_STATE_WORKING), metadata=simple_metadata, ) - await db_store_parameterized.save(task_simple_metadata) - retrieved_simple = await db_store_parameterized.get('task-metadata-test-2') + await db_store_parameterized.save(task_simple_metadata, TEST_CONTEXT) + retrieved_simple = await db_store_parameterized.get( + 'task-metadata-test-2', TEST_CONTEXT + ) assert retrieved_simple is not None assert dict(retrieved_simple.metadata) == simple_metadata @@ -586,8 +619,10 @@ async def test_metadata_field_mapping( status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), metadata=complex_metadata, ) - await db_store_parameterized.save(task_complex_metadata) - retrieved_complex = await db_store_parameterized.get('task-metadata-test-3') + await db_store_parameterized.save(task_complex_metadata, TEST_CONTEXT) + retrieved_complex = await db_store_parameterized.get( + 'task-metadata-test-3', TEST_CONTEXT + ) assert retrieved_complex is not None # Convert proto Struct to dict for comparison retrieved_meta = MessageToDict(retrieved_complex.metadata) @@ -599,14 +634,16 @@ async def test_metadata_field_mapping( context_id='session-meta-4', status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), ) - await db_store_parameterized.save(task_update_metadata) + await db_store_parameterized.save(task_update_metadata, TEST_CONTEXT) # Update metadata task_update_metadata.metadata['updated'] = True task_update_metadata.metadata['timestamp'] = '2024-01-01' - await db_store_parameterized.save(task_update_metadata) + await db_store_parameterized.save(task_update_metadata, TEST_CONTEXT) - retrieved_updated = await db_store_parameterized.get('task-metadata-test-4') + retrieved_updated = await db_store_parameterized.get( + 'task-metadata-test-4', TEST_CONTEXT + ) assert retrieved_updated is not None assert dict(retrieved_updated.metadata) == { 'updated': True, @@ -615,17 +652,19 @@ async def test_metadata_field_mapping( # Test 5: Clear metadata (set to empty) task_update_metadata.metadata.Clear() - await db_store_parameterized.save(task_update_metadata) + await db_store_parameterized.save(task_update_metadata, TEST_CONTEXT) - retrieved_none = await db_store_parameterized.get('task-metadata-test-4') + retrieved_none = await db_store_parameterized.get( + 'task-metadata-test-4', TEST_CONTEXT + ) assert retrieved_none is not None assert len(retrieved_none.metadata) == 0 # Cleanup - await db_store_parameterized.delete('task-metadata-test-1') - await db_store_parameterized.delete('task-metadata-test-2') - await db_store_parameterized.delete('task-metadata-test-3') - await db_store_parameterized.delete('task-metadata-test-4') + await db_store_parameterized.delete('task-metadata-test-1', TEST_CONTEXT) + await db_store_parameterized.delete('task-metadata-test-2', TEST_CONTEXT) + await db_store_parameterized.delete('task-metadata-test-3', TEST_CONTEXT) + await db_store_parameterized.delete('task-metadata-test-4', TEST_CONTEXT) @pytest.mark.asyncio @@ -874,7 +913,7 @@ async def test_core_to_0_3_model_conversion( ) # 1. Save the task (will use core_to_compat_task_model) - await store.save(original_task) + await store.save(original_task, TEST_CONTEXT) # 2. Verify it's stored in v0.3 format directly in DB async with store.async_session_maker() as session: @@ -882,17 +921,18 @@ async def test_core_to_0_3_model_conversion( assert db_task is not None assert db_task.protocol_version == '0.3' # v0.3 status JSON uses string for state + assert isinstance(db_task.status, dict) assert db_task.status['state'] == 'working' # 3. Retrieve the task (will use compat_task_model_to_core) - retrieved_task = await store.get(task_id) + retrieved_task = await store.get(task_id, context=TEST_CONTEXT) assert retrieved_task is not None assert retrieved_task.id == original_task.id assert retrieved_task.status.state == TaskState.TASK_STATE_WORKING assert dict(retrieved_task.metadata) == {'key': 'value'} # Reset conversion attributes store.core_to_model_conversion = None - await store.delete('v03-persistence-task') + await store.delete('v03-persistence-task', TEST_CONTEXT) # Ensure aiosqlite, asyncpg, and aiomysql are installed in the test environment (added to pyproject.toml). diff --git a/tests/server/tasks/test_inmemory_task_store.py b/tests/server/tasks/test_inmemory_task_store.py index af3531e33..f04a69170 100644 --- a/tests/server/tasks/test_inmemory_task_store.py +++ b/tests/server/tasks/test_inmemory_task_store.py @@ -25,6 +25,9 @@ def user_name(self) -> str: return self._user_name +TEST_CONTEXT = ServerCallContext(user=SampleUser('test_user')) + + def create_minimal_task( task_id: str = 'task-abc', context_id: str = 'session-xyz' ) -> Task: @@ -41,8 +44,8 @@ async def test_in_memory_task_store_save_and_get() -> None: """Test saving and retrieving a task from the in-memory store.""" store = InMemoryTaskStore() task = create_minimal_task() - await store.save(task) - retrieved_task = await store.get('task-abc') + await store.save(task, TEST_CONTEXT) + retrieved_task = await store.get('task-abc', TEST_CONTEXT) assert retrieved_task == task @@ -50,7 +53,7 @@ async def test_in_memory_task_store_save_and_get() -> None: async def test_in_memory_task_store_get_nonexistent() -> None: """Test retrieving a nonexistent task.""" store = InMemoryTaskStore() - retrieved_task = await store.get('nonexistent') + retrieved_task = await store.get('nonexistent', TEST_CONTEXT) assert retrieved_task is None @@ -179,9 +182,9 @@ async def test_list_tasks( ), ] for task in tasks_to_create: - await store.save(task) + await store.save(task, TEST_CONTEXT) - page = await store.list(params) + page = await store.list(params, TEST_CONTEXT) retrieved_ids = [task.id for task in page.tasks] assert retrieved_ids == expected_ids @@ -191,7 +194,7 @@ async def test_list_tasks( # Cleanup for task in tasks_to_create: - await store.delete(task.id) + await store.delete(task.id, TEST_CONTEXT) @pytest.mark.asyncio @@ -238,16 +241,16 @@ async def test_list_tasks_fails( ), ] for task in tasks_to_create: - await store.save(task) + await store.save(task, TEST_CONTEXT) with pytest.raises(InvalidParamsError) as excinfo: - await store.list(params) + await store.list(params, TEST_CONTEXT) assert expected_error_message in str(excinfo.value) # Cleanup for task in tasks_to_create: - await store.delete(task.id) + await store.delete(task.id, TEST_CONTEXT) @pytest.mark.asyncio @@ -255,9 +258,9 @@ async def test_in_memory_task_store_delete() -> None: """Test deleting a task from the store.""" store = InMemoryTaskStore() task = create_minimal_task() - await store.save(task) - await store.delete('task-abc') - retrieved_task = await store.get('task-abc') + await store.save(task, TEST_CONTEXT) + await store.delete('task-abc', TEST_CONTEXT) + retrieved_task = await store.get('task-abc', TEST_CONTEXT) assert retrieved_task is None @@ -265,7 +268,7 @@ async def test_in_memory_task_store_delete() -> None: async def test_in_memory_task_store_delete_nonexistent() -> None: """Test deleting a nonexistent task.""" store = InMemoryTaskStore() - await store.delete('nonexistent') + await store.delete('nonexistent', TEST_CONTEXT) @pytest.mark.asyncio @@ -341,10 +344,10 @@ async def test_inmemory_task_store_copying_behavior(use_copying: bool): original_task = Task( id='test_task', status=TaskStatus(state=TaskState.TASK_STATE_WORKING) ) - await store.save(original_task) + await store.save(original_task, TEST_CONTEXT) # Retrieve it - retrieved_task = await store.get('test_task') + retrieved_task = await store.get('test_task', TEST_CONTEXT) assert retrieved_task is not None if use_copying: @@ -356,7 +359,7 @@ async def test_inmemory_task_store_copying_behavior(use_copying: bool): retrieved_task.status.state = TaskState.TASK_STATE_COMPLETED # Retrieve it again, it should NOT be modified in the store if use_copying=True - retrieved_task_2 = await store.get('test_task') + retrieved_task_2 = await store.get('test_task', TEST_CONTEXT) assert retrieved_task_2 is not None if use_copying: diff --git a/tests/server/tasks/test_task_manager.py b/tests/server/tasks/test_task_manager.py index 381f71593..bdfbf525c 100644 --- a/tests/server/tasks/test_task_manager.py +++ b/tests/server/tasks/test_task_manager.py @@ -3,8 +3,9 @@ import pytest +from a2a.auth.user import User +from a2a.server.context import ServerCallContext from a2a.server.tasks import TaskManager -from a2a.utils.errors import InvalidParamsError from a2a.types.a2a_pb2 import ( Artifact, Message, @@ -19,6 +20,24 @@ from a2a.utils.errors import InvalidParamsError +class SampleUser(User): + """A test implementation of the User interface.""" + + def __init__(self, user_name: str): + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + +TEST_CONTEXT = ServerCallContext(user=SampleUser('test_user')) + + # Create proto task instead of dict def create_minimal_task( task_id: str = 'task-abc', @@ -49,6 +68,7 @@ def task_manager(mock_task_store: AsyncMock) -> TaskManager: context_id=MINIMAL_CONTEXT_ID, task_store=mock_task_store, initial_message=None, + context=TEST_CONTEXT, ) @@ -63,6 +83,7 @@ def test_task_manager_invalid_task_id( context_id='test_context', task_store=mock_task_store, initial_message=None, + context=TEST_CONTEXT, ) @@ -75,7 +96,7 @@ async def test_get_task_existing( mock_task_store.get.return_value = expected_task retrieved_task = await task_manager.get_task() assert retrieved_task == expected_task - mock_task_store.get.assert_called_once_with(MINIMAL_TASK_ID, None) + mock_task_store.get.assert_called_once_with(MINIMAL_TASK_ID, TEST_CONTEXT) @pytest.mark.asyncio @@ -86,7 +107,7 @@ async def test_get_task_nonexistent( mock_task_store.get.return_value = None retrieved_task = await task_manager.get_task() assert retrieved_task is None - mock_task_store.get.assert_called_once_with(MINIMAL_TASK_ID, None) + mock_task_store.get.assert_called_once_with(MINIMAL_TASK_ID, TEST_CONTEXT) @pytest.mark.asyncio @@ -96,7 +117,7 @@ async def test_save_task_event_new_task( """Test saving a new task.""" task = create_minimal_task() await task_manager.save_task_event(task) - mock_task_store.save.assert_called_once_with(task, None) + mock_task_store.save.assert_called_once_with(task, TEST_CONTEXT) @pytest.mark.asyncio @@ -188,7 +209,7 @@ async def test_ensure_task_existing( ) retrieved_task = await task_manager.ensure_task(event) assert retrieved_task == expected_task - mock_task_store.get.assert_called_once_with(MINIMAL_TASK_ID, None) + mock_task_store.get.assert_called_once_with(MINIMAL_TASK_ID, TEST_CONTEXT) @pytest.mark.asyncio @@ -202,6 +223,7 @@ async def test_ensure_task_nonexistent( context_id=None, task_store=mock_task_store, initial_message=None, + context=TEST_CONTEXT, ) event = TaskStatusUpdateEvent( task_id='new-task', @@ -212,7 +234,7 @@ async def test_ensure_task_nonexistent( assert new_task.id == 'new-task' assert new_task.context_id == 'some-context' assert new_task.status.state == TaskState.TASK_STATE_SUBMITTED - mock_task_store.save.assert_called_once_with(new_task, None) + mock_task_store.save.assert_called_once_with(new_task, TEST_CONTEXT) assert task_manager_without_id.task_id == 'new-task' assert task_manager_without_id.context_id == 'some-context' @@ -233,7 +255,7 @@ async def test_save_task( """Test saving a task.""" task = create_minimal_task() await task_manager._save_task(task) # type: ignore - mock_task_store.save.assert_called_once_with(task, None) + mock_task_store.save.assert_called_once_with(task, TEST_CONTEXT) @pytest.mark.asyncio @@ -263,6 +285,7 @@ async def test_save_task_event_new_task_no_task_id( context_id=None, task_store=mock_task_store, initial_message=None, + context=TEST_CONTEXT, ) task = Task( id='new-task-id', @@ -270,7 +293,7 @@ async def test_save_task_event_new_task_no_task_id( status=TaskStatus(state=TaskState.TASK_STATE_WORKING), ) await task_manager_without_id.save_task_event(task) - mock_task_store.save.assert_called_once_with(task, None) + mock_task_store.save.assert_called_once_with(task, TEST_CONTEXT) assert task_manager_without_id.task_id == 'new-task-id' assert task_manager_without_id.context_id == 'some-context' # initial submit should be updated to working @@ -287,6 +310,7 @@ async def test_get_task_no_task_id( context_id='some-context', task_store=mock_task_store, initial_message=None, + context=TEST_CONTEXT, ) retrieved_task = await task_manager_without_id.get_task() assert retrieved_task is None @@ -303,6 +327,7 @@ async def test_save_task_event_no_task_existing( context_id=None, task_store=mock_task_store, initial_message=None, + context=TEST_CONTEXT, ) mock_task_store.get.return_value = None event = TaskStatusUpdateEvent( diff --git a/tests/server/test_owner_resolver.py b/tests/server/test_owner_resolver.py index 5bac5c605..dffee863e 100644 --- a/tests/server/test_owner_resolver.py +++ b/tests/server/test_owner_resolver.py @@ -19,13 +19,13 @@ def user_name(self) -> str: return self._user_name -def test_resolve_user_scope_valid_user(): - """Test resolve_user_scope with a valid user in the context.""" +def test_resolve_user_scope_with_authenticated_user(): + """Test resolve_user_scope with an authenticated user in the context.""" user = SampleUser(user_name='SampleUser') context = ServerCallContext(user=user) assert resolve_user_scope(context) == 'SampleUser' -def test_resolve_user_scope_no_context(): - """Test resolve_user_scope when the context is None.""" - assert resolve_user_scope(None) == 'unknown' +def test_resolve_user_default_context(): + """Test resolve_user_scope with default context.""" + assert resolve_user_scope(ServerCallContext()) == ''