-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
123 lines (95 loc) · 3.63 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
Maintainer: Gabriel Dias ([email protected])
Mateus Oliveira ([email protected])
Marcio Almeida ([email protected])
"""
import torch
import yaml
import numpy as np
from typing import List
import h5py
from scipy import signal
import os
def set_device():
if torch.cuda.is_available():
dev = "cuda:0"
else:
dev = "cpu"
device = torch.device(dev)
print('Using {}'.format(device))
return device
def clean_directory(dir_path):
for file_name in os.listdir(dir_path):
file_absolute_path = os.path.join(dir_path, file_name)
if os.path.isfile(file_absolute_path):
os.remove(file_absolute_path)
elif os.path.isdir(file_absolute_path):
clean_directory(file_absolute_path)
os.rmdir(file_absolute_path)
def read_yaml(file: str) -> yaml.loader.FullLoader:
with open(file, "r") as yaml_file:
configurations = yaml.load(yaml_file, Loader=yaml.FullLoader)
return configurations
class ReadDatasets:
@staticmethod
def read_h5_pred_fit(filename: str) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
with h5py.File(filename) as hf:
input_spec = hf["input_spec"][()]
ground = hf["ground"][()]
pred = hf["pred"][()]
ppm = hf["ppm"][()]
return input_spec, ground, pred, ppm
@staticmethod
def write_h5_pred_fit(save_file_path: str,
input_spec: np.ndarray,
ground: np.ndarray,
pred: np.ndarray,
ppm: np.ndarray) -> None:
with h5py.File(save_file_path, 'w') as hf:
hf.create_dataset('input_spec', data=input_spec)
hf.create_dataset('ground', data=ground)
hf.create_dataset('pred', data=pred)
hf.create_dataset('ppm', data=ppm)
@staticmethod
def read_h5(filename: str) -> List[np.ndarray]:
with h5py.File(filename, "r") as f:
a_group_key = list(f.keys())[0]
data = list(f[a_group_key])
return data
def get_fid_params(params: dict) -> np.ndarray:
params.pop("A_Ace", None)
params_array = []
for key in params.keys():
params_array.append(params[key])
return np.asarray(params_array)
def calculate_spectrogram(FID, t, window_size=256, hope_size=64, window='hann', nfft=None):
noverlap = window_size - hope_size
if not signal.check_NOLA(window, window_size, noverlap):
raise ValueError("signal windowing fails Non-zero Overlap Add (NOLA) criterion; "
"STFT not invertible")
fs = 1 / t
_, _, Zxx = signal.stft(FID, fs=fs, nperseg=window_size, noverlap=noverlap,
return_onesided=False, nfft=nfft)
return Zxx
def calculate_fqn(spec, residual, ppm):
dt_max_ind, dt_min_ind = np.amax(np.where(ppm >= 9.8)), np.amin(np.where(ppm <= 10.8))
noise_var = np.var(spec[dt_min_ind:dt_max_ind])
residual_var = np.var(residual)
fqn = residual_var / noise_var
return fqn
class NormalizeData:
def normalize(self, arr, method):
if method == "min-max":
return self.min_max_normalize(arr)
elif method == "z_norm":
return self.z_score_normalize(arr)
def min_max_normalize(self, arr):
min_val = np.min(arr)
max_val = np.max(arr)
normalized_arr = (arr - min_val) / (max_val - min_val)
return normalized_arr
def z_score_normalize(self, arr):
mean = np.mean(arr)
std_dev = np.std(arr)
normalized_arr = (arr - mean) / std_dev
return normalized_arr