Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/8772.misc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Extract `_observe` helper closure in `GQLMetricMiddleware.resolve()` to reduce code duplication
46 changes: 19 additions & 27 deletions src/ai/backend/manager/api/gql_legacy/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3357,40 +3357,32 @@ def resolve(
operation_name = (
info.operation.name.value if info.operation.name is not None else "anonymous"
)
start = time.perf_counter()
try:
res = next(root, info, **args)
graph_ctx.metric_observer.observe_request(
operation_type=operation_type,
field_name=field_name,
parent_type=parent_type,
operation_name=operation_name,
error_code=None,
success=True,
duration=time.perf_counter() - start,
)
except BackendAIError as e:

def _observe(*, duration: float, error: BaseException | None = None) -> None:
match error:
case None:
error_code = None
case BackendAIError():
error_code = error.error_code()
case _:
error_code = ErrorCode.default()
graph_ctx.metric_observer.observe_request(
operation_type=operation_type,
field_name=field_name,
parent_type=parent_type,
operation_name=operation_name,
error_code=e.error_code(),
success=False,
duration=time.perf_counter() - start,
error_code=error_code,
success=error is None,
duration=duration,
)
raise e

start = time.perf_counter()
try:
res = next(root, info, **args)
_observe(duration=time.perf_counter() - start)
except BaseException as e:
graph_ctx.metric_observer.observe_request(
operation_type=operation_type,
field_name=field_name,
parent_type=parent_type,
operation_name=operation_name,
error_code=ErrorCode.default(),
success=False,
duration=time.perf_counter() - start,
)
raise e
_observe(duration=time.perf_counter() - start, error=e)
raise
return res


Expand Down
135 changes: 135 additions & 0 deletions tests/unit/manager/api/test_gql_metric_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""Tests for GQLMetricMiddleware _observe helper."""

from __future__ import annotations

from typing import Any
from unittest.mock import MagicMock

import pytest
from graphql import OperationType

from ai.backend.common.exception import (
ErrorCode,
InvalidAPIParameters,
)
from ai.backend.manager.api.gql_legacy.schema import GQLMetricMiddleware


@pytest.fixture
def metric_observer() -> MagicMock:
observer = MagicMock()
observer.observe_request = MagicMock()
return observer


@pytest.fixture
def resolve_info(metric_observer: MagicMock) -> MagicMock:
info = MagicMock()
info.context.metric_observer = metric_observer
info.operation.operation = OperationType.QUERY
info.operation.name.value = "TestQuery"
info.field_name = "test_field"
info.parent_type.name = "Query"
return info


@pytest.fixture
def middleware() -> GQLMetricMiddleware:
return GQLMetricMiddleware()


class TestGQLMetricMiddlewareSyncResolver:
"""Tests for sync resolver timing in GQLMetricMiddleware."""

def test_sync_resolver_records_timing(
self,
middleware: GQLMetricMiddleware,
resolve_info: MagicMock,
metric_observer: MagicMock,
) -> None:
def sync_resolver(root: Any, info: Any, **kwargs: Any) -> str:
return "sync_result"

result = middleware.resolve(sync_resolver, None, resolve_info)

assert result == "sync_result"
metric_observer.observe_request.assert_called_once()
call_kwargs = metric_observer.observe_request.call_args.kwargs
assert call_kwargs["success"] is True
assert call_kwargs["error_code"] is None
assert call_kwargs["duration"] >= 0
assert call_kwargs["field_name"] == "test_field"
assert call_kwargs["parent_type"] == "Query"
assert call_kwargs["operation_name"] == "TestQuery"

def test_sync_resolver_backend_ai_error(
self,
middleware: GQLMetricMiddleware,
resolve_info: MagicMock,
metric_observer: MagicMock,
) -> None:
error = InvalidAPIParameters("test error")

def sync_resolver(root: Any, info: Any, **kwargs: Any) -> str:
raise error

with pytest.raises(InvalidAPIParameters):
middleware.resolve(sync_resolver, None, resolve_info)

metric_observer.observe_request.assert_called_once()
call_kwargs = metric_observer.observe_request.call_args.kwargs
assert call_kwargs["success"] is False
assert call_kwargs["error_code"] == error.error_code()

def test_sync_resolver_generic_exception(
self,
middleware: GQLMetricMiddleware,
resolve_info: MagicMock,
metric_observer: MagicMock,
) -> None:
def sync_resolver(root: Any, info: Any, **kwargs: Any) -> str:
raise RuntimeError("unexpected")

with pytest.raises(RuntimeError):
middleware.resolve(sync_resolver, None, resolve_info)

metric_observer.observe_request.assert_called_once()
call_kwargs = metric_observer.observe_request.call_args.kwargs
assert call_kwargs["success"] is False
assert call_kwargs["error_code"] == ErrorCode.default()

def test_sync_resolver_anonymous_operation(
self,
middleware: GQLMetricMiddleware,
resolve_info: MagicMock,
metric_observer: MagicMock,
) -> None:
resolve_info.operation.name = None

def sync_resolver(root: Any, info: Any, **kwargs: Any) -> str:
return "result"

middleware.resolve(sync_resolver, None, resolve_info)

call_kwargs = metric_observer.observe_request.call_args.kwargs
assert call_kwargs["operation_name"] == "anonymous"


class TestGQLMetricMiddlewareAsyncAnonymousOperation:
"""Tests for anonymous operation handling with async resolvers."""

async def test_async_resolver_anonymous_operation(
self,
middleware: GQLMetricMiddleware,
resolve_info: MagicMock,
metric_observer: MagicMock,
) -> None:
resolve_info.operation.name = None

async def async_resolver(root: Any, info: Any, **kwargs: Any) -> str:
return "result"

await middleware.resolve(async_resolver, None, resolve_info)

call_kwargs = metric_observer.observe_request.call_args.kwargs
assert call_kwargs["operation_name"] == "anonymous"
Loading