-
Notifications
You must be signed in to change notification settings - Fork 141
/
Copy pathlive_captions_web.py
57 lines (48 loc) · 1.66 KB
/
live_captions_web.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from fastrtc import Stream, AdditionalOutputs, audio_to_float32, ReplyOnPause
from functools import lru_cache
import gradio as gr
from typing import Generator, Literal
from numpy.typing import NDArray
import numpy as np
from moonshine_onnx import MoonshineOnnxModel, load_tokenizer
@lru_cache(maxsize=None)
def load_moonshine(
model_name: Literal["moonshine/base", "moonshine/tiny"],
) -> MoonshineOnnxModel:
return MoonshineOnnxModel(model_name=model_name)
tokenizer = load_tokenizer()
def stt(
audio: tuple[int, NDArray[np.int16 | np.float32]],
model_name: Literal["moonshine/base", "moonshine/tiny"],
) -> Generator[AdditionalOutputs, None, None]:
moonshine = load_moonshine(model_name)
sr, audio_np = audio # type: ignore
if audio_np.dtype == np.int16:
audio_np = audio_to_float32(audio)
if audio_np.ndim == 1:
audio_np = audio_np.reshape(1, -1)
tokens = moonshine.generate(audio_np)
yield AdditionalOutputs(tokenizer.decode_batch(tokens)[0])
stream = Stream(
ReplyOnPause(stt, input_sample_rate=16000),
modality="audio",
mode="send",
ui_args={
"title": "Live Captions by Moonshine",
"icon": "default-favicon.ico",
"icon_button_color": "#5c5c5c",
"pulse_color": "#a7c6fc",
"icon_radius": 0,
},
additional_inputs=[
gr.Radio(
choices=["moonshine/base", "moonshine/tiny"],
value="moonshine/base",
label="Model",
)
],
additional_outputs=[gr.Textbox(label="Captions")],
additional_outputs_handler=lambda prev, current: (prev + "\n" + current).strip(),
)
if __name__ == "__main__":
stream.ui.launch()