Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 68 additions & 52 deletions examples/generative/ipynb/text_generation_fnet.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Darshan Deshpande](https://twitter.com/getdarshan)<br>\n",
"**Date created:** 2021/10/05<br>\n",
"**Last modified:** 2021/10/05<br>\n",
"**Last modified:** 2026/03/18<br>\n",
"**Description:** FNet transformer for text generation in Keras."
]
},
Expand Down Expand Up @@ -54,11 +54,13 @@
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"import os\n",
"\n",
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\" # or \"jax\" , \"torch\"\n",
"import keras\n",
"from keras import layers, ops\n",
"import tensorflow as tf\n",
"\n",
"# Defining hyperparameters\n",
"\n",
"VOCAB_SIZE = 8192\n",
Expand Down Expand Up @@ -98,7 +100,9 @@
")\n",
"\n",
"path_to_dataset = os.path.join(\n",
" os.path.dirname(path_to_zip), \"cornell movie-dialogs corpus\"\n",
" os.path.dirname(path_to_zip),\n",
" \"cornell_movie_dialogs_extracted\",\n",
" \"cornell movie-dialogs corpus\",\n",
")\n",
"path_to_movie_lines = os.path.join(path_to_dataset, \"movie_lines.txt\")\n",
"path_to_movie_conversations = os.path.join(path_to_dataset, \"movie_conversations.txt\")\n",
Expand Down Expand Up @@ -153,6 +157,7 @@
},
"outputs": [],
"source": [
"\n",
"def preprocess_text(sentence):\n",
" sentence = tf.strings.lower(sentence)\n",
" # Adding a space between the punctuation and the last word to allow better tokenization\n",
Expand Down Expand Up @@ -195,13 +200,14 @@
},
"outputs": [],
"source": [
"\n",
"def vectorize_text(inputs, outputs):\n",
" inputs, outputs = vectorizer(inputs), vectorizer(outputs)\n",
" # One extra padding token to the right to match the output shape\n",
" outputs = tf.pad(outputs, [[0, 1]])\n",
" return (\n",
" {\"encoder_inputs\": inputs, \"decoder_inputs\": outputs[:-1]},\n",
" {\"outputs\": outputs[1:]},\n",
" outputs[1:],\n",
" )\n",
"\n",
"\n",
Expand Down Expand Up @@ -245,6 +251,7 @@
},
"outputs": [],
"source": [
"\n",
"class FNetEncoder(layers.Layer):\n",
" def __init__(self, embed_dim, dense_dim, **kwargs):\n",
" super().__init__(**kwargs)\n",
Expand All @@ -260,14 +267,16 @@
" self.layernorm_2 = layers.LayerNormalization()\n",
"\n",
" def call(self, inputs):\n",
" # Casting the inputs to complex64\n",
" inp_complex = tf.cast(inputs, tf.complex64)\n",
" # Projecting the inputs to the frequency domain using FFT2D and\n",
" # extracting the real part of the output\n",
" fft = tf.math.real(tf.signal.fft2d(inp_complex))\n",
" proj_input = self.layernorm_1(inputs + fft)\n",
" # Cast inputs to float32 and create imaginary component\n",
" inp_real = ops.cast(inputs, \"float32\")\n",
" inp_imag = ops.zeros_like(inp_real)\n",
"\n",
" # Apply 2D FFT - returns tuple of (real, imaginary)\n",
" fft_real, fft_imag = ops.fft((inp_real, inp_imag))\n",
" # Use only the real component\n",
" proj_input = self.layernorm_1(inputs + fft_real)\n",
" proj_output = self.dense_proj(proj_input)\n",
" return self.layernorm_2(proj_input + proj_output)"
" return self.layernorm_2(proj_input + proj_output)\n"
]
},
{
Expand All @@ -293,6 +302,7 @@
},
"outputs": [],
"source": [
"\n",
"class PositionalEmbedding(layers.Layer):\n",
" def __init__(self, sequence_length, vocab_size, embed_dim, **kwargs):\n",
" super().__init__(**kwargs)\n",
Expand All @@ -307,14 +317,14 @@
" self.embed_dim = embed_dim\n",
"\n",
" def call(self, inputs):\n",
" length = tf.shape(inputs)[-1]\n",
" positions = tf.range(start=0, limit=length, delta=1)\n",
" length = ops.shape(inputs)[-1]\n",
" positions = ops.arange(0, length, 1)\n",
" embedded_tokens = self.token_embeddings(inputs)\n",
" embedded_positions = self.position_embeddings(positions)\n",
" return embedded_tokens + embedded_positions\n",
"\n",
" def compute_mask(self, inputs, mask=None):\n",
" return tf.math.not_equal(inputs, 0)\n",
" return ops.not_equal(inputs, 0)\n",
"\n",
"\n",
"class FNetDecoder(layers.Layer):\n",
Expand Down Expand Up @@ -342,9 +352,11 @@
"\n",
" def call(self, inputs, encoder_outputs, mask=None):\n",
" causal_mask = self.get_causal_attention_mask(inputs)\n",
"\n",
" if mask is not None:\n",
" padding_mask = tf.cast(mask[:, tf.newaxis, :], dtype=\"int32\")\n",
" padding_mask = tf.minimum(padding_mask, causal_mask)\n",
" padding_mask = ops.cast(mask[:, None, :], \"int32\")\n",
" else:\n",
" padding_mask = None\n",
"\n",
" attention_output_1 = self.attention_1(\n",
" query=inputs, value=inputs, key=inputs, attention_mask=causal_mask\n",
Expand All @@ -363,17 +375,14 @@
" return self.layernorm_3(out_2 + proj_output)\n",
"\n",
" def get_causal_attention_mask(self, inputs):\n",
" input_shape = tf.shape(inputs)\n",
" input_shape = ops.shape(inputs)\n",
" batch_size, sequence_length = input_shape[0], input_shape[1]\n",
" i = tf.range(sequence_length)[:, tf.newaxis]\n",
" j = tf.range(sequence_length)\n",
" mask = tf.cast(i >= j, dtype=\"int32\")\n",
" mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))\n",
" mult = tf.concat(\n",
" [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],\n",
" axis=0,\n",
" )\n",
" return tf.tile(mask, mult)\n",
" i = ops.arange(sequence_length)[:, None]\n",
" j = ops.arange(sequence_length)\n",
" mask = ops.cast(i >= j, dtype=\"int32\")\n",
" mask = ops.reshape(mask, (1, input_shape[1], input_shape[1]))\n",
" multiples = [batch_size, 1, 1]\n",
" return ops.tile(mask, multiples)\n",
"\n",
"\n",
"def create_model():\n",
Expand All @@ -394,7 +403,7 @@
" )\n",
" decoder_outputs = decoder([decoder_inputs, encoder_outputs])\n",
" fnet = keras.Model([encoder_inputs, decoder_inputs], decoder_outputs, name=\"fnet\")\n",
" return fnet"
" return fnet\n"
]
},
{
Expand Down Expand Up @@ -424,9 +433,9 @@
"colab_type": "text"
},
"source": [
"Here, the `epochs` parameter is set to a single epoch, but in practice the model will take around\n",
"**20-30 epochs** of training to start outputting comprehensible sentences. Although accuracy\n",
"is not a good measure for this task, we will use it just to get a hint of the improvement\n",
"The model as configured here uses a simplified architecture to keep training time manageable for a tutorial. The text generation quality\n",
"will be limited - outputs may be generic.\n",
"Although accuracy is not a good measure for this task, we will use it just to get a hint of the improvement\n",
"of the network."
]
},
Expand Down Expand Up @@ -464,38 +473,45 @@
"def decode_sentence(input_sentence):\n",
" # Mapping the input sentence to tokens and adding start and end tokens\n",
" tokenized_input_sentence = vectorizer(\n",
" tf.constant(\"[start] \" + preprocess_text(input_sentence) + \" [end]\")\n",
" \"[start] \" + preprocess_text(input_sentence) + \" [end]\"\n",
" )\n",
" # Initializing the initial sentence consisting of only the start token.\n",
" tokenized_target_sentence = tf.expand_dims(VOCAB.index(\"[start]\"), 0)\n",
" decoded_sentence = \"\"\n",
"\n",
" # Start token\n",
" start_token_index = VOCAB.index(\"[start]\")\n",
" end_token_index = VOCAB.index(\"[end]\")\n",
"\n",
" tokenized_target_sentence = ops.expand_dims(start_token_index, axis=0)\n",
" decoded_sentence = []\n",
"\n",
" for i in range(MAX_LENGTH):\n",
" # Get the predictions\n",
" predictions = fnet.predict(\n",
" {\n",
" \"encoder_inputs\": tf.expand_dims(tokenized_input_sentence, 0),\n",
" \"decoder_inputs\": tf.expand_dims(\n",
" tf.pad(\n",
" \"encoder_inputs\": ops.expand_dims(tokenized_input_sentence, axis=0),\n",
" \"decoder_inputs\": ops.expand_dims(\n",
" ops.pad(\n",
" tokenized_target_sentence,\n",
" [[0, MAX_LENGTH - tf.shape(tokenized_target_sentence)[0]]],\n",
" [[0, MAX_LENGTH - ops.shape(tokenized_target_sentence)[0]]],\n",
" ),\n",
" 0,\n",
" axis=0,\n",
" ),\n",
" }\n",
" )\n",
" # Calculating the token with maximum probability and getting the corresponding word\n",
" sampled_token_index = tf.argmax(predictions[0, i, :])\n",
" sampled_token = VOCAB[sampled_token_index.numpy()]\n",
" # If sampled token is the end token then stop generating and return the sentence\n",
" if tf.equal(sampled_token_index, VOCAB.index(\"[end]\")):\n",
" sampled_token_index = ops.argmax(predictions[0, i, :])\n",
" sampled_token_index = int(sampled_token_index)\n",
"\n",
" if sampled_token_index == end_token_index:\n",
" break\n",
" decoded_sentence += sampled_token + \" \"\n",
" tokenized_target_sentence = tf.concat(\n",
" [tokenized_target_sentence, [sampled_token_index]], 0\n",
"\n",
" decoded_sentence.append(VOCAB[sampled_token_index])\n",
"\n",
" tokenized_target_sentence = ops.concatenate(\n",
" [tokenized_target_sentence, ops.expand_dims(sampled_token_index, axis=0)],\n",
" axis=0,\n",
" )\n",
"\n",
" return decoded_sentence\n",
" return \" \".join(decoded_sentence)\n",
"\n",
"\n",
"decode_sentence(\"Where have you been all this time?\")"
Expand All @@ -517,7 +533,7 @@
"2. [Attention Is All You Need](https://arxiv.org/abs/1706.03762v5) (Vaswani et al.,\n",
"2017)\n",
"\n",
"Thanks to Fran\u00e7ois Chollet for his Keras example on\n",
"Thanks to François Chollet for his Keras example on\n",
"[English-to-Spanish translation with a sequence-to-sequence Transformer](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/)\n",
"from which the decoder implementation was extracted."
]
Expand Down Expand Up @@ -558,9 +574,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
"version": "3.11.14"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading
Loading