Skip to content

Commit bf8a809

Browse files
Squadrickseanpmorgan
authored andcommitted
Add WeightNorm support for RNNs (#769)
1 parent bcd790f commit bf8a809

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

tensorflow_addons/layers/wrappers.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,31 +58,34 @@ def __init__(self, layer, data_init=True, **kwargs):
5858
super(WeightNormalization, self).__init__(layer, **kwargs)
5959
self.data_init = data_init
6060
self._track_trackable(layer, name='layer')
61+
self.is_rnn = isinstance(self.layer, tf.keras.layers.RNN)
6162

6263
def build(self, input_shape):
6364
"""Build `Layer`"""
64-
input_shape = tf.TensorShape(input_shape).as_list()
65+
input_shape = tf.TensorShape(input_shape)
6566
self.input_spec = tf.keras.layers.InputSpec(
6667
shape=[None] + input_shape[1:])
6768

6869
if not self.layer.built:
6970
self.layer.build(input_shape)
7071

71-
if not hasattr(self.layer, 'kernel'):
72+
kernel_layer = self.layer.cell if self.is_rnn else self.layer
73+
74+
if not hasattr(kernel_layer, 'kernel'):
7275
raise ValueError('`WeightNormalization` must wrap a layer that'
7376
' contains a `kernel` for weights')
7477

7578
# The kernel's filter or unit dimension is -1
76-
self.layer_depth = int(self.layer.kernel.shape[-1])
77-
self.kernel_norm_axes = list(range(self.layer.kernel.shape.rank - 1))
79+
self.layer_depth = int(kernel_layer.kernel.shape[-1])
80+
self.kernel_norm_axes = list(range(kernel_layer.kernel.shape.rank - 1))
7881

7982
self.g = self.add_weight(
8083
name='g',
8184
shape=(self.layer_depth,),
8285
initializer='ones',
83-
dtype=self.layer.kernel.dtype,
86+
dtype=kernel_layer.kernel.dtype,
8487
trainable=True)
85-
self.v = self.layer.kernel
88+
self.v = kernel_layer.kernel
8689

8790
self._initialized = self.add_weight(
8891
name='initialized',
@@ -100,7 +103,10 @@ def build(self, input_shape):
100103
layer_config)
101104
self._naked_clone_layer.build(input_shape)
102105
self._naked_clone_layer.set_weights(self.layer.get_weights())
103-
self._naked_clone_layer.activation = None
106+
if self.is_rnn:
107+
self._naked_clone_layer.cell.activation = None
108+
else:
109+
self._naked_clone_layer.activation = None
104110

105111
self.built = True
106112

tensorflow_addons/layers/wrappers_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,13 @@ def test_weightnorm_with_time_dist(self):
8989
out = tf.keras.layers.TimeDistributed(b)(inputs)
9090
model = tf.keras.Model(inputs, out)
9191

92+
def test_weightnorm_with_rnn(self):
93+
inputs = tf.keras.layers.Input(shape=(None, 3))
94+
rnn_layer = tf.keras.layers.SimpleRNN(4)
95+
wt_rnn = wrappers.WeightNormalization(rnn_layer)
96+
dense = tf.keras.layers.Dense(1)
97+
model = tf.keras.models.Sequential(layers=[inputs, wt_rnn, dense])
98+
9299
def test_save_file_h5(self):
93100
self.create_tempfile('wrapper_test_model.h5')
94101
conv = tf.keras.layers.Conv1D(1, 1)

0 commit comments

Comments
 (0)