Skip to content

Commit 927f667

Browse files
#2066 seq2seq.beamsearch (#2198)
* beamsearch with attention wrapper * beamsearch with attention wrapperv1 * flake suggestions * Update attention_wrapper.py * flake suggestions2 * changes * Apply suggestions from code review * Update tensorflow_addons/seq2seq/attention_wrapper.py * Update tensorflow_addons/seq2seq/attention_wrapper.py * Apply suggestions from code review * Apply suggestions from code review * Apply suggestions from code review * Update tensorflow_addons/seq2seq/attention_wrapper.py Co-authored-by: Tzu-Wei Sung <[email protected]>
1 parent 0cb4674 commit 927f667

File tree

1 file changed

+19
-22
lines changed

1 file changed

+19
-22
lines changed

tensorflow_addons/seq2seq/attention_wrapper.py

+19-22
Original file line numberDiff line numberDiff line change
@@ -1390,11 +1390,13 @@ def clone(self, **kwargs):
13901390
13911391
Example:
13921392
1393-
```python
1394-
initial_state = attention_wrapper.get_initial_state(
1395-
batch_size=..., dtype=...)
1396-
initial_state = initial_state.clone(cell_state=encoder_state)
1397-
```
1393+
>>> batch_size = 1
1394+
>>> memory = tf.random.normal(shape=[batch_size, 3, 100])
1395+
>>> encoder_state = [tf.zeros((batch_size, 100)), tf.zeros((batch_size, 100))]
1396+
>>> attention_mechanism = tfa.seq2seq.LuongAttention(100, memory=memory, memory_sequence_length=[3] * batch_size)
1397+
>>> attention_cell = tfa.seq2seq.AttentionWrapper(tf.keras.layers.LSTMCell(100), attention_mechanism, attention_layer_size=10)
1398+
>>> decoder_initial_state = attention_cell.get_initial_state(batch_size=batch_size, dtype=tf.float32)
1399+
>>> decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state)
13981400
13991401
Args:
14001402
**kwargs: Any properties of the state object to replace in the
@@ -1611,23 +1613,18 @@ def __init__(
16111613
16121614
An example:
16131615
1614-
```
1615-
tiled_encoder_outputs = tfa.seq2seq.tile_batch(
1616-
encoder_outputs, multiplier=beam_width)
1617-
tiled_encoder_final_state = tfa.seq2seq.tile_batch(
1618-
encoder_final_state, multiplier=beam_width)
1619-
tiled_sequence_length = tfa.seq2seq.tile_batch(
1620-
sequence_length, multiplier=beam_width)
1621-
attention_mechanism = MyFavoriteAttentionMechanism(
1622-
num_units=attention_depth,
1623-
memory=tiled_inputs,
1624-
memory_sequence_length=tiled_sequence_length)
1625-
attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
1626-
decoder_initial_state = attention_cell.get_initial_state(
1627-
batch_size=true_batch_size * beam_width, dtype=dtype)
1628-
decoder_initial_state = decoder_initial_state.clone(
1629-
cell_state=tiled_encoder_final_state)
1630-
```
1616+
>>> batch_size = 1
1617+
>>> beam_width = 5
1618+
>>> sequence_length = tf.convert_to_tensor([5])
1619+
>>> encoder_outputs = tf.random.uniform(shape=(batch_size, 5, 10))
1620+
>>> encoder_final_state = [tf.zeros((batch_size, 10)), tf.zeros((batch_size, 10))]
1621+
>>> tiled_encoder_outputs = tfa.seq2seq.tile_batch(encoder_outputs, multiplier=beam_width)
1622+
>>> tiled_encoder_final_state = tfa.seq2seq.tile_batch(encoder_final_state, multiplier=beam_width)
1623+
>>> tiled_sequence_length = tfa.seq2seq.tile_batch(sequence_length, multiplier=beam_width)
1624+
>>> attention_mechanism = tfa.seq2seq.BahdanauAttention(10, memory=tiled_encoder_outputs, memory_sequence_length=tiled_sequence_length)
1625+
>>> attention_cell = tfa.seq2seq.AttentionWrapper(tf.keras.layers.LSTMCell(10), attention_mechanism)
1626+
>>> decoder_initial_state = attention_cell.get_initial_state(batch_size=batch_size * beam_width, dtype=tf.float32)
1627+
>>> decoder_initial_state = decoder_initial_state.clone(cell_state=tiled_encoder_final_state)
16311628
16321629
Args:
16331630
cell: A layer that implements the `tf.keras.layers.AbstractRNNCell`

0 commit comments

Comments
 (0)