Skip to content

Commit ea7d3ad

Browse files
knapgishymko
andauthored
feat(rest): update REST error handling to use google.rpc.Status (#838)
# Description This PR refactors the REST transport error handling to adhere to the `google.rpc.Status` JSON format. Both the server-side exception handlers and the client-side REST transport have been updated to utilize the new standard error envelope, ensuring consistency across A2A REST APIs. ## Summary of Changes * **Server:** * Updated `rest_error_handler` and the global `StarletteHTTPException` handler in `A2ARESTFastAPIApplication` to return errors wrapped in an `{'error': {...}}` envelope. * Payloads now correctly include the HTTP `code`, gRPC `status`, `message`, and a `details` array containing `type.googleapis.com/google.rpc.ErrorInfo` for the specific reason and metadata. * **Client:** * Modified `RestTransport._handle_http_error` to parse the new format. It now gracefully extracts the `reason` from the `ErrorInfo` detail object to map it back to the corresponding Python `A2AError` class. * **Core/Utils:** * Introduced `A2A_REST_ERROR_MAPPING` in `errors.py` to centralize the mapping of Python exceptions to their respective HTTP status codes, gRPC statuses, and string reasons. * Added a `data` attribute to the base `A2AError` to carry arbitrary error metadata. * **Tests:** * Updated REST client, server, and error handler tests to validate the new nested `{'error': {...}}` JSON payload structures. - [X] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [X] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [X] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [X] Appropriate docs were updated (if necessary) Fixes #722 🦕 --------- Co-authored-by: Ivan Shymko <ishymko@google.com>
1 parent cac6f58 commit ea7d3ad

File tree

8 files changed

+265
-114
lines changed

8 files changed

+265
-114
lines changed

.github/actions/spelling/allow.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ agentic
1414
AGrpc
1515
aio
1616
aiomysql
17+
AIP
1718
alg
1819
amannn
1920
aproject

src/a2a/client/transports/rest.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,12 @@
3434
Task,
3535
TaskPushNotificationConfig,
3636
)
37-
from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP, MethodNotFoundError
37+
from a2a.utils.errors import A2A_REASON_TO_ERROR, MethodNotFoundError
3838
from a2a.utils.telemetry import SpanKind, trace_class
3939

4040

4141
logger = logging.getLogger(__name__)
4242

43-
_A2A_ERROR_NAME_TO_CLS = {
44-
error_type.__name__: error_type for error_type in JSON_RPC_ERROR_CODE_MAP
45-
}
46-
4743

4844
@trace_class(kind=SpanKind.CLIENT)
4945
class RestTransport(ClientTransport):
@@ -297,15 +293,36 @@ def _get_path(self, base_path: str, tenant: str) -> str:
297293
def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn:
298294
"""Handles HTTP status errors and raises the appropriate A2AError."""
299295
try:
300-
error_data = e.response.json()
301-
error_type = error_data.get('type')
302-
message = error_data.get('message', str(e))
296+
error_payload = e.response.json()
297+
error_data = error_payload.get('error', {})
303298

304-
if isinstance(error_type, str):
305-
# TODO(#723): Resolving imports by name is temporary until proper error handling structure is added in #723.
306-
exception_cls = _A2A_ERROR_NAME_TO_CLS.get(error_type)
299+
message = error_data.get('message', str(e))
300+
details = error_data.get('details', [])
301+
if not isinstance(details, list):
302+
details = []
303+
304+
# The `details` array can contain multiple different error objects.
305+
# We extract the first `ErrorInfo` object because it contains the
306+
# specific `reason` code needed to map this back to a Python A2AError.
307+
error_info = {}
308+
for d in details:
309+
if (
310+
isinstance(d, dict)
311+
and d.get('@type')
312+
== 'type.googleapis.com/google.rpc.ErrorInfo'
313+
):
314+
error_info = d
315+
break
316+
reason = error_info.get('reason')
317+
metadata = error_info.get('metadata') or {}
318+
319+
if isinstance(reason, str):
320+
exception_cls = A2A_REASON_TO_ERROR.get(reason)
307321
if exception_cls:
308-
raise exception_cls(message) from e
322+
exc = exception_cls(message)
323+
if metadata:
324+
exc.data = metadata
325+
raise exc from e
309326
except (json.JSONDecodeError, ValueError):
310327
pass
311328

src/a2a/server/apps/rest/fastapi_app.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,22 @@
77
if TYPE_CHECKING:
88
from fastapi import APIRouter, FastAPI, Request, Response
99
from fastapi.responses import JSONResponse
10+
from starlette.exceptions import HTTPException as StarletteHTTPException
1011

1112
_package_fastapi_installed = True
1213
else:
1314
try:
1415
from fastapi import APIRouter, FastAPI, Request, Response
1516
from fastapi.responses import JSONResponse
17+
from starlette.exceptions import HTTPException as StarletteHTTPException
1618

1719
_package_fastapi_installed = True
1820
except ImportError:
1921
APIRouter = Any
2022
FastAPI = Any
2123
Request = Any
2224
Response = Any
25+
StarletteHTTPException = Any
2326

2427
_package_fastapi_installed = False
2528

@@ -36,6 +39,23 @@
3639
logger = logging.getLogger(__name__)
3740

3841

42+
_HTTP_TO_GRPC_STATUS_MAP = {
43+
400: 'INVALID_ARGUMENT',
44+
401: 'UNAUTHENTICATED',
45+
403: 'PERMISSION_DENIED',
46+
404: 'NOT_FOUND',
47+
405: 'UNIMPLEMENTED',
48+
409: 'ALREADY_EXISTS',
49+
415: 'INVALID_ARGUMENT',
50+
422: 'INVALID_ARGUMENT',
51+
500: 'INTERNAL',
52+
501: 'UNIMPLEMENTED',
53+
502: 'INTERNAL',
54+
503: 'UNAVAILABLE',
55+
504: 'DEADLINE_EXCEEDED',
56+
}
57+
58+
3959
class A2ARESTFastAPIApplication:
4060
"""A FastAPI application implementing the A2A protocol server REST endpoints.
4161
@@ -121,6 +141,34 @@ def build(
121141
A configured FastAPI application instance.
122142
"""
123143
app = FastAPI(**kwargs)
144+
145+
@app.exception_handler(StarletteHTTPException)
146+
async def http_exception_handler(
147+
request: Request, exc: StarletteHTTPException
148+
) -> Response:
149+
"""Catches framework-level HTTP exceptions.
150+
151+
For example, 404 Not Found for bad routes, 422 Unprocessable Entity
152+
for schema validation, and formats them into the A2A standard
153+
google.rpc.Status JSON format (AIP-193).
154+
"""
155+
grpc_status = _HTTP_TO_GRPC_STATUS_MAP.get(
156+
exc.status_code, 'UNKNOWN'
157+
)
158+
return JSONResponse(
159+
status_code=exc.status_code,
160+
content={
161+
'error': {
162+
'code': exc.status_code,
163+
'status': grpc_status,
164+
'message': str(exc.detail)
165+
if hasattr(exc, 'detail')
166+
else 'HTTP Exception',
167+
}
168+
},
169+
media_type='application/json',
170+
)
171+
124172
if self.enable_v0_3_compat and self._v03_adapter:
125173
v03_adapter = self._v03_adapter
126174
v03_router = APIRouter()

src/a2a/utils/error_handlers.py

Lines changed: 59 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33

44
from collections.abc import Awaitable, Callable, Coroutine
5-
from typing import TYPE_CHECKING, Any, cast
5+
from typing import TYPE_CHECKING, Any
66

77

88
if TYPE_CHECKING:
@@ -17,70 +17,40 @@
1717

1818
from google.protobuf.json_format import ParseError
1919

20-
from a2a.server.jsonrpc_models import (
21-
InternalError as JSONRPCInternalError,
22-
)
23-
from a2a.server.jsonrpc_models import (
24-
JSONParseError,
25-
JSONRPCError,
26-
)
2720
from a2a.utils.errors import (
21+
A2A_REST_ERROR_MAPPING,
2822
A2AError,
29-
ContentTypeNotSupportedError,
30-
ExtendedAgentCardNotConfiguredError,
31-
ExtensionSupportRequiredError,
3223
InternalError,
33-
InvalidAgentResponseError,
34-
InvalidParamsError,
35-
InvalidRequestError,
36-
MethodNotFoundError,
37-
PushNotificationNotSupportedError,
38-
TaskNotCancelableError,
39-
TaskNotFoundError,
40-
UnsupportedOperationError,
41-
VersionNotSupportedError,
24+
RestErrorMap,
4225
)
4326

4427

4528
logger = logging.getLogger(__name__)
4629

47-
_A2AErrorType = (
48-
type[JSONRPCError]
49-
| type[JSONParseError]
50-
| type[InvalidRequestError]
51-
| type[MethodNotFoundError]
52-
| type[InvalidParamsError]
53-
| type[InternalError]
54-
| type[JSONRPCInternalError]
55-
| type[TaskNotFoundError]
56-
| type[TaskNotCancelableError]
57-
| type[PushNotificationNotSupportedError]
58-
| type[UnsupportedOperationError]
59-
| type[ContentTypeNotSupportedError]
60-
| type[InvalidAgentResponseError]
61-
| type[ExtendedAgentCardNotConfiguredError]
62-
| type[ExtensionSupportRequiredError]
63-
| type[VersionNotSupportedError]
64-
)
6530

66-
A2AErrorToHttpStatus: dict[_A2AErrorType, int] = {
67-
JSONRPCError: 500,
68-
JSONParseError: 400,
69-
InvalidRequestError: 400,
70-
MethodNotFoundError: 404,
71-
InvalidParamsError: 422,
72-
InternalError: 500,
73-
JSONRPCInternalError: 500,
74-
TaskNotFoundError: 404,
75-
TaskNotCancelableError: 409,
76-
PushNotificationNotSupportedError: 501,
77-
UnsupportedOperationError: 501,
78-
ContentTypeNotSupportedError: 415,
79-
InvalidAgentResponseError: 502,
80-
ExtendedAgentCardNotConfiguredError: 400,
81-
ExtensionSupportRequiredError: 400,
82-
VersionNotSupportedError: 400,
83-
}
31+
def _build_error_payload(
32+
code: int,
33+
status: str,
34+
message: str,
35+
reason: str | None = None,
36+
metadata: dict[str, Any] | None = None,
37+
) -> dict[str, Any]:
38+
"""Helper function to build the JSON error payload."""
39+
payload: dict[str, Any] = {
40+
'code': code,
41+
'status': status,
42+
'message': message,
43+
}
44+
if reason:
45+
payload['details'] = [
46+
{
47+
'@type': 'type.googleapis.com/google.rpc.ErrorInfo',
48+
'reason': reason,
49+
'domain': 'a2a-protocol.org',
50+
'metadata': metadata if metadata is not None else {},
51+
}
52+
]
53+
return {'error': payload}
8454

8555

8656
def rest_error_handler(
@@ -93,9 +63,12 @@ async def wrapper(*args: Any, **kwargs: Any) -> Response:
9363
try:
9464
return await func(*args, **kwargs)
9565
except A2AError as error:
96-
http_code = A2AErrorToHttpStatus.get(
97-
cast('_A2AErrorType', type(error)), 500
66+
mapping = A2A_REST_ERROR_MAPPING.get(
67+
type(error), RestErrorMap(500, 'INTERNAL', 'INTERNAL_ERROR')
9868
)
69+
http_code = mapping.http_code
70+
grpc_status = mapping.grpc_status
71+
reason = mapping.reason
9972

10073
log_level = (
10174
logging.ERROR
@@ -107,32 +80,46 @@ async def wrapper(*args: Any, **kwargs: Any) -> Response:
10780
"Request error: Code=%s, Message='%s'%s",
10881
getattr(error, 'code', 'N/A'),
10982
getattr(error, 'message', str(error)),
110-
', Data=' + str(getattr(error, 'data', ''))
111-
if getattr(error, 'data', None)
112-
else '',
83+
f', Data={error.data}' if error.data else '',
11384
)
114-
# TODO(#722): Standardize error response format.
85+
86+
# SECURITY WARNING: Data attached to A2AError.data is serialized unaltered and exposed publicly to the client in the REST API response.
87+
metadata = getattr(error, 'data', None) or {}
88+
11589
return JSONResponse(
116-
content={
117-
'message': getattr(error, 'message', str(error)),
118-
'type': type(error).__name__,
119-
},
90+
content=_build_error_payload(
91+
code=http_code,
92+
status=grpc_status,
93+
message=getattr(error, 'message', str(error)),
94+
reason=reason,
95+
metadata=metadata,
96+
),
12097
status_code=http_code,
98+
media_type='application/json',
12199
)
122100
except ParseError as error:
123101
logger.warning('Parse error: %s', str(error))
124102
return JSONResponse(
125-
content={
126-
'message': str(error),
127-
'type': 'ParseError',
128-
},
103+
content=_build_error_payload(
104+
code=400,
105+
status='INVALID_ARGUMENT',
106+
message=str(error),
107+
reason='INVALID_REQUEST',
108+
metadata={},
109+
),
129110
status_code=400,
111+
media_type='application/json',
130112
)
131113
except Exception:
132114
logger.exception('Unknown error occurred')
133115
return JSONResponse(
134-
content={'message': 'unknown exception', 'type': 'Exception'},
116+
content=_build_error_payload(
117+
code=500,
118+
status='INTERNAL',
119+
message='unknown exception',
120+
),
135121
status_code=500,
122+
media_type='application/json',
136123
)
137124

138125
return wrapper
@@ -158,9 +145,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
158145
"Request error: Code=%s, Message='%s'%s",
159146
getattr(error, 'code', 'N/A'),
160147
getattr(error, 'message', str(error)),
161-
', Data=' + str(getattr(error, 'data', ''))
162-
if getattr(error, 'data', None)
163-
else '',
148+
f', Data={error.data}' if error.data else '',
164149
)
165150
# Since the stream has started, we can't return a JSONResponse.
166151
# Instead, we run the error handling logic (provides logging)

0 commit comments

Comments
 (0)