-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcli.py
237 lines (197 loc) · 14.5 KB
/
cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import argparse
import os
import pathlib
from urllib.parse import urlparse
import warnings
import numpy as np
import threading
import torch
from app import VadOptions, WhisperTranscriber
from src.config import VAD_INITIAL_PROMPT_MODE_VALUES, ApplicationConfig, VadInitialPromptMode
from src.diarization.diarization import Diarization
from src.download import download_url
from src.languages import get_language_names
from src.utils import optional_float, optional_int, str2bool
from src.whisper.whisperFactory import create_whisper_container
def cli():
app_config = ApplicationConfig.create_default()
whisper_models = app_config.get_model_names()
# For the CLI, we fallback to saving the output to the current directory
output_dir = app_config.output_dir if app_config.output_dir is not None else "."
# Environment variable overrides
default_whisper_implementation = os.environ.get("WHISPER_IMPLEMENTATION", app_config.whisper_implementation)
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, \
help="audio file(s) to transcribe")
parser.add_argument("--model", nargs="+", default=app_config.default_model_name, choices=whisper_models, \
help="name of the Whisper model to use") # medium
parser.add_argument("--model_dir", type=str, default=app_config.model_dir, \
help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default=app_config.device, \
help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=output_dir, \
help="directory to save the outputs")
parser.add_argument("--verbose", type=str2bool, default=app_config.verbose, \
help="whether to print out the progress and debug messages")
parser.add_argument("--whisper_implementation", type=str, default=default_whisper_implementation, choices=["whisper", "faster-whisper"],\
help="the Whisper implementation to use")
parser.add_argument("--task", nargs="+", type=str, default=app_config.task, choices=["transcribe", "translate", "both"], \
help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=app_config.language, choices=sorted(get_language_names()), \
help="language spoken in the audio, specify None to perform language detection")
parser.add_argument("--vad", type=str, default=app_config.default_vad, choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], \
help="The voice activity detection algorithm to use") # silero-vad
parser.add_argument("--vad_initial_prompt_mode", type=str, default=app_config.vad_initial_prompt_mode, choices=VAD_INITIAL_PROMPT_MODE_VALUES, \
help="Whether or not to prepend the initial prompt to each VAD segment (prepend_all_segments), or just the first segment (prepend_first_segment)") # prepend_first_segment
parser.add_argument("--vad_merge_window", type=optional_float, default=app_config.vad_merge_window, \
help="The window size (in seconds) to merge voice segments")
parser.add_argument("--vad_max_merge_size", type=optional_float, default=app_config.vad_max_merge_size,\
help="The maximum size (in seconds) of a voice segment")
parser.add_argument("--vad_padding", type=optional_float, default=app_config.vad_padding, \
help="The padding (in seconds) to add to each voice segment")
parser.add_argument("--vad_prompt_window", type=optional_float, default=app_config.vad_prompt_window, \
help="The window size of the prompt to pass to Whisper")
parser.add_argument("--vad_cpu_cores", type=int, default=app_config.vad_cpu_cores, \
help="The number of CPU cores to use for VAD pre-processing.") # 1
parser.add_argument("--vad_parallel_devices", type=str, default=app_config.vad_parallel_devices, \
help="A commma delimited list of CUDA devices to use for parallel processing. If None, disable parallel processing.") # ""
parser.add_argument("--auto_parallel", type=bool, default=app_config.auto_parallel, \
help="True to use all available GPUs and CPU cores for processing. Use vad_cpu_cores/vad_parallel_devices to specify the number of CPU cores/GPUs to use.") # False
parser.add_argument("--temperature", type=float, default=app_config.temperature, \
help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=app_config.best_of, \
help="number of candidates when sampling with non-zero temperature")
parser.add_argument("--beam_size", type=optional_int, default=app_config.beam_size, \
help="number of beams in beam search, only applicable when temperature is zero")
parser.add_argument("--patience", type=float, default=app_config.patience, \
help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
parser.add_argument("--length_penalty", type=float, default=app_config.length_penalty, \
help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
parser.add_argument("--suppress_tokens", type=str, default=app_config.suppress_tokens, \
help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--initial_prompt", type=str, default=app_config.initial_prompt, \
help="optional text to provide as a prompt for the first window.")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=app_config.condition_on_previous_text, \
help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
# parser.add_argument("--fp16", type=str2bool, default=app_config.fp16, \
# help="whether to perform inference in fp16; True by default")
parser.add_argument("--compute_type", type=str, default=app_config.compute_type, choices=["default", "auto", "int8", "int8_float16", "int16", "float16", "float32"], \
help="the compute type to use for inference")
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=app_config.temperature_increment_on_fallback, \
help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=app_config.compression_ratio_threshold, \
help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=app_config.logprob_threshold, \
help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=app_config.no_speech_threshold, \
help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--word_timestamps", type=str2bool, default=app_config.word_timestamps,
help="(experimental) extract word-level timestamps and refine the results based on them")
parser.add_argument("--prepend_punctuations", type=str, default=app_config.prepend_punctuations,
help="if word_timestamps is True, merge these punctuation symbols with the next word")
parser.add_argument("--append_punctuations", type=str, default=app_config.append_punctuations,
help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--highlight_words", type=str2bool, default=app_config.highlight_words,
help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
parser.add_argument("--threads", type=optional_int, default=0,
help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
# Diarization
parser.add_argument('--auth_token', type=str, default=app_config.auth_token, help='HuggingFace API Token (optional)')
parser.add_argument("--diarization", type=str2bool, default=app_config.diarization, \
help="whether to perform speaker diarization")
parser.add_argument("--diarization_num_speakers", type=int, default=app_config.diarization_speakers, help="Number of speakers")
parser.add_argument("--diarization_min_speakers", type=int, default=app_config.diarization_min_speakers, help="Minimum number of speakers")
parser.add_argument("--diarization_max_speakers", type=int, default=app_config.diarization_max_speakers, help="Maximum number of speakers")
args = parser.parse_args().__dict__
model_names: str = args.pop("model")
model_dir: str = args.pop("model_dir")
output_dir: str = args.pop("output_dir")
device: str = args.pop("device")
os.makedirs(output_dir, exist_ok=True)
if (threads := args.pop("threads")) > 0:
torch.set_num_threads(threads)
whisper_implementation = args.pop("whisper_implementation")
print(f"Using {whisper_implementation} for Whisper")
if model_names[0].endswith(".en") and args["language"] not in {"en", "English"}:
warnings.warn(f"{model_names[0]} is an English-only model but receipted '{args['language']}'; using English instead.")
args["language"] = "en"
temperature = args.pop("temperature")
temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
if temperature_increment_on_fallback is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
else:
temperature = [temperature]
vad = args.pop("vad")
vad_initial_prompt_mode = args.pop("vad_initial_prompt_mode")
vad_merge_window = args.pop("vad_merge_window")
vad_max_merge_size = args.pop("vad_max_merge_size")
vad_padding = args.pop("vad_padding")
vad_prompt_window = args.pop("vad_prompt_window")
vad_cpu_cores = args.pop("vad_cpu_cores")
auto_parallel = args.pop("auto_parallel")
compute_type = args.pop("compute_type")
tasks = args.pop("task")
if "both" in tasks:
if len(tasks) > 1:
print("both is not supported with nargs, assuming only both")
tasks = ["both"]
model_task_list = []
if tasks[0] == "both":
for model_name in model_names:
model_task_list.append({"model": model_name, "task": "transcribe"})
model_task_list.append({"model": model_name, "task": "translate"})
elif len(model_names) == len(tasks):
for model_name, task in zip(model_names, tasks):
model_task_list.append({"model": model_name, "task": task})
else:
print("bad model task combination")
return
model_cache = {}
for model_name in model_names:
if model_name in model_cache:
continue
model_cache[model_name] = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
highlight_words = args.pop("highlight_words")
auth_token = args.pop("auth_token")
diarization = args.pop("diarization")
num_speakers = args.pop("diarization_num_speakers")
min_speakers = args.pop("diarization_min_speakers")
max_speakers = args.pop("diarization_max_speakers")
transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
transcriber.set_auto_parallel(auto_parallel)
if diarization:
transcriber.set_diarization(auth_token=auth_token, enable_daemon_process=False, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
if (transcriber._has_parallel_devices()):
print("Using parallel devices:", transcriber.parallel_device_list)
for audio_path in args.pop("audio"):
sources = []
# Detect URL and download the audio
if (uri_validator(audio_path)):
# Download from YouTube/URL directly
for source_path in download_url(audio_path, maxDuration=-1, destinationDirectory=output_dir, playlistItems=None):
source_name = os.path.basename(source_path)
sources.append({ "path": source_path, "name": source_name })
else:
sources.append({ "path": audio_path, "name": os.path.basename(audio_path) })
for source in sources:
source_path = source["path"]
source_name = source["name"]
for model_task in model_task_list:
model = model_cache[model_task["model"]]
vadOptions = VadOptions(vad, vad_merge_window, vad_max_merge_size, vad_padding, vad_prompt_window,
VadInitialPromptMode.from_string(vad_initial_prompt_mode))
taskArgs = dict(args)
taskArgs["task"] = model_task["task"]
result = transcriber.transcribe_file(model, source_path, temperature=temperature, vadOptions=vadOptions, **taskArgs)
transcriber.write_result(result, source_name, output_dir, highlight_words, extras="_" + model.model_name + "_" + taskArgs["task"])
transcriber.close()
def uri_validator(x):
try:
result = urlparse(x)
return all([result.scheme, result.netloc])
except:
return False
if __name__ == '__main__':
cli()