Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix/timestamp granularities handling #89

Merged
merged 5 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ jobs:
- run: uv python install 3.12
- run: uv sync --extra dev
# TODO: figure out why `pytest` doesn't discover tests in `faster_whisper_server` directory by itself
- run: uv run pytest src/faster_whisper_server/* tests
- run: uv run pytest -m "not requires_openai" src/faster_whisper_server/* tests
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,6 @@ pythonPlatform = "Linux"
# https://github.com/DetachHead/basedpyright?tab=readme-ov-file#pre-commit-hook
venvPath = "."
venv = ".venv"

[tool.pytest.ini_options]
asyncio_default_fixture_loop_scope = "function" # this fixes pytest warning
27 changes: 25 additions & 2 deletions src/faster_whisper_server/routers/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import asyncio
from io import BytesIO
import logging
from typing import TYPE_CHECKING, Annotated, Literal
from typing import TYPE_CHECKING, Annotated

from fastapi import (
APIRouter,
Form,
Query,
Request,
Response,
UploadFile,
WebSocket,
Expand All @@ -30,6 +31,9 @@
from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config
from faster_whisper_server.server_models import (
DEFAULT_TIMESTAMP_GRANULARITIES,
TIMESTAMP_GRANULARITIES_COMBINATIONS,
TimestampGranularities,
TranscriptionJsonResponse,
TranscriptionVerboseJsonResponse,
)
Expand Down Expand Up @@ -149,6 +153,18 @@ def translate_file(
return segments_to_response(segments, transcription_info, response_format)


# HACK: Since Form() doesn't support `alias`, we need to use a workaround.
async def get_timestamp_granularities(request: Request) -> TimestampGranularities:
form = await request.form()
if form.get("timestamp_granularities[]") is None:
return DEFAULT_TIMESTAMP_GRANULARITIES
timestamp_granularities = form.getlist("timestamp_granularities[]")
assert (
timestamp_granularities in TIMESTAMP_GRANULARITIES_COMBINATIONS
), f"{timestamp_granularities} is not a valid value for `timestamp_granularities[]`."
return timestamp_granularities


# https://platform.openai.com/docs/api-reference/audio/createTranscription
# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915
@router.post(
Expand All @@ -158,14 +174,16 @@ def translate_file(
def transcribe_file(
config: ConfigDependency,
model_manager: ModelManagerDependency,
request: Request,
file: Annotated[UploadFile, Form()],
model: Annotated[ModelName | None, Form()] = None,
language: Annotated[Language | None, Form()] = None,
prompt: Annotated[str | None, Form()] = None,
response_format: Annotated[ResponseFormat | None, Form()] = None,
temperature: Annotated[float, Form()] = 0.0,
timestamp_granularities: Annotated[
list[Literal["segment", "word"]],
TimestampGranularities,
# WARN: `alias` doesn't actually work.
Form(alias="timestamp_granularities[]"),
] = ["segment"],
stream: Annotated[bool, Form()] = False,
Expand All @@ -177,6 +195,11 @@ def transcribe_file(
language = config.default_language
if response_format is None:
response_format = config.default_response_format
timestamp_granularities = asyncio.run(get_timestamp_granularities(request))
if timestamp_granularities != DEFAULT_TIMESTAMP_GRANULARITIES and response_format != ResponseFormat.VERBOSE_JSON:
logger.warning(
"It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities." # noqa: E501
)
whisper = model_manager.load_model(model)
segments, transcription_info = whisper.transcribe(
file.file,
Expand Down
19 changes: 16 additions & 3 deletions src/faster_whisper_server/server_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class TranscriptionVerboseJsonResponse(BaseModel):
language: str
duration: float
text: str
words: list[Word]
words: list[Word] | None
segments: list[Segment]

@classmethod
Expand All @@ -38,7 +38,7 @@ def from_segment(cls, segment: Segment, transcription_info: TranscriptionInfo) -
language=transcription_info.language,
duration=segment.end - segment.start,
text=segment.text,
words=(segment.words if isinstance(segment.words, list) else []),
words=segment.words if transcription_info.transcription_options.word_timestamps else None,
segments=[segment],
)

Expand All @@ -51,7 +51,7 @@ def from_segments(
duration=transcription_info.duration,
text=segments_to_text(segments),
segments=segments,
words=Word.from_segments(segments),
words=Word.from_segments(segments) if transcription_info.transcription_options.word_timestamps else None,
)

@classmethod
Expand Down Expand Up @@ -107,3 +107,16 @@ class ModelObject(BaseModel):
]
},
)


TimestampGranularities = list[Literal["segment", "word"]]


DEFAULT_TIMESTAMP_GRANULARITIES: TimestampGranularities = ["segment"]
TIMESTAMP_GRANULARITIES_COMBINATIONS: list[TimestampGranularities] = [
[], # should be treated as ["segment"]. https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities
["segment"],
["word"],
["word", "segment"],
["segment", "word"], # same as ["word", "segment"] but order is different
]
17 changes: 10 additions & 7 deletions tests/api_model_test.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import openai
from openai import OpenAI
from openai import AsyncOpenAI
import pytest

MODEL_THAT_EXISTS = "Systran/faster-whisper-tiny.en"
MODEL_THAT_DOES_NOT_EXIST = "i-do-not-exist"
MIN_EXPECTED_NUMBER_OF_MODELS = 70 # At the time of the test creation there are 89 models


def test_list_models(openai_client: OpenAI) -> None:
models = openai_client.models.list().data
@pytest.mark.asyncio()
async def test_list_models(openai_client: AsyncOpenAI) -> None:
models = (await openai_client.models.list()).data
assert len(models) > MIN_EXPECTED_NUMBER_OF_MODELS


def test_model_exists(openai_client: OpenAI) -> None:
model = openai_client.models.retrieve(MODEL_THAT_EXISTS)
@pytest.mark.asyncio()
async def test_model_exists(openai_client: AsyncOpenAI) -> None:
model = await openai_client.models.retrieve(MODEL_THAT_EXISTS)
assert model.id == MODEL_THAT_EXISTS


def test_model_does_not_exist(openai_client: OpenAI) -> None:
@pytest.mark.asyncio()
async def test_model_does_not_exist(openai_client: AsyncOpenAI) -> None:
with pytest.raises(openai.NotFoundError):
openai_client.models.retrieve(MODEL_THAT_DOES_NOT_EXIST)
await openai_client.models.retrieve(MODEL_THAT_DOES_NOT_EXIST)
43 changes: 43 additions & 0 deletions tests/api_timestamp_granularities_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""See `tests/openai_timestamp_granularities_test.py` to understand how OpenAI handles `response_type` and `timestamp_granularities`.""" # noqa: E501

from faster_whisper_server.server_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities
from openai import AsyncOpenAI
import pytest


@pytest.mark.asyncio()
@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS)
async def test_api_json_response_format_and_timestamp_granularities_combinations(
openai_client: AsyncOpenAI,
timestamp_granularities: TimestampGranularities,
) -> None:
audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230

await openai_client.audio.transcriptions.create(
file=audio_file, model="whisper-1", response_format="json", timestamp_granularities=timestamp_granularities
)


@pytest.mark.asyncio()
@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS)
async def test_api_verbose_json_response_format_and_timestamp_granularities_combinations(
openai_client: AsyncOpenAI,
timestamp_granularities: TimestampGranularities,
) -> None:
audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230

transcription = await openai_client.audio.transcriptions.create(
file=audio_file,
model="whisper-1",
response_format="verbose_json",
timestamp_granularities=timestamp_granularities,
)

assert transcription.__pydantic_extra__
if "word" in timestamp_granularities:
assert transcription.__pydantic_extra__.get("segments") is not None
assert transcription.__pydantic_extra__.get("words") is not None
else:
# Unless explicitly requested, words are not present
assert transcription.__pydantic_extra__.get("segments") is not None
assert transcription.__pydantic_extra__.get("words") is None
13 changes: 10 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fastapi.testclient import TestClient
from faster_whisper_server.main import create_app
from httpx import ASGITransport, AsyncClient
from openai import OpenAI
from openai import AsyncOpenAI
import pytest
import pytest_asyncio

Expand All @@ -32,6 +32,13 @@ async def aclient() -> AsyncGenerator[AsyncClient, None]:
yield aclient


@pytest_asyncio.fixture()
def openai_client(aclient: AsyncClient) -> AsyncOpenAI:
return AsyncOpenAI(api_key="cant-be-empty", http_client=aclient)


@pytest.fixture()
def openai_client(client: TestClient) -> OpenAI:
return OpenAI(api_key="cant-be-empty", http_client=client)
def actual_openai_client() -> AsyncOpenAI:
return AsyncOpenAI(
base_url="https://api.openai.com/v1"
) # `base_url` is provided in case `OPENAI_API_BASE_URL` is set to a different value
58 changes: 58 additions & 0 deletions tests/openai_timestamp_granularities_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""OpenAI's handling of `response_format` and `timestamp_granularities` is a bit confusing and inconsistent. This test module exists to capture the OpenAI API's behavior with respect to these parameters.""" # noqa: E501

from faster_whisper_server.server_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities
from openai import AsyncOpenAI, BadRequestError
import pytest


@pytest.mark.asyncio()
@pytest.mark.requires_openai()
@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS)
async def test_openai_json_response_format_and_timestamp_granularities_combinations(
actual_openai_client: AsyncOpenAI,
timestamp_granularities: TimestampGranularities,
) -> None:
audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230

if "word" in timestamp_granularities:
with pytest.raises(BadRequestError):
await actual_openai_client.audio.transcriptions.create(
file=audio_file,
model="whisper-1",
response_format="json",
timestamp_granularities=timestamp_granularities,
)
else:
await actual_openai_client.audio.transcriptions.create(
file=audio_file, model="whisper-1", response_format="json", timestamp_granularities=timestamp_granularities
)


@pytest.mark.asyncio()
@pytest.mark.requires_openai()
@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS)
async def test_openai_verbose_json_response_format_and_timestamp_granularities_combinations(
actual_openai_client: AsyncOpenAI,
timestamp_granularities: TimestampGranularities,
) -> None:
audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230

transcription = await actual_openai_client.audio.transcriptions.create(
file=audio_file,
model="whisper-1",
response_format="verbose_json",
timestamp_granularities=timestamp_granularities,
)

assert transcription.__pydantic_extra__
if timestamp_granularities == ["word"]:
# This is an exception where segments are not present
assert transcription.__pydantic_extra__.get("segments") is None
assert transcription.__pydantic_extra__.get("words") is not None
elif "word" in timestamp_granularities:
assert transcription.__pydantic_extra__.get("segments") is not None
assert transcription.__pydantic_extra__.get("words") is not None
else:
# Unless explicitly requested, words are not present
assert transcription.__pydantic_extra__.get("segments") is not None
assert transcription.__pydantic_extra__.get("words") is None