Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
2a8668e
Merge branch '1.0-dev' of https://github.com/sokoliva/a2a-python into…
sokoliva Mar 20, 2026
2f94c4e
Merge branch '1.0-dev' of https://github.com/sokoliva/a2a-python into…
sokoliva Mar 20, 2026
72a8a7f
feat: Enforce ServerCallContext in all TaskStore operations
sokoliva Mar 20, 2026
93e3e20
make ServerCallCOntext mandatory in task manager
sokoliva Mar 25, 2026
dc49636
Merge branch '1.0-dev' into Server-Call-Context-Mandatory
sokoliva Mar 25, 2026
2af7880
Make ServerCallContext mandatory in CopyingTaskStore
sokoliva Mar 25, 2026
fbab840
format
sokoliva Mar 25, 2026
cdeb9f9
Merge branch '1.0-dev' into Server-Call-Context-Mandatory
sokoliva Mar 25, 2026
4c9089c
Merge branch '1.0-dev' into Server-Call-Context-Mandatory
sokoliva Mar 25, 2026
3541781
move mandatory parameters to top
sokoliva Mar 26, 2026
ae9cec9
Merge branch 'Server-Call-Context-Mandatory' of https://github.com/so…
sokoliva Mar 26, 2026
a21a16a
Merge branch '1.0-dev' into Server-Call-Context-Mandatory
sokoliva Mar 26, 2026
50cf655
Make `ServerCallContext` mandatory in `RequestContext`
sokoliva Mar 26, 2026
cf3c9fe
Merge branch '1.0-dev' into Server-Call-Context-Mandatory
sokoliva Mar 26, 2026
fddfec7
fix
sokoliva Mar 26, 2026
e77f8d9
Merge branch 'Server-Call-Context-Mandatory' of https://github.com/so…
sokoliva Mar 26, 2026
4cd6708
make ServerCallCOntext mandatory
sokoliva Mar 26, 2026
006cff8
test descr
sokoliva Mar 26, 2026
b9136dc
Merge branch '1.0-dev' into Server-Call-Context-Mandatory
sokoliva Mar 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions src/a2a/contrib/tasks/vertex_task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def __init__(
self._client = client
self._agent_engine_resource_id = agent_engine_resource_id

async def save(
self, task: Task, context: ServerCallContext | None = None
) -> None:
async def save(self, task: Task, context: ServerCallContext) -> None:
"""Saves or updates a task in the store."""
compat_task = to_compat_task(task)
previous_task = await self._get_stored_task(compat_task.id)
Expand Down Expand Up @@ -206,7 +204,7 @@ async def _get_stored_task(
return a2a_task

async def get(
self, task_id: str, context: ServerCallContext | None = None
self, task_id: str, context: ServerCallContext
) -> Task | None:
"""Retrieves a task from the database by ID."""
a2a_task = await self._get_stored_task(task_id)
Expand All @@ -217,13 +215,11 @@ async def get(
async def list(
self,
params: ListTasksRequest,
context: ServerCallContext | None = None,
context: ServerCallContext,
) -> ListTasksResponse:
"""Retrieves a list of tasks from the store."""
raise NotImplementedError

async def delete(
self, task_id: str, context: ServerCallContext | None = None
) -> None:
async def delete(self, task_id: str, context: ServerCallContext) -> None:
"""The backend doesn't support deleting tasks, so this is not implemented."""
raise NotImplementedError
19 changes: 7 additions & 12 deletions src/a2a/server/agent_execution/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,35 +26,35 @@ class RequestContext:

def __init__( # noqa: PLR0913
self,
call_context: ServerCallContext,
request: SendMessageRequest | None = None,
task_id: str | None = None,
context_id: str | None = None,
task: Task | None = None,
related_tasks: list[Task] | None = None,
call_context: ServerCallContext | None = None,
task_id_generator: IDGenerator | None = None,
context_id_generator: IDGenerator | None = None,
):
"""Initializes the RequestContext.

Args:
call_context: The server call context associated with this request.
request: The incoming `SendMessageRequest` request payload.
task_id: The ID of the task explicitly provided in the request or path.
context_id: The ID of the context explicitly provided in the request or path.
task: The existing `Task` object retrieved from the store, if any.
related_tasks: A list of other tasks related to the current request (e.g., for tool use).
call_context: The server call context associated with this request.
task_id_generator: ID generator for new task IDs. Defaults to UUID generator.
context_id_generator: ID generator for new context IDs. Defaults to UUID generator.
"""
if related_tasks is None:
related_tasks = []
self._call_context = call_context
self._params = request
self._task_id = task_id
self._context_id = context_id
self._current_task = task
self._related_tasks = related_tasks
self._call_context = call_context
self._task_id_generator = (
task_id_generator if task_id_generator else UUIDGenerator()
)
Expand Down Expand Up @@ -140,7 +140,7 @@ def configuration(self) -> SendMessageConfiguration | None:
return self._params.configuration if self._params else None

@property
def call_context(self) -> ServerCallContext | None:
def call_context(self) -> ServerCallContext:
"""The server call context associated with this request."""
return self._call_context

Expand All @@ -157,22 +157,17 @@ def add_activated_extension(self, uri: str) -> None:
This causes the extension to be indicated back to the client in the
response.
"""
if self._call_context:
self._call_context.activated_extensions.add(uri)
self._call_context.activated_extensions.add(uri)

@property
def tenant(self) -> str:
"""The tenant associated with this request."""
return self._call_context.tenant if self._call_context else ''
return self._call_context.tenant

@property
def requested_extensions(self) -> set[str]:
"""Extensions that the client requested to activate."""
return (
self._call_context.requested_extensions
if self._call_context
else set()
)
return self._call_context.requested_extensions

def _check_or_generate_task_id(self) -> None:
"""Ensures a task ID is present, generating one if necessary."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ class RequestContextBuilder(ABC):
@abstractmethod
async def build(
self,
context: ServerCallContext,
params: SendMessageRequest | None = None,
task_id: str | None = None,
context_id: str | None = None,
task: Task | None = None,
context: ServerCallContext | None = None,
) -> RequestContext:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ def __init__(

async def build(
self,
context: ServerCallContext,
params: SendMessageRequest | None = None,
task_id: str | None = None,
context_id: str | None = None,
task: Task | None = None,
context: ServerCallContext | None = None,
) -> RequestContext:
"""Builds the request context for an agent execution.

Expand All @@ -48,11 +48,11 @@ async def build(
referenced in `params.message.reference_task_ids` from the `task_store`.

Args:
context: The server call context, containing metadata about the call.
params: The parameters of the incoming message send request.
task_id: The ID of the task being executed.
context_id: The ID of the current execution context.
task: The primary task object associated with the request.
context: The server call context, containing metadata about the call.

Returns:
An instance of RequestContext populated with the provided information
Expand All @@ -68,19 +68,19 @@ async def build(
):
tasks = await asyncio.gather(
*[
self._task_store.get(task_id)
self._task_store.get(task_id, context)
for task_id in params.message.reference_task_ids
]
)
related_tasks = [x for x in tasks if x is not None]

return RequestContext(
call_context=context,
request=params,
task_id=task_id,
context_id=context_id,
task=task,
related_tasks=related_tasks,
call_context=context,
task_id_generator=self._task_id_generator,
context_id_generator=self._context_id_generator,
)
7 changes: 2 additions & 5 deletions src/a2a/server/owner_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@


# Definition
OwnerResolver = Callable[[ServerCallContext | None], str]
OwnerResolver = Callable[[ServerCallContext], str]


# Example Default Implementation
def resolve_user_scope(context: ServerCallContext | None) -> str:
def resolve_user_scope(context: ServerCallContext) -> str:
"""Resolves the owner scope based on the user in the context."""
if not context:
return 'unknown'
# Example: Basic user name. Adapt as needed for your user model.
return context.user.user_name
18 changes: 7 additions & 11 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ async def on_cancel_task(

await self.agent_executor.cancel(
RequestContext(
None,
call_context=context,
request=None,
task_id=task.id,
context_id=task.context_id,
task=task,
Expand Down Expand Up @@ -290,7 +291,7 @@ async def _setup_message_execution(
await self._push_config_store.set_info(
task_id,
params.configuration.task_push_notification_config,
context or ServerCallContext(),
context,
)

queue = await self._queue_manager.create_or_tap(task_id)
Expand Down Expand Up @@ -504,7 +505,7 @@ async def on_create_task_push_notification_config(
await self._push_config_store.set_info(
task_id,
params,
context or ServerCallContext(),
context,
)

return params
Expand All @@ -529,10 +530,7 @@ async def on_get_task_push_notification_config(
raise TaskNotFoundError

push_notification_configs: list[TaskPushNotificationConfig] = (
await self._push_config_store.get_info(
task_id, context or ServerCallContext()
)
or []
await self._push_config_store.get_info(task_id, context) or []
)

for config in push_notification_configs:
Expand Down Expand Up @@ -603,7 +601,7 @@ async def on_list_task_push_notification_configs(
raise TaskNotFoundError

push_notification_config_list = await self._push_config_store.get_info(
task_id, context or ServerCallContext()
task_id, context
)

return ListTaskPushNotificationConfigsResponse(
Expand All @@ -629,6 +627,4 @@ async def on_delete_task_push_notification_config(
if not task:
raise TaskNotFoundError

await self._push_config_store.delete_info(
task_id, context or ServerCallContext(), config_id
)
await self._push_config_store.delete_info(task_id, context, config_id)
12 changes: 4 additions & 8 deletions src/a2a/server/tasks/copying_task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,14 @@ class CopyingTaskStoreAdapter(TaskStore):
def __init__(self, underlying_store: TaskStore):
self._store = underlying_store

async def save(
self, task: Task, context: ServerCallContext | None = None
) -> None:
async def save(self, task: Task, context: ServerCallContext) -> None:
"""Saves a copy of the task to the underlying store."""
task_copy = Task()
task_copy.CopyFrom(task)
await self._store.save(task_copy, context)

async def get(
self, task_id: str, context: ServerCallContext | None = None
self, task_id: str, context: ServerCallContext
) -> Task | None:
"""Retrieves a task from the underlying store and returns a copy."""
task = await self._store.get(task_id, context)
Expand All @@ -46,16 +44,14 @@ async def get(
async def list(
self,
params: ListTasksRequest,
context: ServerCallContext | None = None,
context: ServerCallContext,
) -> ListTasksResponse:
"""Retrieves a list of tasks from the underlying store and returns a copy."""
response = await self._store.list(params, context)
response_copy = ListTasksResponse()
response_copy.CopyFrom(response)
return response_copy

async def delete(
self, task_id: str, context: ServerCallContext | None = None
) -> None:
async def delete(self, task_id: str, context: ServerCallContext) -> None:
"""Deletes a task from the underlying store."""
await self._store.delete(task_id, context)
12 changes: 4 additions & 8 deletions src/a2a/server/tasks/database_task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,7 @@ def _from_orm(self, task_model: TaskModel) -> Task:
# Legacy conversion
return compat_task_model_to_core(task_model)

async def save(
self, task: Task, context: ServerCallContext | None = None
) -> None:
async def save(self, task: Task, context: ServerCallContext) -> None:
"""Saves or updates a task in the database for the resolved owner."""
await self._ensure_initialized()
owner = self.owner_resolver(context)
Expand All @@ -185,7 +183,7 @@ async def save(
)

async def get(
self, task_id: str, context: ServerCallContext | None = None
self, task_id: str, context: ServerCallContext
) -> Task | None:
"""Retrieves a task from the database by ID, for the given owner."""
await self._ensure_initialized()
Expand Down Expand Up @@ -216,7 +214,7 @@ async def get(
async def list(
self,
params: a2a_pb2.ListTasksRequest,
context: ServerCallContext | None = None,
context: ServerCallContext,
) -> a2a_pb2.ListTasksResponse:
"""Retrieves tasks from the database based on provided parameters, for the given owner."""
await self._ensure_initialized()
Expand Down Expand Up @@ -315,9 +313,7 @@ async def list(
page_size=page_size,
)

async def delete(
self, task_id: str, context: ServerCallContext | None = None
) -> None:
async def delete(self, task_id: str, context: ServerCallContext) -> None:
"""Deletes a task from the database by ID, for the given owner."""
await self._ensure_initialized()
owner = self.owner_resolver(context)
Expand Down
24 changes: 8 additions & 16 deletions src/a2a/server/tasks/inmemory_task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def __init__(
def _get_owner_tasks(self, owner: str) -> dict[str, Task]:
return self.tasks.get(owner, {})

async def save(
self, task: Task, context: ServerCallContext | None = None
) -> None:
async def save(self, task: Task, context: ServerCallContext) -> None:
"""Saves or updates a task in the in-memory store for the resolved owner."""
owner = self.owner_resolver(context)
if owner not in self.tasks:
Expand All @@ -50,7 +48,7 @@ async def save(
)

async def get(
self, task_id: str, context: ServerCallContext | None = None
self, task_id: str, context: ServerCallContext
) -> Task | None:
"""Retrieves a task from the in-memory store by ID, for the given owner."""
owner = self.owner_resolver(context)
Expand All @@ -77,7 +75,7 @@ async def get(
async def list(
self,
params: a2a_pb2.ListTasksRequest,
context: ServerCallContext | None = None,
context: ServerCallContext,
) -> a2a_pb2.ListTasksResponse:
"""Retrieves a list of tasks from the store, for the given owner."""
owner = self.owner_resolver(context)
Expand Down Expand Up @@ -156,9 +154,7 @@ async def list(
page_size=page_size,
)

async def delete(
self, task_id: str, context: ServerCallContext | None = None
) -> None:
async def delete(self, task_id: str, context: ServerCallContext) -> None:
"""Deletes a task from the in-memory store by ID, for the given owner."""
owner = self.owner_resolver(context)
async with self.lock:
Expand Down Expand Up @@ -211,28 +207,24 @@ def __init__(
CopyingTaskStoreAdapter(self._impl) if use_copying else self._impl
)

async def save(
self, task: Task, context: ServerCallContext | None = None
) -> None:
async def save(self, task: Task, context: ServerCallContext) -> None:
"""Saves or updates a task in the store."""
await self._store.save(task, context)

async def get(
self, task_id: str, context: ServerCallContext | None = None
self, task_id: str, context: ServerCallContext
) -> Task | None:
"""Retrieves a task from the store by ID."""
return await self._store.get(task_id, context)

async def list(
self,
params: a2a_pb2.ListTasksRequest,
context: ServerCallContext | None = None,
context: ServerCallContext,
) -> a2a_pb2.ListTasksResponse:
"""Retrieves a list of tasks from the store."""
return await self._store.list(params, context)

async def delete(
self, task_id: str, context: ServerCallContext | None = None
) -> None:
async def delete(self, task_id: str, context: ServerCallContext) -> None:
"""Deletes a task from the store by ID."""
await self._store.delete(task_id, context)
Loading
Loading