-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathsilverbox_plot.py
60 lines (46 loc) · 1.1 KB
/
silverbox_plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
if __name__ == '__main__':
# Set seed for reproducibility
np.random.seed(0)
torch.manual_seed(0)
# Settings
add_noise = True
lr = 2e-4
num_iter = 1
test_freq = 100
n_fit = 100000
decimate = 1
n_batch = 1
n_b = 3
n_f = 3
# Column names in the dataset
COL_U = ['V1']
COL_Y = ['V2']
# Load dataset
df_X = pd.read_csv(os.path.join("data", "SNLS80mV.csv"))
# Extract data
y = np.array(df_X[COL_Y], dtype=np.float32)
u = np.array(df_X[COL_U], dtype=np.float32)
u = u-np.mean(u)
fs = 10**7/2**14
N = y.size
ts = 1/fs
t = np.arange(N)*ts
# Fit data
y_fit = y[:n_fit:decimate]
u_fit = u[:n_fit:decimate]
t_fit = t[0:n_fit:decimate]
# In[Plot]
fig, ax = plt.subplots(2, 1, sharex=True)
ax[0].plot(t_fit, y_fit, 'r', label="$y$")
ax[0].grid()
ax[0].set_ylabel('$y$')
ax[1].plot(t_fit, u_fit, 'k', label="$u$")
ax[1].grid()
ax[1].set_ylabel('$u$')
# plt.legend()
plt.show()