diff --git a/python/packages/a2a/tests/test_a2a_agent.py b/python/packages/a2a/tests/test_a2a_agent.py index a0919cbda4..442960a7ee 100644 --- a/python/packages/a2a/tests/test_a2a_agent.py +++ b/python/packages/a2a/tests/test_a2a_agent.py @@ -1284,13 +1284,11 @@ async def test_streaming_artifact_update_event_does_not_duplicate_terminal_task_ final=True, ) - mock_a2a_client.responses.extend( - [ - (working_task, first_chunk), - (working_task, second_chunk), - (terminal_task, terminal_event), - ] - ) + mock_a2a_client.responses.extend([ + (working_task, first_chunk), + (working_task, second_chunk), + (terminal_task, terminal_event), + ]) stream = a2a_agent.run("Hello", stream=True) updates: list[AgentResponseUpdate] = [] @@ -1371,12 +1369,10 @@ async def test_streaming_terminal_task_only_emits_unstreamed_artifacts( final=True, ) - mock_a2a_client.responses.extend( - [ - (working_task, streamed_chunk), - (terminal_task, terminal_event), - ] - ) + mock_a2a_client.responses.extend([ + (working_task, streamed_chunk), + (terminal_task, terminal_event), + ]) stream = a2a_agent.run("Hello", stream=True) updates: list[AgentResponseUpdate] = [] diff --git a/python/packages/openai/agent_framework_openai/_shared.py b/python/packages/openai/agent_framework_openai/_shared.py index f1d0728f61..7fb12ad14e 100644 --- a/python/packages/openai/agent_framework_openai/_shared.py +++ b/python/packages/openai/agent_framework_openai/_shared.py @@ -282,9 +282,46 @@ def load_openai_service_settings( "Azure OpenAI client requires either an API key or an Azure AD token provider." " This can be provided either as a callable api_key or via the credential parameter." ) + + # The /openai/v1 endpoint exposes an OpenAI-compatible API surface. + # AsyncAzureOpenAI rewrites certain request paths (e.g. /embeddings, + # /chat/completions) by inserting /deployments/{model}/, which produces + # 404s on this endpoint. Use AsyncOpenAI instead so request URLs are + # sent as-is. responses_mode is excluded because the Responses API path + # (/responses) is not rewritten by the Azure SDK. + resolved_base_url = client_args.get("base_url", "") + if not responses_mode and resolved_base_url and resolved_base_url.rstrip("/").endswith("/openai/v1"): + openai_args: dict[str, Any] = { + "base_url": resolved_base_url, + "default_headers": client_args.get("default_headers"), + } + if "azure_ad_token_provider" in client_args: + openai_args["api_key"] = _ensure_async_token_provider(client_args["azure_ad_token_provider"]) + elif "api_key" in client_args: + openai_args["api_key"] = client_args["api_key"] + return azure_settings, AsyncOpenAI(**openai_args), True # type: ignore[return-value] + return azure_settings, AsyncAzureOpenAI(**client_args), True # type: ignore[return-value] +def _ensure_async_token_provider( + provider: AzureTokenProvider, +) -> Callable[[], Awaitable[str]]: + """Wrap a (possibly synchronous) token provider so it always returns an awaitable. + + ``AsyncOpenAI`` requires callable ``api_key`` values to return ``Awaitable[str]``. + Azure token providers may return a plain ``str``, so this normalises them. + """ + + async def _wrapper() -> str: + result = provider() + if isinstance(result, str): + return result + return await result + + return _wrapper + + def _resolve_azure_credential_to_token_provider( credential: AzureCredentialTypes | AzureTokenProvider, ) -> AzureTokenProvider: diff --git a/python/packages/openai/tests/openai/test_openai_embedding_client_azure.py b/python/packages/openai/tests/openai/test_openai_embedding_client_azure.py index 2d3c457bf6..4e7a584874 100644 --- a/python/packages/openai/tests/openai/test_openai_embedding_client_azure.py +++ b/python/packages/openai/tests/openai/test_openai_embedding_client_azure.py @@ -11,7 +11,7 @@ from agent_framework.exceptions import SettingNotFoundError from azure.core.credentials_async import AsyncTokenCredential from azure.identity.aio import AzureCliCredential -from openai import AsyncAzureOpenAI +from openai import AsyncAzureOpenAI, AsyncOpenAI from agent_framework_openai import OpenAIEmbeddingClient, OpenAIEmbeddingOptions @@ -196,6 +196,78 @@ def test_openai_base_url_wins_over_azure_aliases(monkeypatch, azure_openai_unit_ assert client.azure_endpoint is None +def test_init_with_openai_v1_base_url_and_credential_uses_openai_client(monkeypatch) -> None: + for env in [ + "OPENAI_API_KEY", + "OPENAI_ORG_ID", + "OPENAI_MODEL", + "OPENAI_EMBEDDING_MODEL", + "OPENAI_BASE_URL", + "AZURE_OPENAI_ENDPOINT", + "AZURE_OPENAI_BASE_URL", + "AZURE_OPENAI_API_KEY", + "AZURE_OPENAI_EMBEDDING_MODEL", + "AZURE_OPENAI_MODEL", + "AZURE_OPENAI_API_VERSION", + "AZURE_OPENAI_CHAT_MODEL", + "AZURE_OPENAI_CHAT_COMPLETION_MODEL", + ]: + monkeypatch.delenv(env, raising=False) + + client = OpenAIEmbeddingClient( + base_url="https://myproject.openai.azure.com/openai/v1/", + model="text-embedding-3-large", + credential=lambda: "fake-token", + ) + + assert client.model == "text-embedding-3-large" + assert not isinstance(client.client, AsyncAzureOpenAI) + assert isinstance(client.client, AsyncOpenAI) + assert client.OTEL_PROVIDER_NAME == "azure.ai.openai" + assert str(client.client.base_url).rstrip("/").endswith("/openai/v1") + + +def test_init_with_openai_v1_base_url_and_api_key_uses_openai_client(monkeypatch) -> None: + for env in [ + "OPENAI_API_KEY", + "OPENAI_ORG_ID", + "OPENAI_MODEL", + "OPENAI_EMBEDDING_MODEL", + "OPENAI_BASE_URL", + "AZURE_OPENAI_ENDPOINT", + "AZURE_OPENAI_BASE_URL", + "AZURE_OPENAI_API_KEY", + "AZURE_OPENAI_EMBEDDING_MODEL", + "AZURE_OPENAI_MODEL", + "AZURE_OPENAI_API_VERSION", + "AZURE_OPENAI_CHAT_MODEL", + "AZURE_OPENAI_CHAT_COMPLETION_MODEL", + ]: + monkeypatch.delenv(env, raising=False) + + # AZURE_OPENAI_BASE_URL + AZURE_OPENAI_API_KEY enter the Azure settings + # path without an explicit endpoint parameter; the /openai/v1 suffix + # should still produce AsyncOpenAI (not AsyncAzureOpenAI). + monkeypatch.setenv("AZURE_OPENAI_BASE_URL", "https://myproject.openai.azure.com/openai/v1/") + monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key") + + client = OpenAIEmbeddingClient(model="text-embedding-3-large") + + assert client.model == "text-embedding-3-large" + assert not isinstance(client.client, AsyncAzureOpenAI) + assert isinstance(client.client, AsyncOpenAI) + assert str(client.client.base_url).rstrip("/").endswith("/openai/v1") + + +def test_init_with_azure_endpoint_still_uses_azure_client(azure_openai_unit_test_env: dict[str, str]) -> None: + client = OpenAIEmbeddingClient( + azure_endpoint=azure_openai_unit_test_env["AZURE_OPENAI_ENDPOINT"], + api_key=azure_openai_unit_test_env["AZURE_OPENAI_API_KEY"], + ) + + assert isinstance(client.client, AsyncAzureOpenAI) + + @pytest.mark.flaky @pytest.mark.integration @skip_if_azure_openai_integration_tests_disabled diff --git a/python/packages/openai/tests/openai/test_openai_shared.py b/python/packages/openai/tests/openai/test_openai_shared.py index b69feb7314..86d43bc43b 100644 --- a/python/packages/openai/tests/openai/test_openai_shared.py +++ b/python/packages/openai/tests/openai/test_openai_shared.py @@ -8,7 +8,11 @@ from azure.core.credentials import TokenCredential from azure.core.credentials_async import AsyncTokenCredential -from agent_framework_openai._shared import AZURE_OPENAI_TOKEN_SCOPE, _resolve_azure_credential_to_token_provider +from agent_framework_openai._shared import ( + AZURE_OPENAI_TOKEN_SCOPE, + _ensure_async_token_provider, + _resolve_azure_credential_to_token_provider, +) class _AsyncTokenCredentialStub(AsyncTokenCredential): @@ -52,3 +56,23 @@ def test_resolve_azure_callable_token_provider_passthrough() -> None: def test_resolve_azure_invalid_credential_raises() -> None: with pytest.raises(ValueError, match="credential"): _resolve_azure_credential_to_token_provider(object()) # type: ignore[arg-type] + + +async def test_ensure_async_token_provider_wraps_sync_provider() -> None: + def sync_provider() -> str: + return "sync-token" + + wrapper = _ensure_async_token_provider(sync_provider) + result = await wrapper() + + assert result == "sync-token" + + +async def test_ensure_async_token_provider_wraps_async_provider() -> None: + async def async_provider() -> str: + return "async-token" + + wrapper = _ensure_async_token_provider(async_provider) + result = await wrapper() + + assert result == "async-token"