@@ -1390,11 +1390,13 @@ def clone(self, **kwargs):
13901390
13911391        Example: 
13921392
1393-         ```python 
1394-         initial_state = attention_wrapper.get_initial_state( 
1395-             batch_size=..., dtype=...) 
1396-         initial_state = initial_state.clone(cell_state=encoder_state) 
1397-         ``` 
1393+         >>> batch_size = 1 
1394+         >>> memory = tf.random.normal(shape=[batch_size, 3, 100]) 
1395+         >>> encoder_state = [tf.zeros((batch_size, 100)), tf.zeros((batch_size, 100))] 
1396+         >>> attention_mechanism = tfa.seq2seq.LuongAttention(100, memory=memory, memory_sequence_length=[3] * batch_size) 
1397+         >>> attention_cell = tfa.seq2seq.AttentionWrapper(tf.keras.layers.LSTMCell(100), attention_mechanism, attention_layer_size=10) 
1398+         >>> decoder_initial_state = attention_cell.get_initial_state(batch_size=batch_size, dtype=tf.float32) 
1399+         >>> decoder_initial_state = decoder_initial_state.clone(cell_state=encoder_state) 
13981400
13991401        Args: 
14001402          **kwargs: Any properties of the state object to replace in the 
@@ -1611,23 +1613,18 @@ def __init__(
16111613
16121614        An example: 
16131615
1614-         ``` 
1615-         tiled_encoder_outputs = tfa.seq2seq.tile_batch( 
1616-             encoder_outputs, multiplier=beam_width) 
1617-         tiled_encoder_final_state = tfa.seq2seq.tile_batch( 
1618-             encoder_final_state, multiplier=beam_width) 
1619-         tiled_sequence_length = tfa.seq2seq.tile_batch( 
1620-             sequence_length, multiplier=beam_width) 
1621-         attention_mechanism = MyFavoriteAttentionMechanism( 
1622-             num_units=attention_depth, 
1623-             memory=tiled_inputs, 
1624-             memory_sequence_length=tiled_sequence_length) 
1625-         attention_cell = AttentionWrapper(cell, attention_mechanism, ...) 
1626-         decoder_initial_state = attention_cell.get_initial_state( 
1627-             batch_size=true_batch_size * beam_width, dtype=dtype) 
1628-         decoder_initial_state = decoder_initial_state.clone( 
1629-             cell_state=tiled_encoder_final_state) 
1630-         ``` 
1616+         >>> batch_size = 1 
1617+         >>> beam_width = 5 
1618+         >>> sequence_length = tf.convert_to_tensor([5]) 
1619+         >>> encoder_outputs = tf.random.uniform(shape=(batch_size, 5, 10)) 
1620+         >>> encoder_final_state = [tf.zeros((batch_size, 10)), tf.zeros((batch_size, 10))] 
1621+         >>> tiled_encoder_outputs = tfa.seq2seq.tile_batch(encoder_outputs, multiplier=beam_width) 
1622+         >>> tiled_encoder_final_state = tfa.seq2seq.tile_batch(encoder_final_state, multiplier=beam_width) 
1623+         >>> tiled_sequence_length = tfa.seq2seq.tile_batch(sequence_length, multiplier=beam_width) 
1624+         >>> attention_mechanism = tfa.seq2seq.BahdanauAttention(10, memory=tiled_encoder_outputs, memory_sequence_length=tiled_sequence_length) 
1625+         >>> attention_cell = tfa.seq2seq.AttentionWrapper(tf.keras.layers.LSTMCell(10), attention_mechanism) 
1626+         >>> decoder_initial_state = attention_cell.get_initial_state(batch_size=batch_size * beam_width, dtype=tf.float32) 
1627+         >>> decoder_initial_state = decoder_initial_state.clone(cell_state=tiled_encoder_final_state) 
16311628
16321629        Args: 
16331630          cell: A layer that implements the `tf.keras.layers.AbstractRNNCell` 
0 commit comments