@@ -1390,11 +1390,13 @@ def clone(self, **kwargs):
1390
1390
1391
1391
Example:
1392
1392
1393
- ```python
1394
- initial_state = attention_wrapper.get_initial_state(
1395
- batch_size=..., dtype=...)
1396
- initial_state = initial_state.clone(cell_state=encoder_state)
1397
- ```
1393
+ >>> batch_size = 1
1394
+ >>> memory = tf.random.normal(shape=[batch_size, 3, 100])
1395
+ >>> encoder_state = [tf.zeros((batch_size, 100)), tf.zeros((batch_size, 100))]
1396
+ >>> attention_mechanism = tfa.seq2seq.LuongAttention(100, memory=memory, memory_sequence_length=[3] * batch_size)
1397
+ >>> attention_cell = tfa.seq2seq.AttentionWrapper(tf.keras.layers.LSTMCell(100), attention_mechanism, attention_layer_size=10)
1398
+ >>> decoder_initial_state = attention_cell.get_initial_state(batch_size=batch_size, dtype=tf.float32)
1399
+ >>> decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state)
1398
1400
1399
1401
Args:
1400
1402
**kwargs: Any properties of the state object to replace in the
@@ -1611,23 +1613,18 @@ def __init__(
1611
1613
1612
1614
An example:
1613
1615
1614
- ```
1615
- tiled_encoder_outputs = tfa.seq2seq.tile_batch(
1616
- encoder_outputs, multiplier=beam_width)
1617
- tiled_encoder_final_state = tfa.seq2seq.tile_batch(
1618
- encoder_final_state, multiplier=beam_width)
1619
- tiled_sequence_length = tfa.seq2seq.tile_batch(
1620
- sequence_length, multiplier=beam_width)
1621
- attention_mechanism = MyFavoriteAttentionMechanism(
1622
- num_units=attention_depth,
1623
- memory=tiled_inputs,
1624
- memory_sequence_length=tiled_sequence_length)
1625
- attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
1626
- decoder_initial_state = attention_cell.get_initial_state(
1627
- batch_size=true_batch_size * beam_width, dtype=dtype)
1628
- decoder_initial_state = decoder_initial_state.clone(
1629
- cell_state=tiled_encoder_final_state)
1630
- ```
1616
+ >>> batch_size = 1
1617
+ >>> beam_width = 5
1618
+ >>> sequence_length = tf.convert_to_tensor([5])
1619
+ >>> encoder_outputs = tf.random.uniform(shape=(batch_size, 5, 10))
1620
+ >>> encoder_final_state = [tf.zeros((batch_size, 10)), tf.zeros((batch_size, 10))]
1621
+ >>> tiled_encoder_outputs = tfa.seq2seq.tile_batch(encoder_outputs, multiplier=beam_width)
1622
+ >>> tiled_encoder_final_state = tfa.seq2seq.tile_batch(encoder_final_state, multiplier=beam_width)
1623
+ >>> tiled_sequence_length = tfa.seq2seq.tile_batch(sequence_length, multiplier=beam_width)
1624
+ >>> attention_mechanism = tfa.seq2seq.BahdanauAttention(10, memory=tiled_encoder_outputs, memory_sequence_length=tiled_sequence_length)
1625
+ >>> attention_cell = tfa.seq2seq.AttentionWrapper(tf.keras.layers.LSTMCell(10), attention_mechanism)
1626
+ >>> decoder_initial_state = attention_cell.get_initial_state(batch_size=batch_size * beam_width, dtype=tf.float32)
1627
+ >>> decoder_initial_state = decoder_initial_state.clone(cell_state=tiled_encoder_final_state)
1631
1628
1632
1629
Args:
1633
1630
cell: A layer that implements the `tf.keras.layers.AbstractRNNCell`
0 commit comments