@@ -58,31 +58,34 @@ def __init__(self, layer, data_init=True, **kwargs):
58
58
super (WeightNormalization , self ).__init__ (layer , ** kwargs )
59
59
self .data_init = data_init
60
60
self ._track_trackable (layer , name = 'layer' )
61
+ self .is_rnn = isinstance (self .layer , tf .keras .layers .RNN )
61
62
62
63
def build (self , input_shape ):
63
64
"""Build `Layer`"""
64
- input_shape = tf .TensorShape (input_shape ). as_list ()
65
+ input_shape = tf .TensorShape (input_shape )
65
66
self .input_spec = tf .keras .layers .InputSpec (
66
67
shape = [None ] + input_shape [1 :])
67
68
68
69
if not self .layer .built :
69
70
self .layer .build (input_shape )
70
71
71
- if not hasattr (self .layer , 'kernel' ):
72
+ kernel_layer = self .layer .cell if self .is_rnn else self .layer
73
+
74
+ if not hasattr (kernel_layer , 'kernel' ):
72
75
raise ValueError ('`WeightNormalization` must wrap a layer that'
73
76
' contains a `kernel` for weights' )
74
77
75
78
# The kernel's filter or unit dimension is -1
76
- self .layer_depth = int (self . layer .kernel .shape [- 1 ])
77
- self .kernel_norm_axes = list (range (self . layer .kernel .shape .rank - 1 ))
79
+ self .layer_depth = int (kernel_layer .kernel .shape [- 1 ])
80
+ self .kernel_norm_axes = list (range (kernel_layer .kernel .shape .rank - 1 ))
78
81
79
82
self .g = self .add_weight (
80
83
name = 'g' ,
81
84
shape = (self .layer_depth ,),
82
85
initializer = 'ones' ,
83
- dtype = self . layer .kernel .dtype ,
86
+ dtype = kernel_layer .kernel .dtype ,
84
87
trainable = True )
85
- self .v = self . layer .kernel
88
+ self .v = kernel_layer .kernel
86
89
87
90
self ._initialized = self .add_weight (
88
91
name = 'initialized' ,
@@ -100,7 +103,10 @@ def build(self, input_shape):
100
103
layer_config )
101
104
self ._naked_clone_layer .build (input_shape )
102
105
self ._naked_clone_layer .set_weights (self .layer .get_weights ())
103
- self ._naked_clone_layer .activation = None
106
+ if self .is_rnn :
107
+ self ._naked_clone_layer .cell .activation = None
108
+ else :
109
+ self ._naked_clone_layer .activation = None
104
110
105
111
self .built = True
106
112
0 commit comments