-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_train.py
49 lines (35 loc) · 1.3 KB
/
plot_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from __future__ import print_function
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
fig, ax = plt.subplots(nrows=1, ncols=2)
for (k, fname) in enumerate(['prd/train.log', 'ttd/train.log']):
print('Reading:', fname)
df = pd.read_csv(fname)
epoch = df['epoch'].values + 1
loss = df['loss'].values
val_loss = df['val_loss'].values
ax[k].plot(epoch, loss, label='loss')
ax[k].plot(epoch, val_loss, label='val_loss')
ax[k].set_xlabel('epoch')
if fname.startswith('prd'):
ax[k].set_title('probability of disruption', fontsize='medium')
ax[k].set_ylabel('binary cross-entropy')
else:
ax[k].set_title('time to disruption', fontsize='medium')
ax[k].set_ylabel('mean absolute error (s)')
ax[k].legend()
ax[k].grid()
i = np.argmin(val_loss)
min_val_loss = val_loss[i]
min_val_epoch = epoch[i]
print('min_val_loss: %10.6f' % min_val_loss)
print('min_val_epoch: %d' % min_val_epoch)
(x_min, x_max) = ax[k].get_xlim()
(y_min, y_max) = ax[k].get_ylim()
ax[k].plot([x_min, min_val_epoch], [min_val_loss, min_val_loss], 'k--')
ax[k].plot([min_val_epoch, min_val_epoch], [y_min, min_val_loss], 'k--')
ax[k].set_xlim(0, x_max)
ax[k].set_ylim(y_min, y_max)
plt.tight_layout()
plt.show()