@@ -86,26 +86,16 @@ def add_layer(self, x, out_size, ac=None):
86
86
def plot_histogram (l_in , l_in_bn , pre_ac , pre_ac_bn ):
87
87
for i , (ax_pa , ax_pa_bn , ax , ax_bn ) in enumerate (zip (axs [0 , :], axs [1 , :], axs [2 , :], axs [3 , :])):
88
88
[a .clear () for a in [ax_pa , ax_pa_bn , ax , ax_bn ]]
89
- if i == 0 :
90
- p_range = (- 7 , 10 )
91
- the_range = (- 7 , 10 )
92
- else :
93
- p_range = (- 4 , 4 )
94
- the_range = (- 1 , 1 )
89
+ if i == 0 : p_range = (- 7 , 10 ); the_range = (- 7 , 10 )
90
+ else : p_range = (- 4 , 4 ); the_range = (- 1 , 1 )
95
91
ax_pa .set_title ('L' + str (i ))
96
92
ax_pa .hist (pre_ac [i ].ravel (), bins = 10 , range = p_range , color = '#FF9359' , alpha = 0.5 )
97
93
ax_pa_bn .hist (pre_ac_bn [i ].ravel (), bins = 10 , range = p_range , color = '#74BCFF' , alpha = 0.5 )
98
94
ax .hist (l_in [i ].ravel (), bins = 10 , range = the_range , color = '#FF9359' )
99
95
ax_bn .hist (l_in_bn [i ].ravel (), bins = 10 , range = the_range , color = '#74BCFF' )
100
96
for a in [ax_pa , ax , ax_pa_bn , ax_bn ]:
101
- a .set_yticks (())
102
- a .set_xticks (())
103
- ax_pa_bn .set_xticks (p_range )
104
- ax_bn .set_xticks (the_range )
105
- axs [0 , 0 ].set_ylabel ('PreAct' )
106
- axs [1 , 0 ].set_ylabel ('BN PreAct' )
107
- axs [2 , 0 ].set_ylabel ('Act' )
108
- axs [3 , 0 ].set_ylabel ('BN Act' )
97
+ a .set_yticks (()); a .set_xticks (())
98
+ ax_pa_bn .set_xticks (p_range ); ax_bn .set_xticks (the_range ); axs [2 , 0 ].set_ylabel ('Act' ); axs [3 , 0 ].set_ylabel ('BN Act' )
109
99
plt .pause (0.01 )
110
100
111
101
losses = [[], []] # record test loss
@@ -137,15 +127,12 @@ def plot_histogram(l_in, l_in_bn, pre_ac, pre_ac_bn):
137
127
plt .figure (2 )
138
128
plt .plot (losses [0 ], c = '#FF9359' , lw = 3 , label = 'Original' )
139
129
plt .plot (losses [1 ], c = '#74BCFF' , lw = 3 , label = 'Batch Normalization' )
140
- plt .ylabel ('test loss' )
141
- plt .ylim ((0 , 2000 ))
142
- plt .legend (loc = 'best' )
130
+ plt .ylabel ('test loss' ); plt .ylim ((0 , 2000 )); plt .legend (loc = 'best' )
143
131
144
132
# plot prediction line
145
133
pred , pred_bn = sess .run ([nets [0 ].out , nets [1 ].out ], {tf_x : test_x , tf_is_train : False })
146
134
plt .figure (3 )
147
135
plt .plot (test_x , pred , c = '#FF9359' , lw = 4 , label = 'Original' )
148
136
plt .plot (test_x , pred_bn , c = '#74BCFF' , lw = 4 , label = 'Batch Normalization' )
149
137
plt .scatter (x [:200 ], y [:200 ], c = 'r' , s = 50 , alpha = 0.2 , label = 'train' )
150
- plt .legend (loc = 'best' )
151
- plt .show ()
138
+ plt .legend (loc = 'best' ); plt .show ()
0 commit comments