Skip to content

Commit

Permalink
feat: filter GET /v1/model by task type
Browse files Browse the repository at this point in the history
  • Loading branch information
Fedir Zadniprovskyi authored and fedirz committed Mar 2, 2025
1 parent 66f9635 commit 337679d
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 51 deletions.
31 changes: 5 additions & 26 deletions src/speaches/api_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Literal

import faster_whisper.transcribe
from pydantic import BaseModel, ConfigDict, Field, computed_field
from pydantic import BaseModel, Field, computed_field

from speaches.text_utils import segments_to_text

Expand Down Expand Up @@ -130,6 +130,9 @@ class ListModelsResponse(BaseModel):
object: Literal["list"] = "list"


ModelTask = Literal["automatic-speech-recognition", "text-to-speech"] # TODO: add "voice-activity-detection"


# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L11146
class Model(BaseModel):
id: str
Expand All @@ -143,31 +146,7 @@ class Model(BaseModel):
language: list[str] | None = None
"""List of ISO 639-3 supported by the model. It's possible that the list will be empty. This field is not a part of the OpenAI API spec and is added for convenience."""

model_config = ConfigDict(
populate_by_name=True,
json_schema_extra={
"examples": [
{
"id": "Systran/faster-whisper-large-v3",
"created": 1700732060,
"object": "model",
"owned_by": "Systran",
},
{
"id": "Systran/faster-distil-whisper-large-v3",
"created": 1711378296,
"object": "model",
"owned_by": "Systran",
},
{
"id": "bofenghuang/whisper-large-v2-cv11-french-ct2",
"created": 1687968011,
"object": "model",
"owned_by": "bofenghuang",
},
]
},
)
task: ModelTask # TODO: make a list?


# https://github.com/openai/openai-openapi/blob/master/openapi.yaml#L10909
Expand Down
31 changes: 29 additions & 2 deletions src/speaches/kokoro_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from kokoro_onnx import Kokoro
import numpy as np

from speaches.api_types import Model
from speaches.api_types import Model, Voice
from speaches.audio import resample_audio
from speaches.hf_utils import list_model_files

Expand Down Expand Up @@ -40,7 +40,7 @@


def get_kokoro_models() -> list[Model]:
model = Model(id=MODEL_ID, owned_by=MODEL_ID.split("/")[0])
model = Model(id=MODEL_ID, owned_by=MODEL_ID.split("/")[0], task="text-to-speech")
return [model]


Expand All @@ -66,6 +66,33 @@ def download_kokoro_model() -> None:
voices_path.write_bytes(res.content)


def list_kokoro_voice_names() -> list[str]:
model_path = get_kokoro_model_path()
voices_path = model_path.parent / "voices.bin"
voices_npz = np.load(voices_path)
return list(voices_npz.keys())


def list_kokoro_voices() -> list[Voice]:
model_path = get_kokoro_model_path()
voices_path = model_path.parent / "voices.bin"
voices_npz = np.load(voices_path)
voice_names: list[str] = list(voices_npz.keys())

voices = [
Voice(
model_id=MODEL_ID,
voice_id=voice_name,
created=int(voices_path.stat().st_mtime),
owned_by=MODEL_ID.split("/")[0],
sample_rate=24000,
model_path=model_path, # HACK: not applicable for Kokoro
)
for voice_name in voice_names
]
return voices


async def generate_audio(
kokoro_tts: Kokoro,
text: str,
Expand Down
6 changes: 3 additions & 3 deletions src/speaches/piper_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from functools import lru_cache
import json
import logging
from pathlib import Path
import time
from typing import TYPE_CHECKING, Any, Literal

Expand All @@ -15,6 +14,7 @@

if TYPE_CHECKING:
from collections.abc import Generator
from pathlib import Path

from piper.voice import PiperVoice

Expand All @@ -31,11 +31,11 @@


def get_piper_models() -> list[Model]:
model = Model(id=MODEL_ID, owned_by=MODEL_ID.split("/")[0])
model = Model(id=MODEL_ID, owned_by=MODEL_ID.split("/")[0], task="text-to-speech")
return [model]


def list_piper_models() -> Generator[Voice, None, None]:
def list_piper_voices() -> Generator[Voice, None, None]:
model_weights_files = list_model_files(MODEL_ID, glob_pattern="**/*.onnx")
for model_weights_file in model_weights_files:
yield Voice(
Expand Down
17 changes: 10 additions & 7 deletions src/speaches/routers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from speaches.api_types import (
ListModelsResponse,
Model,
ModelTask,
)
from speaches.model_aliases import ModelId
from speaches.whisper_utils import list_local_whisper_models, list_whisper_models
Expand All @@ -19,14 +20,16 @@


@router.get("/v1/models")
def get_models() -> ListModelsResponse:
def get_models(task: ModelTask | None = None) -> ListModelsResponse:
models: list[Model] = []
models.extend(kokoro_utils.get_kokoro_models())
models.extend(piper_utils.get_piper_models())
if os.getenv("HF_HUB_OFFLINE") is not None:
models.extend(list(list_local_whisper_models()))
else:
models.extend(list(list_whisper_models()))
if task is None or task == "text-to-speech":
models.extend(kokoro_utils.get_kokoro_models())
models.extend(piper_utils.get_piper_models())
if task is None or task == "automatic-speech-recognition":
if os.getenv("HF_HUB_OFFLINE") is not None:
models.extend(list(list_local_whisper_models()))
else:
models.extend(list(list_whisper_models()))
return ListModelsResponse(data=models)


Expand Down
16 changes: 3 additions & 13 deletions src/speaches/routers/speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class CreateSpeechRequestBody(BaseModel):
@model_validator(mode="after")
def verify_voice_is_valid(self) -> Self:
if self.model == kokoro_utils.MODEL_ID:
assert self.voice in kokoro_utils.VOICE_IDS
assert self.voice in kokoro_utils.list_kokoro_voice_names()
elif self.model == piper_utils.MODEL_ID:
assert self.voice in piper_utils.read_piper_voices_config()
return self
Expand Down Expand Up @@ -147,18 +147,8 @@ async def synthesize(
def list_voices(model_id: ModelId | None = None) -> list[Voice]:
voices: list[Voice] = []
if model_id == kokoro_utils.MODEL_ID or model_id is None:
kokoro_model_path = kokoro_utils.get_kokoro_model_path()
for voice_id in kokoro_utils.VOICE_IDS:
voice = Voice(
created=0,
model_path=kokoro_model_path,
model_id=kokoro_utils.MODEL_ID,
owned_by=kokoro_utils.MODEL_ID.split("/")[0],
sample_rate=kokoro_utils.SAMPLE_RATE,
voice_id=voice_id,
)
voices.append(voice)
voices.extend(list(kokoro_utils.list_kokoro_voices()))
elif model_id == piper_utils.MODEL_ID or model_id is None:
voices.extend(list(piper_utils.list_piper_models()))
voices.extend(list(piper_utils.list_piper_voices()))

return voices
2 changes: 2 additions & 0 deletions src/speaches/whisper_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def list_whisper_models() -> Generator[Model, None, None]:
created=int(model.created_at.timestamp()),
owned_by=model.id.split("/")[0],
language=language,
task=TASK_NAME_TAG,
)
yield transformed_model

Expand Down Expand Up @@ -67,5 +68,6 @@ def list_local_whisper_models() -> Generator[Model, None, None]:
created=int(model.last_modified),
owned_by=model.repo_id.split("/")[0],
language=language,
task=TASK_NAME_TAG,
)
yield transformed_model

0 comments on commit 337679d

Please sign in to comment.