diff --git a/changes/8772.misc.md b/changes/8772.misc.md new file mode 100644 index 00000000000..44f581100bc --- /dev/null +++ b/changes/8772.misc.md @@ -0,0 +1 @@ +Extract `_observe` helper closure in `GQLMetricMiddleware.resolve()` to reduce code duplication diff --git a/src/ai/backend/manager/api/gql_legacy/schema.py b/src/ai/backend/manager/api/gql_legacy/schema.py index 30c6e7c15c5..f4fdc73165c 100644 --- a/src/ai/backend/manager/api/gql_legacy/schema.py +++ b/src/ai/backend/manager/api/gql_legacy/schema.py @@ -3357,40 +3357,32 @@ def resolve( operation_name = ( info.operation.name.value if info.operation.name is not None else "anonymous" ) - start = time.perf_counter() - try: - res = next(root, info, **args) - graph_ctx.metric_observer.observe_request( - operation_type=operation_type, - field_name=field_name, - parent_type=parent_type, - operation_name=operation_name, - error_code=None, - success=True, - duration=time.perf_counter() - start, - ) - except BackendAIError as e: + + def _observe(*, duration: float, error: BaseException | None = None) -> None: + match error: + case None: + error_code = None + case BackendAIError(): + error_code = error.error_code() + case _: + error_code = ErrorCode.default() graph_ctx.metric_observer.observe_request( operation_type=operation_type, field_name=field_name, parent_type=parent_type, operation_name=operation_name, - error_code=e.error_code(), - success=False, - duration=time.perf_counter() - start, + error_code=error_code, + success=error is None, + duration=duration, ) - raise e + + start = time.perf_counter() + try: + res = next(root, info, **args) + _observe(duration=time.perf_counter() - start) except BaseException as e: - graph_ctx.metric_observer.observe_request( - operation_type=operation_type, - field_name=field_name, - parent_type=parent_type, - operation_name=operation_name, - error_code=ErrorCode.default(), - success=False, - duration=time.perf_counter() - start, - ) - raise e + _observe(duration=time.perf_counter() - start, error=e) + raise return res diff --git a/tests/unit/manager/api/test_gql_metric_middleware.py b/tests/unit/manager/api/test_gql_metric_middleware.py new file mode 100644 index 00000000000..9328c709ed4 --- /dev/null +++ b/tests/unit/manager/api/test_gql_metric_middleware.py @@ -0,0 +1,135 @@ +"""Tests for GQLMetricMiddleware _observe helper.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest +from graphql import OperationType + +from ai.backend.common.exception import ( + ErrorCode, + InvalidAPIParameters, +) +from ai.backend.manager.api.gql_legacy.schema import GQLMetricMiddleware + + +@pytest.fixture +def metric_observer() -> MagicMock: + observer = MagicMock() + observer.observe_request = MagicMock() + return observer + + +@pytest.fixture +def resolve_info(metric_observer: MagicMock) -> MagicMock: + info = MagicMock() + info.context.metric_observer = metric_observer + info.operation.operation = OperationType.QUERY + info.operation.name.value = "TestQuery" + info.field_name = "test_field" + info.parent_type.name = "Query" + return info + + +@pytest.fixture +def middleware() -> GQLMetricMiddleware: + return GQLMetricMiddleware() + + +class TestGQLMetricMiddlewareSyncResolver: + """Tests for sync resolver timing in GQLMetricMiddleware.""" + + def test_sync_resolver_records_timing( + self, + middleware: GQLMetricMiddleware, + resolve_info: MagicMock, + metric_observer: MagicMock, + ) -> None: + def sync_resolver(root: Any, info: Any, **kwargs: Any) -> str: + return "sync_result" + + result = middleware.resolve(sync_resolver, None, resolve_info) + + assert result == "sync_result" + metric_observer.observe_request.assert_called_once() + call_kwargs = metric_observer.observe_request.call_args.kwargs + assert call_kwargs["success"] is True + assert call_kwargs["error_code"] is None + assert call_kwargs["duration"] >= 0 + assert call_kwargs["field_name"] == "test_field" + assert call_kwargs["parent_type"] == "Query" + assert call_kwargs["operation_name"] == "TestQuery" + + def test_sync_resolver_backend_ai_error( + self, + middleware: GQLMetricMiddleware, + resolve_info: MagicMock, + metric_observer: MagicMock, + ) -> None: + error = InvalidAPIParameters("test error") + + def sync_resolver(root: Any, info: Any, **kwargs: Any) -> str: + raise error + + with pytest.raises(InvalidAPIParameters): + middleware.resolve(sync_resolver, None, resolve_info) + + metric_observer.observe_request.assert_called_once() + call_kwargs = metric_observer.observe_request.call_args.kwargs + assert call_kwargs["success"] is False + assert call_kwargs["error_code"] == error.error_code() + + def test_sync_resolver_generic_exception( + self, + middleware: GQLMetricMiddleware, + resolve_info: MagicMock, + metric_observer: MagicMock, + ) -> None: + def sync_resolver(root: Any, info: Any, **kwargs: Any) -> str: + raise RuntimeError("unexpected") + + with pytest.raises(RuntimeError): + middleware.resolve(sync_resolver, None, resolve_info) + + metric_observer.observe_request.assert_called_once() + call_kwargs = metric_observer.observe_request.call_args.kwargs + assert call_kwargs["success"] is False + assert call_kwargs["error_code"] == ErrorCode.default() + + def test_sync_resolver_anonymous_operation( + self, + middleware: GQLMetricMiddleware, + resolve_info: MagicMock, + metric_observer: MagicMock, + ) -> None: + resolve_info.operation.name = None + + def sync_resolver(root: Any, info: Any, **kwargs: Any) -> str: + return "result" + + middleware.resolve(sync_resolver, None, resolve_info) + + call_kwargs = metric_observer.observe_request.call_args.kwargs + assert call_kwargs["operation_name"] == "anonymous" + + +class TestGQLMetricMiddlewareAsyncAnonymousOperation: + """Tests for anonymous operation handling with async resolvers.""" + + async def test_async_resolver_anonymous_operation( + self, + middleware: GQLMetricMiddleware, + resolve_info: MagicMock, + metric_observer: MagicMock, + ) -> None: + resolve_info.operation.name = None + + async def async_resolver(root: Any, info: Any, **kwargs: Any) -> str: + return "result" + + await middleware.resolve(async_resolver, None, resolve_info) + + call_kwargs = metric_observer.observe_request.call_args.kwargs + assert call_kwargs["operation_name"] == "anonymous"