Skip to content

Commit 08090d0

Browse files
fix state size to be a TensorShape instead of a Tensor (#1121)
* fix state size to be a TensorShape instead of a Tensor * fix Eager mode tests where alignments_size is an int * fix 2 * add a test case * update module name * create an empty TensorShape when self._alignments_size is a Tensor Co-authored-by: Gabriel de Marmiesse <[email protected]>
1 parent 062f026 commit 08090d0

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

tensorflow_addons/seq2seq/attention_wrapper.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,11 +411,14 @@ def deserialize_inner_layer_from_config(cls, config, custom_objects):
411411

412412
@property
413413
def alignments_size(self):
414-
return self._alignments_size
414+
if isinstance(self._alignments_size, int):
415+
return self._alignments_size
416+
else:
417+
return tf.TensorShape([None])
415418

416419
@property
417420
def state_size(self):
418-
return self._alignments_size
421+
return self.alignments_size
419422

420423
def initial_alignments(self, batch_size, dtype):
421424
"""Creates the initial alignment values for the `AttentionWrapper`

tensorflow_addons/seq2seq/attention_wrapper_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -999,6 +999,19 @@ def test_attention_state_with_keras_rnn(self):
999999
initial_state = cell.get_initial_state(batch_size=2, dtype=tf.float32)
10001000
_ = layer(inputs=tf.ones((2, 4, 8)), initial_state=initial_state)
10011001

1002+
def test_attention_state_with_variable_length_input(self):
1003+
cell = tf.keras.layers.LSTMCell(3)
1004+
mechanism = wrapper.LuongAttention(units=3)
1005+
cell = wrapper.AttentionWrapper(cell, mechanism)
1006+
1007+
var_len = tf.random.uniform(shape=(), minval=2, maxval=10, dtype=tf.int32)
1008+
data = tf.ones(shape=(var_len, var_len, 3))
1009+
1010+
mechanism.setup_memory(data)
1011+
layer = tf.keras.layers.RNN(cell)
1012+
1013+
_ = layer(data)
1014+
10021015

10031016
if __name__ == "__main__":
10041017
tf.test.main()

0 commit comments

Comments
 (0)