Skip to content

Commit c223185

Browse files
committed
make BN shape compatible to previous version
1 parent 92de448 commit c223185

File tree

5 files changed

+13
-5
lines changed

5 files changed

+13
-5
lines changed

tensorlayer/files/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2666,6 +2666,10 @@ def _load_weights_from_hdf5_group(f, layers, skip=False):
26662666
elif isinstance(layer, tl.layers.Layer):
26672667
weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
26682668
for iid, w_name in enumerate(weight_names):
2669+
# FIXME : this is only for compatibility
2670+
if isinstance(layer, tl.layers.BatchNorm) and np.asarray(g[w_name]).ndim > 1:
2671+
assign_tf_variable(layer.all_weights[iid], np.asarray(g[w_name]).squeeze())
2672+
continue
26692673
assign_tf_variable(layer.all_weights[iid], np.asarray(g[w_name]))
26702674
else:
26712675
raise Exception("Only layer or model can be saved into hdf5.")

tensorlayer/layers/normalization.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ class BatchNorm1d(BatchNorm):
315315
>>> bn = tl.layers.BatchNorm1d(num_features=32)
316316
317317
"""
318+
318319
def _check_input_shape(self, inputs):
319320
if inputs.ndim != 2 and inputs.ndim != 3:
320321
raise ValueError('expected input to be 2D or 3D, but got {}D input'.format(inputs.ndim))
@@ -337,6 +338,7 @@ class BatchNorm2d(BatchNorm):
337338
>>> bn = tl.layers.BatchNorm2d(num_features=32)
338339
339340
"""
341+
340342
def _check_input_shape(self, inputs):
341343
if inputs.ndim != 4:
342344
raise ValueError('expected input to be 4D, but got {}D input'.format(inputs.ndim))
@@ -359,6 +361,7 @@ class BatchNorm3d(BatchNorm):
359361
>>> bn = tl.layers.BatchNorm3d(num_features=32)
360362
361363
"""
364+
362365
def _check_input_shape(self, inputs):
363366
if inputs.ndim != 5:
364367
raise ValueError('expected input to be 5D, but got {}D input'.format(inputs.ndim))

tensorlayer/models/mobilenetv1.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def restore_params(network, path='models'):
4343
expected_bytes=25600116
4444
) # ls -al
4545
params = load_npz(name=os.path.join(path, 'mobilenet.npz'))
46-
for idx, net_weight in enumerate(network.all_weights):
47-
if 'batchnorm' in net_weight.name:
48-
params[idx] = params[idx].reshape(1, 1, 1, -1)
46+
# for idx, net_weight in enumerate(network.all_weights):
47+
# if 'batchnorm' in net_weight.name:
48+
# params[idx] = params[idx].reshape(1, 1, 1, -1)
4949
assign_weights(params[:len(network.all_weights)], network)
5050
del params
5151

tensorlayer/models/resnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ def restore_params(network, path='models'):
194194
continue
195195
w_names = list(f[layer.name])
196196
params = [f[layer.name][n][:] for n in w_names]
197-
if 'bn' in layer.name:
198-
params = [x.reshape(1, 1, 1, -1) for x in params]
197+
# if 'bn' in layer.name:
198+
# params = [x.reshape(1, 1, 1, -1) for x in params]
199199
assign_weights(params, layer)
200200
del params
201201

tests/layers/test_layers_normalization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def test_input_shape(self):
261261
self.assertIsInstance(e, ValueError)
262262
print(e)
263263

264+
264265
if __name__ == '__main__':
265266

266267
tl.logging.set_verbosity(tl.logging.DEBUG)

0 commit comments

Comments
 (0)