-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcnn_train_test_efficiency.py
120 lines (90 loc) · 4 KB
/
cnn_train_test_efficiency.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
from matplotlib import pyplot as plt
import numpy as np
from tensorflow import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, Reshape, GlobalAveragePooling1D, Activation, GlobalAveragePooling2D
from keras.layers import Conv2D, MaxPooling2D, Conv1D, MaxPooling1D
PathToARIANNA = os.environ['ARIANNA_analysis']
######################################################################################################################
"""
This script trains a basic CNN with function train_cnn and then plots the efficiency curve with the function efficiency_curve
The parameters to set are the path to input data, the noise and signal files, and the output name for the model that will be trained and tested
"""
######################################################################################################################
path = "/arianna_data"
noise = np.load(os.path.join(path, "noise.npy")) #input a subset of the data here so that you can validate on the other set
signal = np.load(os.path.join(path, "signal.npy")) #make sure the signal and noise subset of data are the same size
model_name = "trained_CNN_1l-10-8-10_do0.5_mp10_fltn_sigm"
model_path = "/h5_model_path"
if signal.ndim==2:
signal = np.reshape(signal, (signal.shape[0], 1, signal.shape[1]))
noise = np.reshape(noise, (noise.shape[0], 1, noise.shape[1]))
def train_cnn():
x = np.vstack((noise, signal))
n_samples = x.shape[2]
n_channels = x.shape[1]
x = np.expand_dims(x, axis=-1)
y = np.vstack((np.zeros((noise.shape[0], 1)), np.ones((signal.shape[0], 1))))
s = np.arange(x.shape[0])
np.random.shuffle(s)
x = x[s]
y = y[s]
print(x.shape)
BATCH_SIZE = 32
EPOCHS = 100
callbacks_list = [
keras.callbacks.EarlyStopping(monitor='val_loss', patience=2)]
model = Sequential()
model.add(Conv2D(10, (8, 10), activation='relu', input_shape=(n_channels, n_samples, 1)))
model.add(Dropout(0.5))
model.add(MaxPooling2D(pool_size=(1, 10)))
model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='Adam',
loss='binary_crossentropy',
metrics=['accuracy'])
model.summary()
#input the path and file you'd like to save the model as (in h5 format)
model.save(f'{model_path}/{model_name}.h5')
def efficiency_curve(h5_name, n_dpt, colors):
n_signal = signal.shape[0]
n_noise = noise.shape[0]
x = np.vstack((signal, noise))
x = np.expand_dims(x, axis=-1)
# x = np.swapaxes(x, 1, 2)
y = np.zeros((x.shape[0], 2))
y[:n_signal, 1] = 1
y[n_signal:, 0] = 1
model = keras.models.load_model(f'{model_path}/{h5_name}.h5')
y_pred = model.predict(x)
print(y_pred)
ary = np.zeros((2, n_dpt * 2))
vals = np.zeros((2 * n_dpt)) # array of threshold cuts
vals[:n_dpt] = np.linspace(0, 0.9, n_dpt) #doing this in two steps gives more detail in the higher cut values which is usually where the detail is needed
vals[n_dpt:] = np.linspace(0.9, 1, n_dpt)
for i, threshold in enumerate(vals):
eff_signal = np.sum((y_pred[:signal.shape[0], 0] > threshold) == True) / n_signal
eff_noise = np.sum((y_pred[signal.shape[0]:, 0] > threshold) == False) / n_noise
if(eff_noise < 1):
reduction_factor = (1 / (1 - eff_noise))
ary[0][i] = reduction_factor
else:
reduction_factor = (n_noise)
ary[0][i] = reduction_factor
ary[1][i] = eff_signal
return ary[1][1:], ary[0][1:]
def main():
train_cnn()
x1, y1 = efficiency_curve(h5_name=model_name, n_dpt = 500,colors='blue')
plt.plot(x1[0::10], y1[0::10], label='cnn', linewidth=3) #syntax [0::10] plots every 10 events to give a smoother curve
plt.legend(loc='lower left')
plt.yscale('log')
plt.xlim(0.91, 1.1)
plt.ylim(1, 10**6)
plt.grid(True, 'major', 'both', linewidth=0.5)
plt.xlabel('signal efficiency', fontsize=15)
plt.ylabel('noise reduction factor', fontsize=15)
plt.show()
if __name__== "__main__":
main()