-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathinference.py
48 lines (35 loc) · 1.57 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import argparse
import torchaudio as ta
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchaudio.transforms as T
from train import SRTrain
from utils import *
import yaml
def inference(config, args):
sr_train = SRTrain.load_from_checkpoint(args.path_ckpt, config = config)
sr_train.aecnn.eval()
if args.mode == "wav":
wav_lr, _ =ta.load(args.path_in)
wav_sr = sr_train.synth_one_sample(wav_lr)
wav_sr = wav_sr.cpu()
filename = get_filename(args.path_in)
ta.save(os.path.join(os.path.dirname(args.path_in),filename[0]+"_sr"+filename[1]), wav_sr, 16000)
elif args.mode == "dir":
check_dir_exist(args.path_out)
path_wavs = get_wav_paths(args.path_in)
for path_wav in path_wavs:
wav_lr, _ = ta.load(path_wav)
wav_sr = sr_train.synth_one_sample(wav_lr)
wav_sr = wav_sr.cpu()
ta.save(os.path.join(args.path_out, os.path.basename(path_wav)), wav_sr, 16000)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--path_ckpt", type = str)
parser.add_argument("--mode", type = str, help = 'wav/dir', default = 'wav')
parser.add_argument("--path_in", type = str, help = "path of input wav file or directory")
parser.add_argument("--path_out", type = str, help = "path of directory of output file")
args = parser.parse_args()
config = yaml.load(open("./config.yaml", 'r'), Loader=yaml.FullLoader)
inference(config, args)