Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/magentic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ._streamed_response import StreamedResponse as StreamedResponse
from .chat_model.message import AnyMessage as AnyMessage
from .chat_model.message import AssistantMessage as AssistantMessage
from .chat_model.message import AudioBytes as AudioBytes
from .chat_model.message import DocumentBytes as DocumentBytes
from .chat_model.message import FunctionResultMessage as FunctionResultMessage
from .chat_model.message import ImageBytes as ImageBytes
Expand Down
4 changes: 2 additions & 2 deletions src/magentic/chat_model/anthropic_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def complete(
) -> AssistantMessage[OutputT]:
"""Request an LLM message."""
if output_types is None:
output_types = [] if functions else cast(list[type[OutputT]], [str])
output_types = [] if functions else cast("list[type[OutputT]]", [str])

function_schemas = get_function_schemas(functions, output_types)
tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas]
Expand Down Expand Up @@ -478,7 +478,7 @@ async def acomplete(
) -> AssistantMessage[OutputT]:
"""Async version of `complete`."""
if output_types is None:
output_types = [] if functions else cast(list[type[OutputT]], [str])
output_types = [] if functions else cast("list[type[OutputT]]", [str])

function_schemas = get_async_function_schemas(functions, output_types)
tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas]
Expand Down
28 changes: 14 additions & 14 deletions src/magentic/chat_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,23 +116,23 @@ def parse_stream(
obj = next(stream)
if isinstance(obj, StreamedStr):
if StreamedResponse in output_type_origins:
return cast(OutputT, StreamedResponse(chain([obj], stream)))
return cast("OutputT", StreamedResponse(chain([obj], stream)))
if StreamedStr in output_type_origins:
return cast(OutputT, obj)
return cast("OutputT", obj)
if str in output_type_origins:
return cast(OutputT, str(obj))
return cast("OutputT", str(obj))
raise StringNotAllowedError(obj.truncate(100))
if isinstance(obj, FunctionCall):
if StreamedResponse in output_type_origins:
return cast(OutputT, StreamedResponse(chain([obj], stream)))
return cast("OutputT", StreamedResponse(chain([obj], stream)))
if ParallelFunctionCall in output_type_origins:
return cast(OutputT, ParallelFunctionCall(chain([obj], stream)))
return cast("OutputT", ParallelFunctionCall(chain([obj], stream)))
if FunctionCall in output_type_origins:
# TODO: Check that FunctionCall type matches ?
return cast(OutputT, obj)
return cast("OutputT", obj)
raise FunctionCallNotAllowedError(obj)
if isinstance(obj, tuple(output_type_origins)):
return cast(OutputT, obj)
return cast("OutputT", obj)
raise ObjectNotAllowedError(obj)


Expand All @@ -145,27 +145,27 @@ async def aparse_stream(
if isinstance(obj, AsyncStreamedStr):
if AsyncStreamedResponse in output_type_origins:
return cast(
OutputT, AsyncStreamedResponse(achain(async_iter([obj]), stream))
"OutputT", AsyncStreamedResponse(achain(async_iter([obj]), stream))
)
if AsyncStreamedStr in output_type_origins:
return cast(OutputT, obj)
return cast("OutputT", obj)
if str in output_type_origins:
return cast(OutputT, await obj.to_string())
return cast("OutputT", await obj.to_string())
raise StringNotAllowedError(await obj.truncate(100))
if isinstance(obj, FunctionCall):
if AsyncStreamedResponse in output_type_origins:
return cast(
OutputT, AsyncStreamedResponse(achain(async_iter([obj]), stream))
"OutputT", AsyncStreamedResponse(achain(async_iter([obj]), stream))
)
if AsyncParallelFunctionCall in output_type_origins:
return cast(
OutputT, AsyncParallelFunctionCall(achain(async_iter([obj]), stream))
"OutputT", AsyncParallelFunctionCall(achain(async_iter([obj]), stream))
)
if FunctionCall in output_type_origins:
return cast(OutputT, obj)
return cast("OutputT", obj)
raise FunctionCallNotAllowedError(obj)
if isinstance(obj, tuple(output_type_origins)):
return cast(OutputT, obj)
return cast("OutputT", obj)
raise ObjectNotAllowedError(obj)


Expand Down
20 changes: 12 additions & 8 deletions src/magentic/chat_model/function_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,11 @@ def parameters(self) -> dict[str, Any]:

@property
def strict(self) -> bool | None:
return cast(ConfigDict, self._model.model_config).get("openai_strict")
return cast("ConfigDict", self._model.model_config).get("openai_strict")

def parse_args(self, chunks: Iterable[str]) -> T:
args_json = "".join(chunks)
return cast(T, self._model.model_validate_json(args_json).value) # type: ignore[attr-defined]
return cast("T", self._model.model_validate_json(args_json).value) # type: ignore[attr-defined]

def serialize_args(self, value: T) -> str:
return self._model.model_construct(value=value).model_dump_json()
Expand All @@ -211,6 +211,7 @@ def __init__(self, output_type: type[IterableT]):
"Output",
__config__=get_pydantic_config(output_type),
value=(output_type, ...),
__module__="pydantic.main",
)

@property
Expand All @@ -223,14 +224,17 @@ def parameters(self) -> dict[str, Any]:

@property
def strict(self) -> bool | None:
return cast(ConfigDict, self._model.model_config).get("openai_strict")
return cast("ConfigDict", self._model.model_config).get("openai_strict")

def parse_args(self, chunks: Iterable[str]) -> IterableT:
iter_items = (
self._item_type_adapter.validate_json(item)
for item in iter_streamed_json_array(chunks)
)
return cast(IterableT, self._model.model_validate({"value": iter_items}).value) # type: ignore[attr-defined]
# Use a type annotation to tell mypy what's going on
validated = self._model.model_validate({"value": iter_items})
# Pydantic model will have a value field based on how we constructed it
return cast("IterableT", validated.value) # type: ignore[attr-defined]

def serialize_args(self, value: IterableT) -> str:
return self._model.model_construct(value=value).model_dump_json()
Expand Down Expand Up @@ -266,7 +270,7 @@ def parameters(self) -> dict[str, Any]:

@property
def strict(self) -> bool | None:
return cast(ConfigDict, self._model.model_config).get("openai_strict")
return cast("ConfigDict", self._model.model_config).get("openai_strict")

async def aparse_args(self, chunks: AsyncIterable[str]) -> AsyncIterableT:
aiter_items = (
Expand All @@ -277,7 +281,7 @@ async def aparse_args(self, chunks: AsyncIterable[str]) -> AsyncIterableT:
typing.AsyncIterable,
typing.AsyncIterator,
) or is_origin_abstract(self._output_type):
return cast(AsyncIterableT, aiter_items)
return cast("AsyncIterableT", aiter_items)

raise NotImplementedError

Expand Down Expand Up @@ -335,7 +339,7 @@ def parameters(self) -> dict[str, Any]:

@property
def strict(self) -> bool | None:
return cast(ConfigDict, self._model.model_config).get("openai_strict")
return cast("ConfigDict", self._model.model_config).get("openai_strict")

def parse_args(self, chunks: Iterable[str]) -> BaseModelT:
args_json = "".join(chunks)
Expand Down Expand Up @@ -400,7 +404,7 @@ def parameters(self) -> dict[str, Any]:

@property
def strict(self) -> bool | None:
return cast(ConfigDict, self._model.model_config).get("openai_strict")
return cast("ConfigDict", self._model.model_config).get("openai_strict")

def parse_args(self, chunks: Iterable[str]) -> FunctionCall[T]:
# Anthropic message stream returns empty string for function call with no arguments
Expand Down
6 changes: 3 additions & 3 deletions src/magentic/chat_model/litellm_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def update(self, item: ModelResponse) -> None:
assert isinstance(item.choices[0], StreamingChoices)
item.choices[0].delta.refusal = None # type: ignore[attr-defined]
self._chat_completion_stream_state.handle_chunk(item) # type: ignore[arg-type]
usage = cast(litellm.Usage, item.usage) # type: ignore[attr-defined,name-defined]
usage = cast("litellm.Usage", item.usage) # type: ignore[attr-defined,name-defined]
# Ignore usages with 0 tokens
if usage and usage.prompt_tokens and usage.completion_tokens:
assert not self.usage_ref
Expand Down Expand Up @@ -172,7 +172,7 @@ def complete(
) -> AssistantMessage[OutputT]:
"""Request an LLM message."""
if output_types is None:
output_types = cast(Iterable[type[OutputT]], [] if functions else [str])
output_types = cast("Iterable[type[OutputT]]", [] if functions else [str])

function_schemas = get_function_schemas(functions, output_types)
tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas]
Expand Down Expand Up @@ -213,7 +213,7 @@ async def acomplete(
) -> AssistantMessage[OutputT]:
"""Async version of `complete`."""
if output_types is None:
output_types = cast(Iterable[type[OutputT]], [] if functions else [str])
output_types = cast("Iterable[type[OutputT]]", [] if functions else [str])

function_schemas = get_async_function_schemas(functions, output_types)
tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas]
Expand Down
36 changes: 33 additions & 3 deletions src/magentic/chat_model/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class DocumentBytes(RootModel[bytes]):
def mime_type(self) -> DocumentMimeType:
mimetype: str | None = filetype.guess_mime(self.root)
assert mimetype in _DOCUMENT_MIME_TYPES
return cast(DocumentMimeType, mimetype)
return cast("DocumentMimeType", mimetype)

def __init__(self, root: bytes, **data: Any):
super().__init__(root=root, **data)
Expand Down Expand Up @@ -175,7 +175,7 @@ class ImageBytes(RootModel[bytes]):
def mime_type(self) -> ImageMimeType:
mimetype: str | None = filetype.guess_mime(self.root)
assert mimetype in _IMAGE_MIME_TYPES
return cast(ImageMimeType, mimetype)
return cast("ImageMimeType", mimetype)

def as_base64(self) -> str:
return base64.b64encode(self.root).decode("utf-8")
Expand All @@ -193,6 +193,36 @@ def _is_image_bytes(self) -> Self:
return self


# OpenAI supports audio inputs in WAV format
AudioMimeType = Literal["audio/wav", "audio/x-wav"]
_AUDIO_MIME_TYPES: tuple[AudioMimeType, ...] = get_args(AudioMimeType)


class AudioBytes(RootModel[bytes]):
"""Bytes representing an audio file."""

@cached_property
def mime_type(self) -> AudioMimeType:
mimetype: str | None = filetype.guess_mime(self.root)
assert mimetype in _AUDIO_MIME_TYPES
return cast("AudioMimeType", mimetype)

def as_base64(self) -> str:
return base64.b64encode(self.root).decode("utf-8")

def format(self, **kwargs: Any) -> Self:
del kwargs
return self

@model_validator(mode="after")
def _is_audio_bytes(self) -> Self:
mimetype: str | None = filetype.guess_mime(self.root)
if mimetype not in _AUDIO_MIME_TYPES:
msg = f"Unsupported audio MIME type: {mimetype!r}"
raise ValueError(msg)
return self


class ImageUrl(RootModel[str]):
"""String representing a URL to an image."""

Expand All @@ -201,7 +231,7 @@ def format(self, **kwargs: Any) -> Self:
return self


UserMessageContentBlock: TypeAlias = DocumentBytes | ImageBytes | ImageUrl
UserMessageContentBlock: TypeAlias = DocumentBytes | ImageBytes | ImageUrl | AudioBytes
UserMessageContentT = TypeVar(
"UserMessageContentT",
bound=str
Expand Down
14 changes: 11 additions & 3 deletions src/magentic/chat_model/openai_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
)
from magentic.chat_model.message import (
AssistantMessage,
AudioBytes,
ImageBytes,
ImageUrl,
Message,
Expand Down Expand Up @@ -82,7 +83,7 @@ def _(message: _RawMessage[Any]) -> ChatCompletionMessageParam:
assert isinstance(message.content, dict)
assert "role" in message.content
assert "content" in message.content
return cast(ChatCompletionMessageParam, message.content)
return cast("ChatCompletionMessageParam", message.content)


@message_to_openai_message.register
Expand Down Expand Up @@ -110,6 +111,13 @@ def _(message: UserMessage[Any]) -> ChatCompletionUserMessageParam:
)
elif isinstance(block, ImageUrl):
content.append({"type": "image_url", "image_url": {"url": block.root}})
elif isinstance(block, AudioBytes):
content.append(
{
"type": "input_audio",
"input_audio": {"data": block.as_base64(), "format": "wav"},
}
)
else:
msg = f"Invalid block type: {type(block)}"
raise TypeError(msg)
Expand Down Expand Up @@ -479,7 +487,7 @@ def complete(
) -> AssistantMessage[OutputT]:
"""Request an LLM message."""
if output_types is None:
output_types = cast(Iterable[type[OutputT]], [] if functions else [str])
output_types = cast("Iterable[type[OutputT]]", [] if functions else [str])

function_schemas = get_function_schemas(functions, output_types)
tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas]
Expand Down Expand Up @@ -523,7 +531,7 @@ async def acomplete(
) -> AssistantMessage[OutputT]:
"""Async version of `complete`."""
if output_types is None:
output_types = [] if functions else cast(list[type[OutputT]], [str])
output_types = [] if functions else cast("list[type[OutputT]]", [str])

function_schemas = get_async_function_schemas(functions, output_types)
tool_schemas = [BaseFunctionToolSchema(schema) for schema in function_schemas]
Expand Down
6 changes: 3 additions & 3 deletions src/magentic/chatprompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def decorator(
model=model,
)
return cast(
AsyncChatPromptFunction[P, R],
"AsyncChatPromptFunction[P, R]",
update_wrapper(async_prompt_function, func),
)

Expand All @@ -205,6 +205,6 @@ def decorator(
max_retries=max_retries,
model=model,
)
return cast(ChatPromptFunction[P, R], update_wrapper(prompt_function, func))
return cast("ChatPromptFunction[P, R]", update_wrapper(prompt_function, func))

return cast(ChatPromptDecorator, decorator)
return cast("ChatPromptDecorator", decorator)
2 changes: 1 addition & 1 deletion src/magentic/function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ async def __call__(self) -> tuple[T, ...]:
if inspect.iscoroutine(result):
tasks_and_results.append(asyncio.create_task(result))
else:
result = cast(T, result)
result = cast("T", result)
tasks_and_results.append(result)

tasks = [
Expand Down
4 changes: 2 additions & 2 deletions src/magentic/prompt_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def awrapper(*args: P.args, **kwargs: P.kwargs) -> Any:
num_calls += 1
return chat.last_message.content

return cast(Callable[P, R], awrapper)
return cast("Callable[P, R]", awrapper)

prompt_function = ChatPromptFunction[P, R](
name=func.__name__,
Expand Down Expand Up @@ -113,7 +113,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
raise MaxFunctionCallsError(msg)
chat = chat.exec_function_call().submit()
num_calls += 1
return cast(R, chat.last_message.content)
return cast("R", chat.last_message.content)

return wrapper

Expand Down
7 changes: 3 additions & 4 deletions src/magentic/prompt_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,7 @@ def decorator(
model=model,
)
return cast(
AsyncPromptFunction[P, R],
update_wrapper(async_prompt_function, func),
"AsyncPromptFunction[P, R]", update_wrapper(async_prompt_function, func)
)

prompt_function = PromptFunction[P, R](
Expand All @@ -179,6 +178,6 @@ def decorator(
max_retries=max_retries,
model=model,
)
return cast(PromptFunction[P, R], update_wrapper(prompt_function, func))
return cast("PromptFunction[P, R]", update_wrapper(prompt_function, func))

return cast(PromptDecorator, decorator)
return cast("PromptDecorator", decorator)
Loading