diff --git a/src/faster_whisper_server/routers/stt.py b/src/faster_whisper_server/routers/stt.py index f55bf0d9..a39b5686 100644 --- a/src/faster_whisper_server/routers/stt.py +++ b/src/faster_whisper_server/routers/stt.py @@ -9,6 +9,7 @@ APIRouter, Form, Query, + Request, Response, UploadFile, WebSocket, @@ -30,6 +31,8 @@ from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config from faster_whisper_server.server_models import ( + DEFAULT_TIMESTAMP_GRANULARITIES, + TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities, TranscriptionJsonResponse, TranscriptionVerboseJsonResponse, @@ -150,6 +153,18 @@ def translate_file( return segments_to_response(segments, transcription_info, response_format) +# HACK: Since Form() doesn't support `alias`, we need to use a workaround. +async def get_timestamp_granularities(request: Request) -> TimestampGranularities: + form = await request.form() + if form.get("timestamp_granularities[]") is None: + return DEFAULT_TIMESTAMP_GRANULARITIES + timestamp_granularities = form.getlist("timestamp_granularities[]") + assert ( + timestamp_granularities in TIMESTAMP_GRANULARITIES_COMBINATIONS + ), f"{timestamp_granularities} is not a valid value for `timestamp_granularities[]`." + return timestamp_granularities + + # https://platform.openai.com/docs/api-reference/audio/createTranscription # https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L8915 @router.post( @@ -159,6 +174,7 @@ def translate_file( def transcribe_file( config: ConfigDependency, model_manager: ModelManagerDependency, + request: Request, file: Annotated[UploadFile, Form()], model: Annotated[ModelName | None, Form()] = None, language: Annotated[Language | None, Form()] = None, @@ -167,6 +183,7 @@ def transcribe_file( temperature: Annotated[float, Form()] = 0.0, timestamp_granularities: Annotated[ TimestampGranularities, + # WARN: `alias` doesn't actually work. Form(alias="timestamp_granularities[]"), ] = ["segment"], stream: Annotated[bool, Form()] = False, @@ -178,6 +195,11 @@ def transcribe_file( language = config.default_language if response_format is None: response_format = config.default_response_format + timestamp_granularities = asyncio.run(get_timestamp_granularities(request)) + if timestamp_granularities != DEFAULT_TIMESTAMP_GRANULARITIES and response_format != ResponseFormat.VERBOSE_JSON: + logger.warning( + "It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities." # noqa: E501 + ) whisper = model_manager.load_model(model) segments, transcription_info = whisper.transcribe( file.file, diff --git a/src/faster_whisper_server/server_models.py b/src/faster_whisper_server/server_models.py index 33d17206..b6cf5aa4 100644 --- a/src/faster_whisper_server/server_models.py +++ b/src/faster_whisper_server/server_models.py @@ -29,7 +29,7 @@ class TranscriptionVerboseJsonResponse(BaseModel): language: str duration: float text: str - words: list[Word] + words: list[Word] | None segments: list[Segment] @classmethod @@ -38,7 +38,7 @@ def from_segment(cls, segment: Segment, transcription_info: TranscriptionInfo) - language=transcription_info.language, duration=segment.end - segment.start, text=segment.text, - words=(segment.words if isinstance(segment.words, list) else []), + words=segment.words if transcription_info.transcription_options.word_timestamps else None, segments=[segment], ) @@ -51,7 +51,7 @@ def from_segments( duration=transcription_info.duration, text=segments_to_text(segments), segments=segments, - words=Word.from_segments(segments), + words=Word.from_segments(segments) if transcription_info.transcription_options.word_timestamps else None, ) @classmethod @@ -112,6 +112,7 @@ class ModelObject(BaseModel): TimestampGranularities = list[Literal["segment", "word"]] +DEFAULT_TIMESTAMP_GRANULARITIES: TimestampGranularities = ["segment"] TIMESTAMP_GRANULARITIES_COMBINATIONS: list[TimestampGranularities] = [ [], # should be treated as ["segment"]. https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities ["segment"], diff --git a/tests/api_timestamp_granularities_test.py b/tests/api_timestamp_granularities_test.py new file mode 100644 index 00000000..79a556b3 --- /dev/null +++ b/tests/api_timestamp_granularities_test.py @@ -0,0 +1,43 @@ +"""See `tests/openai_timestamp_granularities_test.py` to understand how OpenAI handles `response_type` and `timestamp_granularities`.""" # noqa: E501 + +from faster_whisper_server.server_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities +from openai import AsyncOpenAI +import pytest + + +@pytest.mark.asyncio() +@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS) +async def test_api_json_response_format_and_timestamp_granularities_combinations( + openai_client: AsyncOpenAI, + timestamp_granularities: TimestampGranularities, +) -> None: + audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230 + + await openai_client.audio.transcriptions.create( + file=audio_file, model="whisper-1", response_format="json", timestamp_granularities=timestamp_granularities + ) + + +@pytest.mark.asyncio() +@pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS) +async def test_api_verbose_json_response_format_and_timestamp_granularities_combinations( + openai_client: AsyncOpenAI, + timestamp_granularities: TimestampGranularities, +) -> None: + audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230 + + transcription = await openai_client.audio.transcriptions.create( + file=audio_file, + model="whisper-1", + response_format="verbose_json", + timestamp_granularities=timestamp_granularities, + ) + + assert transcription.__pydantic_extra__ + if "word" in timestamp_granularities: + assert transcription.__pydantic_extra__.get("segments") is not None + assert transcription.__pydantic_extra__.get("words") is not None + else: + # Unless explicitly requested, words are not present + assert transcription.__pydantic_extra__.get("segments") is not None + assert transcription.__pydantic_extra__.get("words") is None