From d837657a7f61fe131a9f27b7855d5d06b75422f8 Mon Sep 17 00:00:00 2001 From: Fedir Zadniprovskyi Date: Tue, 22 Oct 2024 22:18:05 -0700 Subject: [PATCH] feat: tts --- src/faster_whisper_server/dependencies.py | 17 +- src/faster_whisper_server/hf_utils.py | 163 ++++++++++++++++- src/faster_whisper_server/main.py | 4 + src/faster_whisper_server/model_manager.py | 114 ++++++++---- .../routers/list_models.py | 28 +-- src/faster_whisper_server/routers/speech.py | 164 ++++++++++++++++++ tests/speech_test.py | 147 ++++++++++++++++ 7 files changed, 571 insertions(+), 66 deletions(-) create mode 100644 src/faster_whisper_server/routers/speech.py create mode 100644 tests/speech_test.py diff --git a/src/faster_whisper_server/dependencies.py b/src/faster_whisper_server/dependencies.py index ade976fe..985585c0 100644 --- a/src/faster_whisper_server/dependencies.py +++ b/src/faster_whisper_server/dependencies.py @@ -4,7 +4,7 @@ from fastapi import Depends from faster_whisper_server.config import Config -from faster_whisper_server.model_manager import ModelManager +from faster_whisper_server.model_manager import PiperModelManager, WhisperModelManager @lru_cache @@ -16,9 +16,18 @@ def get_config() -> Config: @lru_cache -def get_model_manager() -> ModelManager: +def get_model_manager() -> WhisperModelManager: config = get_config() # HACK - return ModelManager(config.whisper) + return WhisperModelManager(config.whisper) -ModelManagerDependency = Annotated[ModelManager, Depends(get_model_manager)] +ModelManagerDependency = Annotated[WhisperModelManager, Depends(get_model_manager)] + + +@lru_cache +def get_piper_model_manager() -> PiperModelManager: + config = get_config() # HACK + return PiperModelManager(config.whisper.ttl) # HACK + + +PiperModelManagerDependency = Annotated[PiperModelManager, Depends(get_piper_model_manager)] diff --git a/src/faster_whisper_server/hf_utils.py b/src/faster_whisper_server/hf_utils.py index 94a1642f..14e00da4 100644 --- a/src/faster_whisper_server/hf_utils.py +++ b/src/faster_whisper_server/hf_utils.py @@ -1,9 +1,16 @@ from collections.abc import Generator +from functools import lru_cache +import json import logging from pathlib import Path import typing +from typing import Any, Literal import huggingface_hub +from huggingface_hub.constants import HF_HUB_CACHE +from pydantic import BaseModel + +from faster_whisper_server.api_models import Model logger = logging.getLogger(__name__) @@ -12,10 +19,36 @@ def does_local_model_exist(model_id: str) -> bool: - return any(model_id == model.repo_id for model, _ in list_local_models()) + return any(model_id == model.repo_id for model, _ in list_local_whisper_models()) + + +def list_whisper_models() -> Generator[Model, None, None]: + models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True) + models = list(models) + models.sort(key=lambda model: model.downloads or -1, reverse=True) + for model in models: + assert model.created_at is not None + assert model.card_data is not None + assert model.card_data.language is None or isinstance(model.card_data.language, str | list) + if model.card_data.language is None: + language = [] + elif isinstance(model.card_data.language, str): + language = [model.card_data.language] + else: + language = model.card_data.language + transformed_model = Model( + id=model.id, + created=int(model.created_at.timestamp()), + object_="model", + owned_by=model.id.split("/")[0], + language=language, + ) + yield transformed_model -def list_local_models() -> Generator[tuple[huggingface_hub.CachedRepoInfo, huggingface_hub.ModelCardData], None, None]: +def list_local_whisper_models() -> ( + Generator[tuple[huggingface_hub.CachedRepoInfo, huggingface_hub.ModelCardData], None, None] +): hf_cache = huggingface_hub.scan_cache_dir() hf_models = [repo for repo in list(hf_cache.repos) if repo.repo_type == "model"] for model in hf_models: @@ -36,3 +69,129 @@ def list_local_models() -> Generator[tuple[huggingface_hub.CachedRepoInfo, huggi and TASK_NAME in model_card_data.tags ): yield model, model_card_data + + +def get_whisper_models() -> Generator[Model, None, None]: + models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True) + models = list(models) + models.sort(key=lambda model: model.downloads or -1, reverse=True) + for model in models: + assert model.created_at is not None + assert model.card_data is not None + assert model.card_data.language is None or isinstance(model.card_data.language, str | list) + if model.card_data.language is None: + language = [] + elif isinstance(model.card_data.language, str): + language = [model.card_data.language] + else: + language = model.card_data.language + transformed_model = Model( + id=model.id, + created=int(model.created_at.timestamp()), + object_="model", + owned_by=model.id.split("/")[0], + language=language, + ) + yield transformed_model + + +class PiperModel(BaseModel): + id: str + object: Literal["model"] = "model" + created: int + owned_by: Literal["rhasspy"] = "rhasspy" + path: Path + config_path: Path + + +def get_model_path(model_id: str, *, cache_dir: str | Path | None = None) -> Path | None: + if cache_dir is None: + cache_dir = HF_HUB_CACHE + + cache_dir = Path(cache_dir).expanduser().resolve() + if not cache_dir.exists(): + raise huggingface_hub.CacheNotFound( + f"Cache directory not found: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable.", # noqa: E501 + cache_dir=cache_dir, + ) + + if cache_dir.is_file(): + raise ValueError( + f"Scan cache expects a directory but found a file: {cache_dir}. Please use `cache_dir` argument or set `HF_HUB_CACHE` environment variable." # noqa: E501 + ) + + for repo_path in cache_dir.iterdir(): + if not repo_path.is_dir(): + continue + if repo_path.name == ".locks": # skip './.locks/' folder + continue + repo_type, repo_id = repo_path.name.split("--", maxsplit=1) + repo_type = repo_type[:-1] # "models" -> "model" + repo_id = repo_id.replace("--", "/") # google--fleurs -> "google/fleurs" + if repo_type != "model": + continue + if model_id == repo_id: + return repo_path + + return None + + +def list_model_files( + model_id: str, glob_pattern: str = "**/*", *, cache_dir: str | Path | None = None +) -> Generator[Path, None, None]: + repo_path = get_model_path(model_id, cache_dir=cache_dir) + if repo_path is None: + return None + snapshots_path = repo_path / "snapshots" + if not snapshots_path.exists(): + return None + yield from list(snapshots_path.glob(glob_pattern)) + + +def list_piper_models() -> Generator[PiperModel, None, None]: + model_weights_files = list_model_files("rhasspy/piper-voices", glob_pattern="**/*.onnx") + for model_weights_file in model_weights_files: + model_config_file = model_weights_file.with_suffix(".json") + yield PiperModel( + id=model_weights_file.name, + created=int(model_weights_file.stat().st_mtime), + path=model_weights_file, + config_path=model_config_file, + ) + + +# NOTE: It's debatable whether caching should be done here or by the caller. Should be revisited. + + +@lru_cache +def read_piper_voices_config() -> dict[str, Any]: + voices_file = next(list_model_files("rhasspy/piper-voices", glob_pattern="**/voices.json"), None) + if voices_file is None: + raise FileNotFoundError("Could not find voices.json file") + return json.loads(voices_file.read_text()) + + +@lru_cache +def get_piper_voice_model_file(voice: str) -> Path: + model_file = next(list_model_files("rhasspy/piper-voices", glob_pattern=f"**/{voice}.onnx"), None) + if model_file is None: + raise FileNotFoundError(f"Could not find model file for '{voice}' voice") + return model_file + + +class PiperVoiceConfigAudio(BaseModel): + sample_rate: int + quality: int + + +class PiperVoiceConfig(BaseModel): + audio: PiperVoiceConfigAudio + # NOTE: there are more fields in the config, but we don't care about them + + +@lru_cache +def read_piper_voice_config(voice: str) -> PiperVoiceConfig: + model_config_file = next(list_model_files("rhasspy/piper-voices", glob_pattern=f"**/{voice}.onnx.json"), None) + if model_config_file is None: + raise FileNotFoundError(f"Could not find config file for '{voice}' voice") + return PiperVoiceConfig.model_validate_json(model_config_file.read_text()) diff --git a/src/faster_whisper_server/main.py b/src/faster_whisper_server/main.py index 4ff9f003..8f8351b4 100644 --- a/src/faster_whisper_server/main.py +++ b/src/faster_whisper_server/main.py @@ -17,6 +17,9 @@ from faster_whisper_server.routers.misc import ( router as misc_router, ) +from faster_whisper_server.routers.speech import ( + router as speech_router, +) from faster_whisper_server.routers.stt import ( router as stt_router, ) @@ -46,6 +49,7 @@ async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]: app.include_router(stt_router) app.include_router(list_models_router) app.include_router(misc_router) + app.include_router(speech_router) if config.allow_origins is not None: app.add_middleware( diff --git a/src/faster_whisper_server/model_manager.py b/src/faster_whisper_server/model_manager.py index 43de3e1c..56bc86f5 100644 --- a/src/faster_whisper_server/model_manager.py +++ b/src/faster_whisper_server/model_manager.py @@ -8,6 +8,9 @@ from typing import TYPE_CHECKING from faster_whisper import WhisperModel +from piper.voice import PiperVoice + +from faster_whisper_server.hf_utils import get_piper_voice_model_file if TYPE_CHECKING: from collections.abc import Callable @@ -21,51 +24,41 @@ # TODO: enable concurrent model downloads -class SelfDisposingWhisperModel: +class SelfDisposingModel[T]: def __init__( - self, - model_id: str, - whisper_config: WhisperConfig, - *, - on_unload: Callable[[str], None] | None = None, + self, model_id: str, load_fn: Callable[[], T], ttl: int, unload_fn: Callable[[str], None] | None = None ) -> None: self.model_id = model_id - self.whisper_config = whisper_config - self.on_unload = on_unload + self.load_fn = load_fn + self.ttl = ttl + self.unload_fn = unload_fn self.ref_count: int = 0 self.rlock = threading.RLock() self.expire_timer: threading.Timer | None = None - self.whisper: WhisperModel | None = None + self.model: T | None = None def unload(self) -> None: with self.rlock: - if self.whisper is None: + if self.model is None: raise ValueError(f"Model {self.model_id} is not loaded. {self.ref_count=}") if self.ref_count > 0: raise ValueError(f"Model {self.model_id} is still in use. {self.ref_count=}") if self.expire_timer: self.expire_timer.cancel() - self.whisper = None + self.model = None # WARN: ~300 MB of memory will still be held by the model. See https://github.com/SYSTRAN/faster-whisper/issues/992 gc.collect() logger.info(f"Model {self.model_id} unloaded") - if self.on_unload is not None: - self.on_unload(self.model_id) + if self.unload_fn is not None: + self.unload_fn(self.model_id) def _load(self) -> None: with self.rlock: - assert self.whisper is None + assert self.model is None logger.debug(f"Loading model {self.model_id}") start = time.perf_counter() - self.whisper = WhisperModel( - self.model_id, - device=self.whisper_config.inference_device, - device_index=self.whisper_config.device_index, - compute_type=self.whisper_config.compute_type, - cpu_threads=self.whisper_config.cpu_threads, - num_workers=self.whisper_config.num_workers, - ) + self.model = self.load_fn() logger.info(f"Model {self.model_id} loaded in {time.perf_counter() - start:.2f}s") def _increment_ref(self) -> None: @@ -81,34 +74,82 @@ def _decrement_ref(self) -> None: self.ref_count -= 1 logger.debug(f"Decremented ref count for {self.model_id}, {self.ref_count=}") if self.ref_count <= 0: - if self.whisper_config.ttl > 0: - logger.info(f"Model {self.model_id} is idle, scheduling offload in {self.whisper_config.ttl}s") - self.expire_timer = threading.Timer(self.whisper_config.ttl, self.unload) + if self.ttl > 0: + logger.info(f"Model {self.model_id} is idle, scheduling offload in {self.ttl}s") + self.expire_timer = threading.Timer(self.ttl, self.unload) self.expire_timer.start() - elif self.whisper_config.ttl == 0: + elif self.ttl == 0: logger.info(f"Model {self.model_id} is idle, unloading immediately") self.unload() else: logger.info(f"Model {self.model_id} is idle, not unloading") - def __enter__(self) -> WhisperModel: + def __enter__(self) -> T: with self.rlock: - if self.whisper is None: + if self.model is None: self._load() self._increment_ref() - assert self.whisper is not None - return self.whisper + assert self.model is not None + return self.model def __exit__(self, *_args) -> None: # noqa: ANN002 self._decrement_ref() -class ModelManager: +class WhisperModelManager: def __init__(self, whisper_config: WhisperConfig) -> None: self.whisper_config = whisper_config - self.loaded_models: OrderedDict[str, SelfDisposingWhisperModel] = OrderedDict() + self.loaded_models: OrderedDict[str, SelfDisposingModel[WhisperModel]] = OrderedDict() + self._lock = threading.Lock() + + def _load_fn(self, model_id: str) -> WhisperModel: + return WhisperModel( + model_id, + device=self.whisper_config.inference_device, + device_index=self.whisper_config.device_index, + compute_type=self.whisper_config.compute_type, + cpu_threads=self.whisper_config.cpu_threads, + num_workers=self.whisper_config.num_workers, + ) + + def _handle_model_unload(self, model_name: str) -> None: + with self._lock: + if model_name in self.loaded_models: + del self.loaded_models[model_name] + + def unload_model(self, model_name: str) -> None: + with self._lock: + model = self.loaded_models.get(model_name) + if model is None: + raise KeyError(f"Model {model_name} not found") + self.loaded_models[model_name].unload() + + def load_model(self, model_name: str) -> SelfDisposingModel[WhisperModel]: + logger.debug(f"Loading model {model_name}") + with self._lock: + logger.debug("Acquired lock") + if model_name in self.loaded_models: + logger.debug(f"{model_name} model already loaded") + return self.loaded_models[model_name] + self.loaded_models[model_name] = SelfDisposingModel[WhisperModel]( + model_name, + load_fn=lambda: self._load_fn(model_name), + ttl=self.whisper_config.ttl, + unload_fn=self._handle_model_unload, + ) + return self.loaded_models[model_name] + + +class PiperModelManager: + def __init__(self, ttl: int) -> None: + self.ttl = ttl + self.loaded_models: OrderedDict[str, SelfDisposingModel[PiperVoice]] = OrderedDict() self._lock = threading.Lock() + def _load_fn(self, model_id: str) -> PiperVoice: + model_path = get_piper_voice_model_file(model_id) + return PiperVoice.load(model_path) + def _handle_model_unload(self, model_name: str) -> None: with self._lock: if model_name in self.loaded_models: @@ -121,14 +162,15 @@ def unload_model(self, model_name: str) -> None: raise KeyError(f"Model {model_name} not found") self.loaded_models[model_name].unload() - def load_model(self, model_name: str) -> SelfDisposingWhisperModel: + def load_model(self, model_name: str) -> SelfDisposingModel[PiperVoice]: with self._lock: if model_name in self.loaded_models: logger.debug(f"{model_name} model already loaded") return self.loaded_models[model_name] - self.loaded_models[model_name] = SelfDisposingWhisperModel( + self.loaded_models[model_name] = SelfDisposingModel[PiperVoice]( model_name, - self.whisper_config, - on_unload=self._handle_model_unload, + load_fn=lambda: self._load_fn(model_name), + ttl=self.ttl, + unload_fn=self._handle_model_unload, ) return self.loaded_models[model_name] diff --git a/src/faster_whisper_server/routers/list_models.py b/src/faster_whisper_server/routers/list_models.py index 86ad7e65..314e5864 100644 --- a/src/faster_whisper_server/routers/list_models.py +++ b/src/faster_whisper_server/routers/list_models.py @@ -13,6 +13,7 @@ ListModelsResponse, Model, ) +from faster_whisper_server.hf_utils import list_whisper_models if TYPE_CHECKING: from huggingface_hub.hf_api import ModelInfo @@ -22,34 +23,13 @@ @router.get("/v1/models") def get_models() -> ListModelsResponse: - models = huggingface_hub.list_models(library="ctranslate2", tags="automatic-speech-recognition", cardData=True) - models = list(models) - models.sort(key=lambda model: model.downloads or -1, reverse=True) - transformed_models: list[Model] = [] - for model in models: - assert model.created_at is not None - assert model.card_data is not None - assert model.card_data.language is None or isinstance(model.card_data.language, str | list) - if model.card_data.language is None: - language = [] - elif isinstance(model.card_data.language, str): - language = [model.card_data.language] - else: - language = model.card_data.language - transformed_model = Model( - id=model.id, - created=int(model.created_at.timestamp()), - object_="model", - owned_by=model.id.split("/")[0], - language=language, - ) - transformed_models.append(transformed_model) - return ListModelsResponse(data=transformed_models) + whisper_models = list(list_whisper_models()) + return ListModelsResponse(data=whisper_models) @router.get("/v1/models/{model_name:path}") -# NOTE: `examples` doesn't work https://github.com/tiangolo/fastapi/discussions/10537 def get_model( + # NOTE: `examples` doesn't work https://github.com/tiangolo/fastapi/discussions/10537 model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")], ) -> Model: models = huggingface_hub.list_models( diff --git a/src/faster_whisper_server/routers/speech.py b/src/faster_whisper_server/routers/speech.py new file mode 100644 index 00000000..95b20608 --- /dev/null +++ b/src/faster_whisper_server/routers/speech.py @@ -0,0 +1,164 @@ +from collections.abc import Generator +import io +import logging +import time +from typing import Annotated, Literal, Self + +from fastapi import APIRouter +from fastapi.responses import StreamingResponse +import numpy as np +from piper.voice import PiperVoice +from pydantic import BaseModel, BeforeValidator, Field, ValidationError, model_validator +import soundfile as sf + +from faster_whisper_server.dependencies import PiperModelManagerDependency +from faster_whisper_server.hf_utils import read_piper_voices_config + +DEFAULT_MODEL = "piper" +# https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-response_format +DEFAULT_RESPONSE_FORMAT = "mp3" +DEFAULT_VOICE = "en_US-amy-medium" # TODO: make configurable +DEFAULT_VOICE_SAMPLE_RATE = 22050 # NOTE: Dependant on the voice + +# https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-model +# https://platform.openai.com/docs/models/tts +OPENAI_SUPPORTED_SPEECH_MODEL = ("tts-1", "tts-1-hd") + +# https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-voice +# https://platform.openai.com/docs/guides/text-to-speech/voice-options +OPENAI_SUPPORTED_SPEECH_VOICE_NAMES = ("alloy", "echo", "fable", "onyx", "nova", "shimmer") + +# https://platform.openai.com/docs/guides/text-to-speech/supported-output-formats +type ResponseFormat = Literal["mp3", "flac", "wav", "pcm"] +SUPPORTED_RESPONSE_FORMATS = ("mp3", "flac", "wav", "pcm") +UNSUPORTED_RESPONSE_FORMATS = ("opus", "aac") + +MIN_SAMPLE_RATE = 8000 +MAX_SAMPLE_RATE = 48000 + + +logger = logging.getLogger(__name__) + +router = APIRouter() + + +# aip 'Write a function `resample_audio` which would take in RAW PCM 16-bit signed, little-endian audio data represented as bytes (`audio_bytes`) and resample it (either downsample or upsample) from `sample_rate` to `target_sample_rate` using numpy' # noqa: E501 +def resample_audio(audio_bytes: bytes, sample_rate: int, target_sample_rate: int) -> bytes: + audio_data = np.frombuffer(audio_bytes, dtype=np.int16) + duration = len(audio_data) / sample_rate + target_length = int(duration * target_sample_rate) + resampled_data = np.interp( + np.linspace(0, len(audio_data), target_length, endpoint=False), np.arange(len(audio_data)), audio_data + ) + return resampled_data.astype(np.int16).tobytes() + + +def generate_audio( + piper_tts: PiperVoice, text: str, *, speed: float = 1.0, sample_rate: int | None = None +) -> Generator[bytes, None, None]: + if sample_rate is None: + sample_rate = piper_tts.config.sample_rate + start = time.perf_counter() + for audio_bytes in piper_tts.synthesize_stream_raw(text, length_scale=1.0 / speed): + if sample_rate != piper_tts.config.sample_rate: + audio_bytes = resample_audio(audio_bytes, piper_tts.config.sample_rate, sample_rate) # noqa: PLW2901 + yield audio_bytes + logger.info(f"Generated audio for {len(text)} characters in {time.perf_counter() - start}s") + + +def convert_audio_format( + audio_bytes: bytes, + sample_rate: int, + audio_format: ResponseFormat, + format: str = "RAW", # noqa: A002 + channels: int = 1, + subtype: str = "PCM_16", + endian: str = "LITTLE", +) -> bytes: + # NOTE: the default dtype is float64. Should something else be used? Would that improve performance? + data, _ = sf.read( + io.BytesIO(audio_bytes), + samplerate=sample_rate, + format=format, + channels=channels, + subtype=subtype, + endian=endian, + ) + converted_audio_bytes_buffer = io.BytesIO() + sf.write(converted_audio_bytes_buffer, data, samplerate=sample_rate, format=audio_format) + return converted_audio_bytes_buffer.getvalue() + + +def handle_openai_supported_model_ids(model_id: str) -> str: + if model_id in OPENAI_SUPPORTED_SPEECH_MODEL: + logger.warning(f"{model_id} is not a valid model name. Using '{DEFAULT_MODEL}' instead.") + return DEFAULT_MODEL + return model_id + + +ModelId = Annotated[ + Literal["piper"], + BeforeValidator(handle_openai_supported_model_ids), + Field( + description=f"The ID of the model. The only supported model is '{DEFAULT_MODEL}'.", + examples=[DEFAULT_MODEL], + ), +] + + +def handle_openai_supported_voices(voice: str) -> str: + if voice in OPENAI_SUPPORTED_SPEECH_VOICE_NAMES: + logger.warning(f"{voice} is not a valid voice name. Using '{DEFAULT_VOICE}' instead.") + return DEFAULT_VOICE + return voice + + +Voice = Annotated[str, BeforeValidator(handle_openai_supported_voices)] # TODO: description and examples + + +class CreateSpeechRequestBody(BaseModel): + model: ModelId = DEFAULT_MODEL + input: str = Field( + ..., + description="The text to generate audio for. ", + examples=[ + "A rainbow is an optical phenomenon caused by refraction, internal reflection and dispersion of light in water droplets resulting in a continuous spectrum of light appearing in the sky. The rainbow takes the form of a multicoloured circular arc. Rainbows caused by sunlight always appear in the section of sky directly opposite the Sun. Rainbows can be caused by many forms of airborne water. These include not only rain, but also mist, spray, and airborne dew." # noqa: E501 + ], + ) + voice: Voice = DEFAULT_VOICE + response_format: ResponseFormat = Field( + DEFAULT_RESPONSE_FORMAT, + description=f"The format to audio in. Supported formats are {", ".join(SUPPORTED_RESPONSE_FORMATS)}. {", ".join(UNSUPORTED_RESPONSE_FORMATS)} are not supported", # noqa: E501 + examples=list(SUPPORTED_RESPONSE_FORMATS), + ) + # https://platform.openai.com/docs/api-reference/audio/createSpeech#audio-createspeech-voice + speed: float = Field(1.0, ge=0.25, le=4.0) + """The speed of the generated audio. Select a value from 0.25 to 4.0. 1.0 is the default.""" + sample_rate: int | None = Field(None, ge=MIN_SAMPLE_RATE, le=MAX_SAMPLE_RATE) # TODO: document + + # TODO: move into `Voice` + @model_validator(mode="after") + def verify_voice_is_valid(self) -> Self: + valid_voices = read_piper_voices_config() + if self.voice not in valid_voices: + raise ValidationError(f"Voice '{self.voice}' is not supported. Supported voices: {valid_voices.keys()}") + return self + + +# https://platform.openai.com/docs/api-reference/audio/createSpeech +@router.post("/v1/audio/speech") +def synthesize( + piper_model_manager: PiperModelManagerDependency, + body: CreateSpeechRequestBody, +) -> StreamingResponse: + with piper_model_manager.load_model(body.voice) as piper_tts: + audio_generator = generate_audio(piper_tts, body.input, speed=body.speed, sample_rate=body.sample_rate) + if body.response_format != "pcm": + audio_generator = ( + convert_audio_format( + audio_bytes, body.sample_rate or piper_tts.config.sample_rate, body.response_format + ) + for audio_bytes in audio_generator + ) + + return StreamingResponse(audio_generator, media_type=f"audio/{body.response_format}") diff --git a/tests/speech_test.py b/tests/speech_test.py new file mode 100644 index 00000000..ca3fb61e --- /dev/null +++ b/tests/speech_test.py @@ -0,0 +1,147 @@ +import io + +from faster_whisper_server.routers.speech import ( + DEFAULT_MODEL, + DEFAULT_RESPONSE_FORMAT, + DEFAULT_VOICE, + SUPPORTED_RESPONSE_FORMATS, + ResponseFormat, +) +from openai import APIConnectionError, AsyncOpenAI, UnprocessableEntityError +import pytest +import soundfile as sf + +DEFAULT_INPUT = "Hello, world!" + + +@pytest.mark.asyncio() +@pytest.mark.parametrize("response_format", SUPPORTED_RESPONSE_FORMATS) +async def test_create_speech_formats(openai_client: AsyncOpenAI, response_format: ResponseFormat) -> None: + await openai_client.audio.speech.create( + model=DEFAULT_MODEL, + voice=DEFAULT_VOICE, # type: ignore # noqa: PGH003 + input=DEFAULT_INPUT, + response_format=response_format, + ) + + +GOOD_MODEL_VOICE_PAIRS: list[tuple[str, str]] = [ + ("tts-1", "alloy"), # OpenAI and OpenAI + ("tts-1-hd", "echo"), # OpenAI and OpenAI + ("tts-1", DEFAULT_VOICE), # OpenAI and Piper + (DEFAULT_MODEL, "echo"), # Piper and OpenAI + (DEFAULT_MODEL, DEFAULT_VOICE), # Piper and Piper +] + + +@pytest.mark.asyncio() +@pytest.mark.parametrize(("model", "voice"), GOOD_MODEL_VOICE_PAIRS) +async def test_create_speech_good_model_voice_pair(openai_client: AsyncOpenAI, model: str, voice: str) -> None: + await openai_client.audio.speech.create( + model=model, + voice=voice, # type: ignore # noqa: PGH003 + input=DEFAULT_INPUT, + response_format=DEFAULT_RESPONSE_FORMAT, + ) + + +BAD_MODEL_VOICE_PAIRS: list[tuple[str, str]] = [ + ("tts-1", "invalid"), # OpenAI and invalid + ("invalid", "echo"), # Invalid and OpenAI + (DEFAULT_MODEL, "invalid"), # Piper and invalid + ("invalid", DEFAULT_VOICE), # Invalid and Piper + ("invalid", "invalid"), # Invalid and invalid +] + + +@pytest.mark.asyncio() +@pytest.mark.parametrize(("model", "voice"), BAD_MODEL_VOICE_PAIRS) +async def test_create_speech_bad_model_voice_pair(openai_client: AsyncOpenAI, model: str, voice: str) -> None: + # NOTE: not sure why `APIConnectionError` is sometimes raised + with pytest.raises((UnprocessableEntityError, APIConnectionError)): + await openai_client.audio.speech.create( + model=model, + voice=voice, # type: ignore # noqa: PGH003 + input=DEFAULT_INPUT, + response_format=DEFAULT_RESPONSE_FORMAT, + ) + + +SUPPORTED_SPEEDS = [0.25, 0.5, 1.0, 2.0, 4.0] + + +@pytest.mark.asyncio() +async def test_create_speech_with_varying_speed(openai_client: AsyncOpenAI) -> None: + previous_size: int | None = None + for speed in SUPPORTED_SPEEDS: + res = await openai_client.audio.speech.create( + model=DEFAULT_MODEL, + voice=DEFAULT_VOICE, # type: ignore # noqa: PGH003 + input=DEFAULT_INPUT, + response_format="pcm", + speed=speed, + ) + audio_bytes = res.read() + if previous_size is not None: + assert len(audio_bytes) * 1.5 < previous_size # TODO: document magic number + previous_size = len(audio_bytes) + + +UNSUPPORTED_SPEEDS = [0.1, 4.1] + + +@pytest.mark.asyncio() +@pytest.mark.parametrize("speed", UNSUPPORTED_SPEEDS) +async def test_create_speech_with_unsupported_speed(openai_client: AsyncOpenAI, speed: float) -> None: + with pytest.raises(UnprocessableEntityError): + await openai_client.audio.speech.create( + model=DEFAULT_MODEL, + voice=DEFAULT_VOICE, # type: ignore # noqa: PGH003 + input=DEFAULT_INPUT, + response_format="pcm", + speed=speed, + ) + + +VALID_SAMPLE_RATES = [16000, 22050, 24000, 48000] + + +@pytest.mark.asyncio() +@pytest.mark.parametrize("sample_rate", VALID_SAMPLE_RATES) +async def test_speech_valid_resample(openai_client: AsyncOpenAI, sample_rate: int) -> None: + res = await openai_client.audio.speech.create( + model=DEFAULT_MODEL, + voice=DEFAULT_VOICE, # type: ignore # noqa: PGH003 + input=DEFAULT_INPUT, + response_format="wav", + extra_body={"sample_rate": sample_rate}, + ) + _, actual_sample_rate = sf.read(io.BytesIO(res.content)) + assert actual_sample_rate == sample_rate + + +INVALID_SAMPLE_RATES = [7999, 48001] + + +@pytest.mark.asyncio() +@pytest.mark.parametrize("sample_rate", INVALID_SAMPLE_RATES) +async def test_speech_invalid_resample(openai_client: AsyncOpenAI, sample_rate: int) -> None: + with pytest.raises(UnprocessableEntityError): + await openai_client.audio.speech.create( + model=DEFAULT_MODEL, + voice=DEFAULT_VOICE, # type: ignore # noqa: PGH003 + input=DEFAULT_INPUT, + response_format="wav", + extra_body={"sample_rate": sample_rate}, + ) + + +# TODO: implement the following test + +# NUMBER_OF_MODELS = 1 +# NUMBER_OF_VOICES = 124 +# +# +# @pytest.mark.asyncio +# async def test_list_tts_models(openai_client: AsyncOpenAI) -> None: +# raise NotImplementedError