From 3dbf4e3fbaab8b5b85d3b276d9b8a87c84f22bab Mon Sep 17 00:00:00 2001 From: Fedir Zadniprovskyi Date: Sat, 21 Sep 2024 10:47:15 -0700 Subject: [PATCH 1/5] fix: pytest-asyncio fixture scope warnings --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index c1839b3f..106b0d7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,3 +98,6 @@ pythonPlatform = "Linux" # https://github.com/DetachHead/basedpyright?tab=readme-ov-file#pre-commit-hook venvPath = "." venv = ".venv" + +[tool.pytest.ini_options] +asyncio_default_fixture_loop_scope = "function" # this fixes pytest warning From fdaf37de48fbe69ca184a07ca2c685abe17fcc06 Mon Sep 17 00:00:00 2001 From: Fedir Zadniprovskyi Date: Sat, 21 Sep 2024 10:47:31 -0700 Subject: [PATCH 2/5] test: capture openai's param handling --- src/faster_whisper_server/routers/stt.py | 5 +- src/faster_whisper_server/server_models.py | 12 +++++ tests/conftest.py | 9 +++- tests/openai_timestamp_granularities_test.py | 56 ++++++++++++++++++++ 4 files changed, 79 insertions(+), 3 deletions(-) create mode 100644 tests/openai_timestamp_granularities_test.py diff --git a/src/faster_whisper_server/routers/stt.py b/src/faster_whisper_server/routers/stt.py index 9a3f5a3a..f55bf0d9 100644 --- a/src/faster_whisper_server/routers/stt.py +++ b/src/faster_whisper_server/routers/stt.py @@ -3,7 +3,7 @@ import asyncio from io import BytesIO import logging -from typing import TYPE_CHECKING, Annotated, Literal +from typing import TYPE_CHECKING, Annotated from fastapi import ( APIRouter, @@ -30,6 +30,7 @@ from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config from faster_whisper_server.server_models import ( + TimestampGranularities, TranscriptionJsonResponse, TranscriptionVerboseJsonResponse, ) @@ -165,7 +166,7 @@ def transcribe_file( response_format: Annotated[ResponseFormat | None, Form()] = None, temperature: Annotated[float, Form()] = 0.0, timestamp_granularities: Annotated[ - list[Literal["segment", "word"]], + TimestampGranularities, Form(alias="timestamp_granularities[]"), ] = ["segment"], stream: Annotated[bool, Form()] = False, diff --git a/src/faster_whisper_server/server_models.py b/src/faster_whisper_server/server_models.py index 2b23bfba..33d17206 100644 --- a/src/faster_whisper_server/server_models.py +++ b/src/faster_whisper_server/server_models.py @@ -107,3 +107,15 @@ class ModelObject(BaseModel): ] }, ) + + +TimestampGranularities = list[Literal["segment", "word"]] + + +TIMESTAMP_GRANULARITIES_COMBINATIONS: list[TimestampGranularities] = [ + [], # should be treated as ["segment"]. https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities + ["segment"], + ["word"], + ["word", "segment"], + ["segment", "word"], # same as ["word", "segment"] but order is different +] diff --git a/tests/conftest.py b/tests/conftest.py index dc796737..658f20bd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ from fastapi.testclient import TestClient from faster_whisper_server.main import create_app from httpx import ASGITransport, AsyncClient -from openai import OpenAI +from openai import AsyncOpenAI, OpenAI import pytest import pytest_asyncio @@ -35,3 +35,10 @@ async def aclient() -> AsyncGenerator[AsyncClient, None]: @pytest.fixture() def openai_client(client: TestClient) -> OpenAI: return OpenAI(api_key="cant-be-empty", http_client=client) + + +@pytest.fixture() +def actual_openai_client() -> AsyncOpenAI: + return AsyncOpenAI( + base_url="https://api.openai.com/v1" + ) # `base_url` is provided in case `OPENAI_API_BASE_URL` is set to a different value diff --git a/tests/openai_timestamp_granularities_test.py b/tests/openai_timestamp_granularities_test.py new file mode 100644 index 00000000..65643365 --- /dev/null +++ b/tests/openai_timestamp_granularities_test.py @@ -0,0 +1,56 @@ +"""OpenAI's handling of `response_format` and `timestamp_granularities` is a bit confusing and inconsistent. This test module exists to capture the OpenAI API's behavior with respect to these parameters.""" # noqa: E501 + +from faster_whisper_server.server_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities +from openai import AsyncOpenAI, BadRequestError +import pytest + + +@pytest.mark.asyncio() +@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS) +async def test_openai_json_response_format_and_timestamp_granularities_combinations( + actual_openai_client: AsyncOpenAI, + timestamp_granularities: TimestampGranularities, +) -> None: + audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230 + + if "word" in timestamp_granularities: + with pytest.raises(BadRequestError): + await actual_openai_client.audio.transcriptions.create( + file=audio_file, + model="whisper-1", + response_format="json", + timestamp_granularities=timestamp_granularities, + ) + else: + await actual_openai_client.audio.transcriptions.create( + file=audio_file, model="whisper-1", response_format="json", timestamp_granularities=timestamp_granularities + ) + + +@pytest.mark.asyncio() +@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS) +async def test_openai_verbose_json_response_format_and_timestamp_granularities_combinations( + actual_openai_client: AsyncOpenAI, + timestamp_granularities: TimestampGranularities, +) -> None: + audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230 + + transcription = await actual_openai_client.audio.transcriptions.create( + file=audio_file, + model="whisper-1", + response_format="verbose_json", + timestamp_granularities=timestamp_granularities, + ) + + assert transcription.__pydantic_extra__ + if timestamp_granularities == ["word"]: + # This is an exception where segments are not present + assert transcription.__pydantic_extra__.get("segments") is None + assert transcription.__pydantic_extra__.get("words") is not None + elif "word" in timestamp_granularities: + assert transcription.__pydantic_extra__.get("segments") is not None + assert transcription.__pydantic_extra__.get("words") is not None + else: + # Unless explicitly requested, words are not present + assert transcription.__pydantic_extra__.get("segments") is not None + assert transcription.__pydantic_extra__.get("words") is None From 0c1fc004ae033ca590262e2116d7ac7b16a66a16 Mon Sep 17 00:00:00 2001 From: Fedir Zadniprovskyi Date: Sat, 21 Sep 2024 10:47:38 -0700 Subject: [PATCH 3/5] test: switch to async openai client --- tests/api_model_test.py | 17 ++++++++++------- tests/conftest.py | 8 ++++---- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/api_model_test.py b/tests/api_model_test.py index fcb822db..c66659dd 100644 --- a/tests/api_model_test.py +++ b/tests/api_model_test.py @@ -1,5 +1,5 @@ import openai -from openai import OpenAI +from openai import AsyncOpenAI import pytest MODEL_THAT_EXISTS = "Systran/faster-whisper-tiny.en" @@ -7,16 +7,19 @@ MIN_EXPECTED_NUMBER_OF_MODELS = 70 # At the time of the test creation there are 89 models -def test_list_models(openai_client: OpenAI) -> None: - models = openai_client.models.list().data +@pytest.mark.asyncio() +async def test_list_models(openai_client: AsyncOpenAI) -> None: + models = (await openai_client.models.list()).data assert len(models) > MIN_EXPECTED_NUMBER_OF_MODELS -def test_model_exists(openai_client: OpenAI) -> None: - model = openai_client.models.retrieve(MODEL_THAT_EXISTS) +@pytest.mark.asyncio() +async def test_model_exists(openai_client: AsyncOpenAI) -> None: + model = await openai_client.models.retrieve(MODEL_THAT_EXISTS) assert model.id == MODEL_THAT_EXISTS -def test_model_does_not_exist(openai_client: OpenAI) -> None: +@pytest.mark.asyncio() +async def test_model_does_not_exist(openai_client: AsyncOpenAI) -> None: with pytest.raises(openai.NotFoundError): - openai_client.models.retrieve(MODEL_THAT_DOES_NOT_EXIST) + await openai_client.models.retrieve(MODEL_THAT_DOES_NOT_EXIST) diff --git a/tests/conftest.py b/tests/conftest.py index 658f20bd..316220a2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ from fastapi.testclient import TestClient from faster_whisper_server.main import create_app from httpx import ASGITransport, AsyncClient -from openai import AsyncOpenAI, OpenAI +from openai import AsyncOpenAI import pytest import pytest_asyncio @@ -32,9 +32,9 @@ async def aclient() -> AsyncGenerator[AsyncClient, None]: yield aclient -@pytest.fixture() -def openai_client(client: TestClient) -> OpenAI: - return OpenAI(api_key="cant-be-empty", http_client=client) +@pytest_asyncio.fixture() +def openai_client(aclient: AsyncClient) -> AsyncOpenAI: + return AsyncOpenAI(api_key="cant-be-empty", http_client=aclient) @pytest.fixture() From 2fd0726946285df9df7e5b28223b1277e874ade9 Mon Sep 17 00:00:00 2001 From: Fedir Zadniprovskyi Date: Sat, 21 Sep 2024 11:25:48 -0700 Subject: [PATCH 4/5] fix: `timestamp_granularities[]` handling (#28, #58, #81) --- src/faster_whisper_server/routers/stt.py | 22 +++++++++++ src/faster_whisper_server/server_models.py | 7 ++-- tests/api_timestamp_granularities_test.py | 43 ++++++++++++++++++++++ 3 files changed, 69 insertions(+), 3 deletions(-) create mode 100644 tests/api_timestamp_granularities_test.py diff --git a/src/faster_whisper_server/routers/stt.py b/src/faster_whisper_server/routers/stt.py index f55bf0d9..a39b5686 100644 --- a/src/faster_whisper_server/routers/stt.py +++ b/src/faster_whisper_server/routers/stt.py @@ -9,6 +9,7 @@ APIRouter, Form, Query, + Request, Response, UploadFile, WebSocket, @@ -30,6 +31,8 @@ from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config from faster_whisper_server.server_models import ( + DEFAULT_TIMESTAMP_GRANULARITIES, + TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities, TranscriptionJsonResponse, TranscriptionVerboseJsonResponse, @@ -150,6 +153,18 @@ def translate_file( return segments_to_response(segments, transcription_info, response_format) +# HACK: Since Form() doesn't support `alias`, we need to use a workaround. +async def get_timestamp_granularities(request: Request) -> TimestampGranularities: + form = await request.form() + if form.get("timestamp_granularities[]") is None: + return DEFAULT_TIMESTAMP_GRANULARITIES + timestamp_granularities = form.getlist("timestamp_granularities[]") + assert ( + timestamp_granularities in TIMESTAMP_GRANULARITIES_COMBINATIONS + ), f"{timestamp_granularities} is not a valid value for `timestamp_granularities[]`." + return timestamp_granularities + + # https://platform.openai.com/docs/api-reference/audio/createTranscription # https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915 @router.post( @@ -159,6 +174,7 @@ def translate_file( def transcribe_file( config: ConfigDependency, model_manager: ModelManagerDependency, + request: Request, file: Annotated[UploadFile, Form()], model: Annotated[ModelName | None, Form()] = None, language: Annotated[Language | None, Form()] = None, @@ -167,6 +183,7 @@ def transcribe_file( temperature: Annotated[float, Form()] = 0.0, timestamp_granularities: Annotated[ TimestampGranularities, + # WARN: `alias` doesn't actually work. Form(alias="timestamp_granularities[]"), ] = ["segment"], stream: Annotated[bool, Form()] = False, @@ -178,6 +195,11 @@ def transcribe_file( language = config.default_language if response_format is None: response_format = config.default_response_format + timestamp_granularities = asyncio.run(get_timestamp_granularities(request)) + if timestamp_granularities != DEFAULT_TIMESTAMP_GRANULARITIES and response_format != ResponseFormat.VERBOSE_JSON: + logger.warning( + "It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities." # noqa: E501 + ) whisper = model_manager.load_model(model) segments, transcription_info = whisper.transcribe( file.file, diff --git a/src/faster_whisper_server/server_models.py b/src/faster_whisper_server/server_models.py index 33d17206..b6cf5aa4 100644 --- a/src/faster_whisper_server/server_models.py +++ b/src/faster_whisper_server/server_models.py @@ -29,7 +29,7 @@ class TranscriptionVerboseJsonResponse(BaseModel): language: str duration: float text: str - words: list[Word] + words: list[Word] | None segments: list[Segment] @classmethod @@ -38,7 +38,7 @@ def from_segment(cls, segment: Segment, transcription_info: TranscriptionInfo) - language=transcription_info.language, duration=segment.end - segment.start, text=segment.text, - words=(segment.words if isinstance(segment.words, list) else []), + words=segment.words if transcription_info.transcription_options.word_timestamps else None, segments=[segment], ) @@ -51,7 +51,7 @@ def from_segments( duration=transcription_info.duration, text=segments_to_text(segments), segments=segments, - words=Word.from_segments(segments), + words=Word.from_segments(segments) if transcription_info.transcription_options.word_timestamps else None, ) @classmethod @@ -112,6 +112,7 @@ class ModelObject(BaseModel): TimestampGranularities = list[Literal["segment", "word"]] +DEFAULT_TIMESTAMP_GRANULARITIES: TimestampGranularities = ["segment"] TIMESTAMP_GRANULARITIES_COMBINATIONS: list[TimestampGranularities] = [ [], # should be treated as ["segment"]. https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities ["segment"], diff --git a/tests/api_timestamp_granularities_test.py b/tests/api_timestamp_granularities_test.py new file mode 100644 index 00000000..79a556b3 --- /dev/null +++ b/tests/api_timestamp_granularities_test.py @@ -0,0 +1,43 @@ +"""See `tests/openai_timestamp_granularities_test.py` to understand how OpenAI handles `response_type` and `timestamp_granularities`.""" # noqa: E501 + +from faster_whisper_server.server_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities +from openai import AsyncOpenAI +import pytest + + +@pytest.mark.asyncio() +@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS) +async def test_api_json_response_format_and_timestamp_granularities_combinations( + openai_client: AsyncOpenAI, + timestamp_granularities: TimestampGranularities, +) -> None: + audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230 + + await openai_client.audio.transcriptions.create( + file=audio_file, model="whisper-1", response_format="json", timestamp_granularities=timestamp_granularities + ) + + +@pytest.mark.asyncio() +@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS) +async def test_api_verbose_json_response_format_and_timestamp_granularities_combinations( + openai_client: AsyncOpenAI, + timestamp_granularities: TimestampGranularities, +) -> None: + audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230 + + transcription = await openai_client.audio.transcriptions.create( + file=audio_file, + model="whisper-1", + response_format="verbose_json", + timestamp_granularities=timestamp_granularities, + ) + + assert transcription.__pydantic_extra__ + if "word" in timestamp_granularities: + assert transcription.__pydantic_extra__.get("segments") is not None + assert transcription.__pydantic_extra__.get("words") is not None + else: + # Unless explicitly requested, words are not present + assert transcription.__pydantic_extra__.get("segments") is not None + assert transcription.__pydantic_extra__.get("words") is None From 5b962c60ea295e740de285a4a8ee21895cb03856 Mon Sep 17 00:00:00 2001 From: Fedir Zadniprovskyi Date: Sat, 21 Sep 2024 11:43:03 -0700 Subject: [PATCH 5/5] ci: do not run tests which require OpenAI API key --- .github/workflows/test.yaml | 2 +- tests/openai_timestamp_granularities_test.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 4d9d0772..78b81fce 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -20,4 +20,4 @@ jobs: - run: uv python install 3.12 - run: uv sync --extra dev # TODO: figure out why `pytest` doesn't discover tests in `faster_whisper_server` directory by itself - - run: uv run pytest src/faster_whisper_server/* tests + - run: uv run pytest -m "not requires_openai" src/faster_whisper_server/* tests diff --git a/tests/openai_timestamp_granularities_test.py b/tests/openai_timestamp_granularities_test.py index 65643365..e88129a5 100644 --- a/tests/openai_timestamp_granularities_test.py +++ b/tests/openai_timestamp_granularities_test.py @@ -6,6 +6,7 @@ @pytest.mark.asyncio() +@pytest.mark.requires_openai() @pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS) async def test_openai_json_response_format_and_timestamp_granularities_combinations( actual_openai_client: AsyncOpenAI, @@ -28,6 +29,7 @@ async def test_openai_json_response_format_and_timestamp_granularities_combinati @pytest.mark.asyncio() +@pytest.mark.requires_openai() @pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS) async def test_openai_verbose_json_response_format_and_timestamp_granularities_combinations( actual_openai_client: AsyncOpenAI,