@@ -1390,11 +1390,13 @@ def clone(self, **kwargs):
13901390
13911391 Example:
13921392
1393- ```python
1394- initial_state = attention_wrapper.get_initial_state(
1395- batch_size=..., dtype=...)
1396- initial_state = initial_state.clone(cell_state=encoder_state)
1397- ```
1393+ >>> batch_size = 1
1394+ >>> memory = tf.random.normal(shape=[batch_size, 3, 100])
1395+ >>> encoder_state = [tf.zeros((batch_size, 100)), tf.zeros((batch_size, 100))]
1396+ >>> attention_mechanism = tfa.seq2seq.LuongAttention(100, memory=memory, memory_sequence_length=[3] * batch_size)
1397+ >>> attention_cell = tfa.seq2seq.AttentionWrapper(tf.keras.layers.LSTMCell(100), attention_mechanism, attention_layer_size=10)
1398+ >>> decoder_initial_state = attention_cell.get_initial_state(batch_size=batch_size, dtype=tf.float32)
1399+ >>> decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state)
13981400
13991401 Args:
14001402 **kwargs: Any properties of the state object to replace in the
@@ -1611,23 +1613,18 @@ def __init__(
16111613
16121614 An example:
16131615
1614- ```
1615- tiled_encoder_outputs = tfa.seq2seq.tile_batch(
1616- encoder_outputs, multiplier=beam_width)
1617- tiled_encoder_final_state = tfa.seq2seq.tile_batch(
1618- encoder_final_state, multiplier=beam_width)
1619- tiled_sequence_length = tfa.seq2seq.tile_batch(
1620- sequence_length, multiplier=beam_width)
1621- attention_mechanism = MyFavoriteAttentionMechanism(
1622- num_units=attention_depth,
1623- memory=tiled_inputs,
1624- memory_sequence_length=tiled_sequence_length)
1625- attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
1626- decoder_initial_state = attention_cell.get_initial_state(
1627- batch_size=true_batch_size * beam_width, dtype=dtype)
1628- decoder_initial_state = decoder_initial_state.clone(
1629- cell_state=tiled_encoder_final_state)
1630- ```
1616+ >>> batch_size = 1
1617+ >>> beam_width = 5
1618+ >>> sequence_length = tf.convert_to_tensor([5])
1619+ >>> encoder_outputs = tf.random.uniform(shape=(batch_size, 5, 10))
1620+ >>> encoder_final_state = [tf.zeros((batch_size, 10)), tf.zeros((batch_size, 10))]
1621+ >>> tiled_encoder_outputs = tfa.seq2seq.tile_batch(encoder_outputs, multiplier=beam_width)
1622+ >>> tiled_encoder_final_state = tfa.seq2seq.tile_batch(encoder_final_state, multiplier=beam_width)
1623+ >>> tiled_sequence_length = tfa.seq2seq.tile_batch(sequence_length, multiplier=beam_width)
1624+ >>> attention_mechanism = tfa.seq2seq.BahdanauAttention(10, memory=tiled_encoder_outputs, memory_sequence_length=tiled_sequence_length)
1625+ >>> attention_cell = tfa.seq2seq.AttentionWrapper(tf.keras.layers.LSTMCell(10), attention_mechanism)
1626+ >>> decoder_initial_state = attention_cell.get_initial_state(batch_size=batch_size * beam_width, dtype=tf.float32)
1627+ >>> decoder_initial_state = decoder_initial_state.clone(cell_state=tiled_encoder_final_state)
16311628
16321629 Args:
16331630 cell: A layer that implements the `tf.keras.layers.AbstractRNNCell`
0 commit comments