Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions python/packages/a2a/tests/test_a2a_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1284,13 +1284,11 @@ async def test_streaming_artifact_update_event_does_not_duplicate_terminal_task_
final=True,
)

mock_a2a_client.responses.extend(
[
(working_task, first_chunk),
(working_task, second_chunk),
(terminal_task, terminal_event),
]
)
mock_a2a_client.responses.extend([
(working_task, first_chunk),
(working_task, second_chunk),
(terminal_task, terminal_event),
])

stream = a2a_agent.run("Hello", stream=True)
updates: list[AgentResponseUpdate] = []
Expand Down Expand Up @@ -1371,12 +1369,10 @@ async def test_streaming_terminal_task_only_emits_unstreamed_artifacts(
final=True,
)

mock_a2a_client.responses.extend(
[
(working_task, streamed_chunk),
(terminal_task, terminal_event),
]
)
mock_a2a_client.responses.extend([
(working_task, streamed_chunk),
(terminal_task, terminal_event),
])

stream = a2a_agent.run("Hello", stream=True)
updates: list[AgentResponseUpdate] = []
Expand Down
37 changes: 37 additions & 0 deletions python/packages/openai/agent_framework_openai/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,46 @@ def load_openai_service_settings(
"Azure OpenAI client requires either an API key or an Azure AD token provider."
" This can be provided either as a callable api_key or via the credential parameter."
)

# The /openai/v1 endpoint exposes an OpenAI-compatible API surface.
# AsyncAzureOpenAI rewrites certain request paths (e.g. /embeddings,
# /chat/completions) by inserting /deployments/{model}/, which produces
# 404s on this endpoint. Use AsyncOpenAI instead so request URLs are
# sent as-is. responses_mode is excluded because the Responses API path
# (/responses) is not rewritten by the Azure SDK.
resolved_base_url = client_args.get("base_url", "")
if not responses_mode and resolved_base_url and resolved_base_url.rstrip("/").endswith("/openai/v1"):
openai_args: dict[str, Any] = {
"base_url": resolved_base_url,
"default_headers": client_args.get("default_headers"),
}
if "azure_ad_token_provider" in client_args:
openai_args["api_key"] = _ensure_async_token_provider(client_args["azure_ad_token_provider"])
elif "api_key" in client_args:
openai_args["api_key"] = client_args["api_key"]
return azure_settings, AsyncOpenAI(**openai_args), True # type: ignore[return-value]

return azure_settings, AsyncAzureOpenAI(**client_args), True # type: ignore[return-value]


def _ensure_async_token_provider(
provider: AzureTokenProvider,
) -> Callable[[], Awaitable[str]]:
"""Wrap a (possibly synchronous) token provider so it always returns an awaitable.

``AsyncOpenAI`` requires callable ``api_key`` values to return ``Awaitable[str]``.
Azure token providers may return a plain ``str``, so this normalises them.
"""

async def _wrapper() -> str:
result = provider()
if isinstance(result, str):
return result
return await result

return _wrapper


def _resolve_azure_credential_to_token_provider(
credential: AzureCredentialTypes | AzureTokenProvider,
) -> AzureTokenProvider:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from agent_framework.exceptions import SettingNotFoundError
from azure.core.credentials_async import AsyncTokenCredential
from azure.identity.aio import AzureCliCredential
from openai import AsyncAzureOpenAI
from openai import AsyncAzureOpenAI, AsyncOpenAI

from agent_framework_openai import OpenAIEmbeddingClient, OpenAIEmbeddingOptions

Expand Down Expand Up @@ -196,6 +196,78 @@ def test_openai_base_url_wins_over_azure_aliases(monkeypatch, azure_openai_unit_
assert client.azure_endpoint is None


def test_init_with_openai_v1_base_url_and_credential_uses_openai_client(monkeypatch) -> None:
for env in [
"OPENAI_API_KEY",
"OPENAI_ORG_ID",
"OPENAI_MODEL",
"OPENAI_EMBEDDING_MODEL",
"OPENAI_BASE_URL",
"AZURE_OPENAI_ENDPOINT",
"AZURE_OPENAI_BASE_URL",
"AZURE_OPENAI_API_KEY",
"AZURE_OPENAI_EMBEDDING_MODEL",
"AZURE_OPENAI_MODEL",
"AZURE_OPENAI_API_VERSION",
"AZURE_OPENAI_CHAT_MODEL",
"AZURE_OPENAI_CHAT_COMPLETION_MODEL",
]:
monkeypatch.delenv(env, raising=False)

client = OpenAIEmbeddingClient(
base_url="https://myproject.openai.azure.com/openai/v1/",
model="text-embedding-3-large",
credential=lambda: "fake-token",
)

assert client.model == "text-embedding-3-large"
assert not isinstance(client.client, AsyncAzureOpenAI)
assert isinstance(client.client, AsyncOpenAI)
assert client.OTEL_PROVIDER_NAME == "azure.ai.openai"
assert str(client.client.base_url).rstrip("/").endswith("/openai/v1")


def test_init_with_openai_v1_base_url_and_api_key_uses_openai_client(monkeypatch) -> None:
for env in [
"OPENAI_API_KEY",
"OPENAI_ORG_ID",
"OPENAI_MODEL",
"OPENAI_EMBEDDING_MODEL",
"OPENAI_BASE_URL",
"AZURE_OPENAI_ENDPOINT",
"AZURE_OPENAI_BASE_URL",
"AZURE_OPENAI_API_KEY",
"AZURE_OPENAI_EMBEDDING_MODEL",
"AZURE_OPENAI_MODEL",
"AZURE_OPENAI_API_VERSION",
"AZURE_OPENAI_CHAT_MODEL",
"AZURE_OPENAI_CHAT_COMPLETION_MODEL",
]:
monkeypatch.delenv(env, raising=False)

# AZURE_OPENAI_BASE_URL + AZURE_OPENAI_API_KEY enter the Azure settings
# path without an explicit endpoint parameter; the /openai/v1 suffix
# should still produce AsyncOpenAI (not AsyncAzureOpenAI).
monkeypatch.setenv("AZURE_OPENAI_BASE_URL", "https://myproject.openai.azure.com/openai/v1/")
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "test-api-key")

client = OpenAIEmbeddingClient(model="text-embedding-3-large")

assert client.model == "text-embedding-3-large"
assert not isinstance(client.client, AsyncAzureOpenAI)
assert isinstance(client.client, AsyncOpenAI)
assert str(client.client.base_url).rstrip("/").endswith("/openai/v1")


def test_init_with_azure_endpoint_still_uses_azure_client(azure_openai_unit_test_env: dict[str, str]) -> None:
client = OpenAIEmbeddingClient(
azure_endpoint=azure_openai_unit_test_env["AZURE_OPENAI_ENDPOINT"],
api_key=azure_openai_unit_test_env["AZURE_OPENAI_API_KEY"],
)

assert isinstance(client.client, AsyncAzureOpenAI)


@pytest.mark.flaky
@pytest.mark.integration
@skip_if_azure_openai_integration_tests_disabled
Expand Down
26 changes: 25 additions & 1 deletion python/packages/openai/tests/openai/test_openai_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
from azure.core.credentials import TokenCredential
from azure.core.credentials_async import AsyncTokenCredential

from agent_framework_openai._shared import AZURE_OPENAI_TOKEN_SCOPE, _resolve_azure_credential_to_token_provider
from agent_framework_openai._shared import (
AZURE_OPENAI_TOKEN_SCOPE,
_ensure_async_token_provider,
_resolve_azure_credential_to_token_provider,
)


class _AsyncTokenCredentialStub(AsyncTokenCredential):
Expand Down Expand Up @@ -52,3 +56,23 @@ def test_resolve_azure_callable_token_provider_passthrough() -> None:
def test_resolve_azure_invalid_credential_raises() -> None:
with pytest.raises(ValueError, match="credential"):
_resolve_azure_credential_to_token_provider(object()) # type: ignore[arg-type]


async def test_ensure_async_token_provider_wraps_sync_provider() -> None:
def sync_provider() -> str:
return "sync-token"

wrapper = _ensure_async_token_provider(sync_provider)
result = await wrapper()

assert result == "sync-token"


async def test_ensure_async_token_provider_wraps_async_provider() -> None:
async def async_provider() -> str:
return "async-token"

wrapper = _ensure_async_token_provider(async_provider)
result = await wrapper()

assert result == "async-token"
Loading