-
Notifications
You must be signed in to change notification settings - Fork 685
/
Copy pathkaldi_utils.py
38 lines (31 loc) · 1.27 KB
/
kaldi_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import subprocess
import torch
def convert_args(**kwargs):
args = []
for key, value in kwargs.items():
if key == "sample_rate":
key = "sample_frequency"
key = "--" + key.replace("_", "-")
value = str(value).lower() if value in [True, False] else str(value)
args.append("%s=%s" % (key, value))
return args
def run_kaldi(command, input_type, input_value):
"""Run provided Kaldi command, pass a tensor and get the resulting tensor
Args:
command (list of str): The command with arguments
input_type (str): 'ark' or 'scp'
input_value (Tensor for 'ark', string for 'scp'): The input to pass.
Must be a path to an audio file for 'scp'.
"""
import kaldi_io
key = "foo"
process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
if input_type == "ark":
kaldi_io.write_mat(process.stdin, input_value.cpu().numpy(), key=key)
elif input_type == "scp":
process.stdin.write(f"{key} {input_value}".encode("utf8"))
else:
raise NotImplementedError("Unexpected type")
process.stdin.close()
result = dict(kaldi_io.read_mat_ark(process.stdout))["foo"]
return torch.from_numpy(result.copy()) # copy supresses some torch warning