Skip to content

Commit

Permalink
test: capture openai's param handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Fedir Zadniprovskyi authored and fedirz committed Sep 21, 2024
1 parent 2a099cf commit 4cb006f
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 3 deletions.
5 changes: 3 additions & 2 deletions src/faster_whisper_server/routers/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
from io import BytesIO
import logging
from typing import TYPE_CHECKING, Annotated, Literal
from typing import TYPE_CHECKING, Annotated

from fastapi import (
APIRouter,
Expand All @@ -30,6 +30,7 @@
from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config
from faster_whisper_server.server_models import (
TimestampGranularities,
TranscriptionJsonResponse,
TranscriptionVerboseJsonResponse,
)
Expand Down Expand Up @@ -165,7 +166,7 @@ def transcribe_file(
response_format: Annotated[ResponseFormat | None, Form()] = None,
temperature: Annotated[float, Form()] = 0.0,
timestamp_granularities: Annotated[
list[Literal["segment", "word"]],
TimestampGranularities,
Form(alias="timestamp_granularities[]"),
] = ["segment"],
stream: Annotated[bool, Form()] = False,
Expand Down
12 changes: 12 additions & 0 deletions src/faster_whisper_server/server_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,15 @@ class ModelObject(BaseModel):
]
},
)


TimestampGranularities = list[Literal["segment", "word"]]


TIMESTAMP_GRANULARITIES_COMBINATIONS: list[TimestampGranularities] = [
[], # should be treated as ["segment"]. https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities
["segment"],
["word"],
["word", "segment"],
["segment", "word"], # same as ["word", "segment"] but order is different
]
9 changes: 8 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fastapi.testclient import TestClient
from faster_whisper_server.main import create_app
from httpx import ASGITransport, AsyncClient
from openai import OpenAI
from openai import AsyncOpenAI, OpenAI
import pytest
import pytest_asyncio

Expand Down Expand Up @@ -35,3 +35,10 @@ async def aclient() -> AsyncGenerator[AsyncClient, None]:
@pytest.fixture()
def openai_client(client: TestClient) -> OpenAI:
return OpenAI(api_key="cant-be-empty", http_client=client)


@pytest.fixture()
def actual_openai_client() -> AsyncOpenAI:
return AsyncOpenAI(
base_url="https://api.openai.com/v1"
) # `base_url` is provided in case `OPENAI_API_BASE_URL` is set to a different value
56 changes: 56 additions & 0 deletions tests/openai_timestamp_granularities_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""OpenAI's handling of `response_format` and `timestamp_granularities` is a bit confusing and inconsistent. This test module exists to capture the OpenAI API's behavior with respect to these parameters.""" # noqa: E501

from faster_whisper_server.server_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities
from openai import AsyncOpenAI, BadRequestError
import pytest


@pytest.mark.asyncio()
@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS)
async def test_openai_json_response_format_and_timestamp_granularities_combinations(
actual_openai_client: AsyncOpenAI,
timestamp_granularities: TimestampGranularities,
) -> None:
audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230

if "word" in timestamp_granularities:
with pytest.raises(BadRequestError):
await actual_openai_client.audio.transcriptions.create(
file=audio_file,
model="whisper-1",
response_format="json",
timestamp_granularities=timestamp_granularities,
)
else:
await actual_openai_client.audio.transcriptions.create(
file=audio_file, model="whisper-1", response_format="json", timestamp_granularities=timestamp_granularities
)


@pytest.mark.asyncio()
@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS)
async def test_openai_verbose_json_response_format_and_timestamp_granularities_combinations(
actual_openai_client: AsyncOpenAI,
timestamp_granularities: TimestampGranularities,
) -> None:
audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230

transcription = await actual_openai_client.audio.transcriptions.create(
file=audio_file,
model="whisper-1",
response_format="verbose_json",
timestamp_granularities=timestamp_granularities,
)

assert transcription.__pydantic_extra__
if timestamp_granularities == ["word"]:
# This is an exception where segments are not present
assert transcription.__pydantic_extra__.get("segments") is None
assert transcription.__pydantic_extra__.get("words") is not None
elif "word" in timestamp_granularities:
assert transcription.__pydantic_extra__.get("segments") is not None
assert transcription.__pydantic_extra__.get("words") is not None
else:
# Unless explicitly requested, words are not present
assert transcription.__pydantic_extra__.get("segments") is not None
assert transcription.__pydantic_extra__.get("words") is None

0 comments on commit 4cb006f

Please sign in to comment.