From 4cb006f2e64e2bed8797820cf696d42f3f6dd3ee Mon Sep 17 00:00:00 2001 From: Fedir Zadniprovskyi Date: Sat, 21 Sep 2024 10:47:31 -0700 Subject: [PATCH] test: capture openai's param handling --- src/faster_whisper_server/routers/stt.py | 5 +- src/faster_whisper_server/server_models.py | 12 +++++ tests/conftest.py | 9 +++- tests/openai_timestamp_granularities_test.py | 56 ++++++++++++++++++++ 4 files changed, 79 insertions(+), 3 deletions(-) create mode 100644 tests/openai_timestamp_granularities_test.py diff --git a/src/faster_whisper_server/routers/stt.py b/src/faster_whisper_server/routers/stt.py index 9a3f5a3a..f55bf0d9 100644 --- a/src/faster_whisper_server/routers/stt.py +++ b/src/faster_whisper_server/routers/stt.py @@ -3,7 +3,7 @@ import asyncio from io import BytesIO import logging -from typing import TYPE_CHECKING, Annotated, Literal +from typing import TYPE_CHECKING, Annotated from fastapi import ( APIRouter, @@ -30,6 +30,7 @@ from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config from faster_whisper_server.server_models import ( + TimestampGranularities, TranscriptionJsonResponse, TranscriptionVerboseJsonResponse, ) @@ -165,7 +166,7 @@ def transcribe_file( response_format: Annotated[ResponseFormat | None, Form()] = None, temperature: Annotated[float, Form()] = 0.0, timestamp_granularities: Annotated[ - list[Literal["segment", "word"]], + TimestampGranularities, Form(alias="timestamp_granularities[]"), ] = ["segment"], stream: Annotated[bool, Form()] = False, diff --git a/src/faster_whisper_server/server_models.py b/src/faster_whisper_server/server_models.py index 2b23bfba..33d17206 100644 --- a/src/faster_whisper_server/server_models.py +++ b/src/faster_whisper_server/server_models.py @@ -107,3 +107,15 @@ class ModelObject(BaseModel): ] }, ) + + +TimestampGranularities = list[Literal["segment", "word"]] + + +TIMESTAMP_GRANULARITIES_COMBINATIONS: list[TimestampGranularities] = [ + [], # should be treated as ["segment"]. https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities + ["segment"], + ["word"], + ["word", "segment"], + ["segment", "word"], # same as ["word", "segment"] but order is different +] diff --git a/tests/conftest.py b/tests/conftest.py index dc796737..658f20bd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ from fastapi.testclient import TestClient from faster_whisper_server.main import create_app from httpx import ASGITransport, AsyncClient -from openai import OpenAI +from openai import AsyncOpenAI, OpenAI import pytest import pytest_asyncio @@ -35,3 +35,10 @@ async def aclient() -> AsyncGenerator[AsyncClient, None]: @pytest.fixture() def openai_client(client: TestClient) -> OpenAI: return OpenAI(api_key="cant-be-empty", http_client=client) + + +@pytest.fixture() +def actual_openai_client() -> AsyncOpenAI: + return AsyncOpenAI( + base_url="https://api.openai.com/v1" + ) # `base_url` is provided in case `OPENAI_API_BASE_URL` is set to a different value diff --git a/tests/openai_timestamp_granularities_test.py b/tests/openai_timestamp_granularities_test.py new file mode 100644 index 00000000..65643365 --- /dev/null +++ b/tests/openai_timestamp_granularities_test.py @@ -0,0 +1,56 @@ +"""OpenAI's handling of `response_format` and `timestamp_granularities` is a bit confusing and inconsistent. This test module exists to capture the OpenAI API's behavior with respect to these parameters.""" # noqa: E501 + +from faster_whisper_server.server_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities +from openai import AsyncOpenAI, BadRequestError +import pytest + + +@pytest.mark.asyncio() +@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS) +async def test_openai_json_response_format_and_timestamp_granularities_combinations( + actual_openai_client: AsyncOpenAI, + timestamp_granularities: TimestampGranularities, +) -> None: + audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230 + + if "word" in timestamp_granularities: + with pytest.raises(BadRequestError): + await actual_openai_client.audio.transcriptions.create( + file=audio_file, + model="whisper-1", + response_format="json", + timestamp_granularities=timestamp_granularities, + ) + else: + await actual_openai_client.audio.transcriptions.create( + file=audio_file, model="whisper-1", response_format="json", timestamp_granularities=timestamp_granularities + ) + + +@pytest.mark.asyncio() +@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS) +async def test_openai_verbose_json_response_format_and_timestamp_granularities_combinations( + actual_openai_client: AsyncOpenAI, + timestamp_granularities: TimestampGranularities, +) -> None: + audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230 + + transcription = await actual_openai_client.audio.transcriptions.create( + file=audio_file, + model="whisper-1", + response_format="verbose_json", + timestamp_granularities=timestamp_granularities, + ) + + assert transcription.__pydantic_extra__ + if timestamp_granularities == ["word"]: + # This is an exception where segments are not present + assert transcription.__pydantic_extra__.get("segments") is None + assert transcription.__pydantic_extra__.get("words") is not None + elif "word" in timestamp_granularities: + assert transcription.__pydantic_extra__.get("segments") is not None + assert transcription.__pydantic_extra__.get("words") is not None + else: + # Unless explicitly requested, words are not present + assert transcription.__pydantic_extra__.get("segments") is not None + assert transcription.__pydantic_extra__.get("words") is None