Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixes
  • Loading branch information
RobertCraigie committed Sep 27, 2024
commit 80aba9ff3d237fd62ca50a75d6de8215a3970d39
5 changes: 4 additions & 1 deletion src/openai/_utils/_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def assert_signatures_in_sync(
check_func: Callable[..., Any],
*,
exclude_params: set[str] = set(),
description: str = "",
) -> None:
"""Ensure that the signature of the second function matches the first."""

Expand All @@ -39,4 +40,6 @@ def assert_signatures_in_sync(
continue

if errors:
raise AssertionError(f"{len(errors)} errors encountered when comparing signatures:\n\n" + "\n\n".join(errors))
raise AssertionError(
f"{len(errors)} errors encountered when comparing signatures{description}:\n\n" + "\n\n".join(errors)
)
12 changes: 6 additions & 6 deletions src/openai/resources/audio/transcriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from __future__ import annotations

from typing import TYPE_CHECKING, List, Union, Mapping, cast, overload
from typing_extensions import Literal, assert_never
from typing import TYPE_CHECKING, List, Union, Mapping, cast
from typing_extensions import Literal, overload, assert_never

import httpx

Expand Down Expand Up @@ -55,7 +55,7 @@ def create(
*,
file: FileTypes,
model: Union[str, AudioModel],
response_format: Literal["json"] | NotGiven = NOT_GIVEN,
response_format: Union[Literal["json"], NotGiven] = NOT_GIVEN,
language: str | NotGiven = NOT_GIVEN,
prompt: str | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -113,7 +113,7 @@ def create(
model: Union[str, AudioModel],
language: str | NotGiven = NOT_GIVEN,
prompt: str | NotGiven = NOT_GIVEN,
response_format: AudioResponseFormat | NotGiven = NOT_GIVEN,
response_format: Union[AudioResponseFormat, NotGiven] = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
timestamp_granularities: List[Literal["word", "segment"]] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
Expand Down Expand Up @@ -219,7 +219,7 @@ async def create(
*,
file: FileTypes,
model: Union[str, AudioModel],
response_format: Literal["json"] | NotGiven = NOT_GIVEN,
response_format: Union[Literal["json"], NotGiven] = NOT_GIVEN,
language: str | NotGiven = NOT_GIVEN,
prompt: str | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
Expand Down Expand Up @@ -277,7 +277,7 @@ async def create(
model: Union[str, AudioModel],
language: str | NotGiven = NOT_GIVEN,
prompt: str | NotGiven = NOT_GIVEN,
response_format: AudioResponseFormat | NotGiven = NOT_GIVEN,
response_format: Union[AudioResponseFormat, NotGiven] = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
timestamp_granularities: List[Literal["word", "segment"]] | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
Expand Down
8 changes: 4 additions & 4 deletions src/openai/resources/audio/translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def create(
*,
file: FileTypes,
model: Union[str, AudioModel],
response_format: Literal["json"] | NotGiven = NOT_GIVEN,
response_format: Union[Literal["json"], NotGiven] = NOT_GIVEN,
prompt: str | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
Expand Down Expand Up @@ -106,7 +106,7 @@ def create(
file: FileTypes,
model: Union[str, AudioModel],
prompt: str | NotGiven = NOT_GIVEN,
response_format: AudioResponseFormat | NotGiven = NOT_GIVEN,
response_format: Union[AudioResponseFormat, NotGiven] = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
Expand Down Expand Up @@ -198,7 +198,7 @@ async def create(
*,
file: FileTypes,
model: Union[str, AudioModel],
response_format: Literal["json"] | NotGiven = NOT_GIVEN,
response_format: Union[Literal["json"], NotGiven] = NOT_GIVEN,
prompt: str | NotGiven = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
Expand Down Expand Up @@ -249,7 +249,7 @@ async def create(
file: FileTypes,
model: Union[str, AudioModel],
prompt: str | NotGiven = NOT_GIVEN,
response_format: AudioResponseFormat | NotGiven = NOT_GIVEN,
response_format: Union[AudioResponseFormat, NotGiven] = NOT_GIVEN,
temperature: float | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
Expand Down
83 changes: 83 additions & 0 deletions tests/lib/test_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from __future__ import annotations

import sys
import inspect
import typing_extensions
from typing import get_args

import pytest

from openai import OpenAI, AsyncOpenAI
from tests.utils import evaluate_forwardref
from openai._utils import assert_signatures_in_sync
from openai._compat import is_literal_type
from openai._utils._typing import is_union_type
from openai.types.audio_response_format import AudioResponseFormat


@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
def test_translation_create_overloads_in_sync(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
checking_client: OpenAI | AsyncOpenAI = client if sync else async_client

fn = checking_client.audio.translations.create
overload_response_formats: set[str] = set()

for i, overload in enumerate(typing_extensions.get_overloads(fn)):
assert_signatures_in_sync(
fn,
overload,
exclude_params={"response_format"},
description=f" for overload {i}",
)

sig = inspect.signature(overload)
typ = evaluate_forwardref(
sig.parameters["response_format"].annotation,
globalns=sys.modules[fn.__module__].__dict__,
)
if is_union_type(typ):
for arg in get_args(typ):
if not is_literal_type(arg):
continue

overload_response_formats.update(get_args(arg))
elif is_literal_type(typ):
overload_response_formats.update(get_args(typ))

src_response_formats: set[str] = set(get_args(AudioResponseFormat))
diff = src_response_formats.difference(overload_response_formats)
assert len(diff) == 0, f"some response format options don't have overloads"


@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
def test_transcription_create_overloads_in_sync(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
checking_client: OpenAI | AsyncOpenAI = client if sync else async_client

fn = checking_client.audio.transcriptions.create
overload_response_formats: set[str] = set()

for i, overload in enumerate(typing_extensions.get_overloads(fn)):
assert_signatures_in_sync(
fn,
overload,
exclude_params={"response_format"},
description=f" for overload {i}",
)

sig = inspect.signature(overload)
typ = evaluate_forwardref(
sig.parameters["response_format"].annotation,
globalns=sys.modules[fn.__module__].__dict__,
)
if is_union_type(typ):
for arg in get_args(typ):
if not is_literal_type(arg):
continue

overload_response_formats.update(get_args(arg))
elif is_literal_type(typ):
overload_response_formats.update(get_args(typ))

src_response_formats: set[str] = set(get_args(AudioResponseFormat))
diff = src_response_formats.difference(overload_response_formats)
assert len(diff) == 0, f"some response format options don't have overloads"
6 changes: 5 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import inspect
import traceback
import contextlib
from typing import Any, TypeVar, Iterator, cast
from typing import Any, ForwardRef, TypeVar, Iterator, cast
from datetime import date, datetime
from typing_extensions import Literal, get_args, get_origin, assert_type

Expand All @@ -26,6 +26,10 @@
BaseModelT = TypeVar("BaseModelT", bound=BaseModel)


def evaluate_forwardref(forwardref: ForwardRef, globalns: dict[str, Any]) -> type:
return eval(str(forwardref), globalns)


def assert_matches_model(model: type[BaseModelT], value: BaseModelT, *, path: list[str]) -> bool:
for name, field in get_model_fields(model).items():
field_value = getattr(value, name)
Expand Down