|
| 1 | +/** |
| 2 | + * @license |
| 3 | + * Copyright 2023 Google LLC. |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + * ============================================================================= |
| 16 | + */ |
| 17 | + |
| 18 | +/** |
| 19 | + * GPT2 Causal LM (Language Model). |
| 20 | + */ |
| 21 | + |
| 22 | +/* Original source: keras-nlp/models/gpt2/gpt2_causal_lm.py */ |
| 23 | +import { NamedTensorMap, Tensor, serialization } from '@tensorflow/tfjs-core'; |
| 24 | + |
| 25 | +import { GPT2Preprocessor } from './gpt2_preprocessor'; |
| 26 | +import { NotImplementedError } from '../../../../errors'; |
| 27 | +import { Layer } from '../../../../exports_layers'; |
| 28 | +import { LayerArgs } from '../../../../engine/topology'; |
| 29 | +import { Embedding } from '../../../../layers/embeddings'; |
| 30 | +import { Shape } from '../../../../keras_format/common'; |
| 31 | +import { GenerativeTask } from '../generative_task'; |
| 32 | +import { GPT2Backbone } from './gpt2_backbone'; |
| 33 | +import { PipelineModelArgs } from '../../utils'; |
| 34 | +import { Kwargs } from '../../../../types'; |
| 35 | + |
| 36 | +declare interface ReverseEmbeddingArgs extends LayerArgs { |
| 37 | + embedding: Embedding; |
| 38 | +} |
| 39 | + |
| 40 | +class ReverseEmbedding extends Layer { |
| 41 | + protected embedding: Embedding; |
| 42 | + |
| 43 | + constructor(args: ReverseEmbeddingArgs) { |
| 44 | + super(args); |
| 45 | + this.embedding = args.embedding; |
| 46 | + } |
| 47 | + |
| 48 | + override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] { |
| 49 | + throw new NotImplementedError(); |
| 50 | + } |
| 51 | + |
| 52 | + override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] { |
| 53 | + throw new NotImplementedError(); |
| 54 | + } |
| 55 | + |
| 56 | +} |
| 57 | + |
| 58 | +export declare interface GPT2CausalLMArgs extends PipelineModelArgs { |
| 59 | + /** |
| 60 | + * A `GPT2Backbone` instance. |
| 61 | + */ |
| 62 | + backbone: GPT2Backbone; |
| 63 | + |
| 64 | + /** |
| 65 | + * Optional `GPT2CausalLMPreprocessor`. |
| 66 | + * If `null`, this model will not apply preprocessing, and inputs should be |
| 67 | + * preprocessed before calling the model. |
| 68 | + */ |
| 69 | + preprocessor?: GPT2Preprocessor; |
| 70 | +} |
| 71 | + |
| 72 | +/** |
| 73 | + * An end-to-end GPT2 model for causal langauge modeling. |
| 74 | + * |
| 75 | + * A causal language model (LM) predicts the next token based on previous |
| 76 | + * tokens. This task setup can be used to train the model unsupervised on |
| 77 | + * plain text input, or to autoregressively generate plain text similar to |
| 78 | + * the data used for training. This task can be used for pre-training or |
| 79 | + * fine-tuning a GPT-2 model, simply by calling `fit()`. |
| 80 | + * |
| 81 | + * This model has a `generate()` method, which generates text based on a |
| 82 | + * prompt. The generation strategy used is controlled by an additional |
| 83 | + * sampler` argument on `compile()`. |
| 84 | + * By default, the top k results will be returned. |
| 85 | + * |
| 86 | + * This model can optionally be configured with a `preprocessor` layer, in |
| 87 | + * which case it will automatically apply preprocessing to string inputs during |
| 88 | + * fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default |
| 89 | + * when creating the model with `fromPreset()`. |
| 90 | + * |
| 91 | + * Disclaimer: Pre-trained models are provided on an "as is" basis, without |
| 92 | + * warranties or conditions of any kind. The underlying model is provided by a |
| 93 | + * third party and subject to a separate license, available |
| 94 | + * here](https://github.com/openai/gpt-2). |
| 95 | + * |
| 96 | + * Use `generate()` to do text generation. |
| 97 | + * ```js |
| 98 | + * const gpt2LM = GPT2CausalLM.fromPreset('gpt2_base_en'); |
| 99 | + * gpt2LM.generate("I want to say", max_length=30); |
| 100 | + * // Generate with batched prompts. |
| 101 | + * gpt2LM.generate(["This is a", "Where are you"], max_length=30); |
| 102 | + * ``` |
| 103 | + * |
| 104 | + * Use `generate()` without preprocessing. |
| 105 | + * ```js |
| 106 | + * // Prompt the model with `5338, 318` (the token ids for `"Who is"`). |
| 107 | + * // Use `"paddingMask"` to indicate values that should not be overridden. |
| 108 | + * const prompt = { |
| 109 | + * tokenIds: tf.tensor([[5338, 318, 0, 0, 0], [5338, 318, 0, 0, 0]]), |
| 110 | + * paddingMask: tf.tensor([[1, 1, 0, 0, 0], [1, 1, 0, 0, 0]]]), |
| 111 | + * }; |
| 112 | + * const gpt2LM = GPT2CausalLM.from_preset('gpt2_base_en', null); |
| 113 | + * gpt2LM.generate(prompt); |
| 114 | + * ``` |
| 115 | + * |
| 116 | + * Call `fit()` on a single batch. |
| 117 | + * ```js |
| 118 | + * const features = ['The quick brown fox jumped.', 'I forgot my homework.']; |
| 119 | + * const gpt2LM = GPT2CausalLM.from_preset('gpt2_base_en'); |
| 120 | + * gpt2LM.fit(features, {batchSize: 2}); |
| 121 | + * ``` |
| 122 | + * |
| 123 | + * Call `fit()` without preprocessing. |
| 124 | + * ```js |
| 125 | + * const x = { |
| 126 | + * tokenIds: tf.tensor([[50256, 1, 2, 3, 4], [50256, 1, 2, 3, 4]]), |
| 127 | + * paddingMask: tf.tensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]), |
| 128 | + * }; |
| 129 | + * const y = tf.tensor([[1, 2, 3, 4, 50256], [1, 2, 3, 4, 50256]]); |
| 130 | + * const sw = tf.tensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]); |
| 131 | + * const gpt2LM = GPT2CausalLM.from_preset('gpt2_base_en', null); |
| 132 | + * gpt2LM.fit(x, y, {sampleWeight: sw, batchSize: 2}); |
| 133 | + * ``` |
| 134 | + * |
| 135 | + * Custom backbone and vocabulary. |
| 136 | + * ```js |
| 137 | + * const features = ["a quick fox.", "a fox quick."]; |
| 138 | + * const vocab = {"<|endoftext|>": 0, "a": 4, "Ġquick": 5, "Ġfox": 6}; |
| 139 | + * const merges = [ |
| 140 | + * "Ġ q", "u i", "c k", "ui ck", "Ġq uick", "Ġ f", "o x", "Ġf ox" |
| 141 | + * ]; |
| 142 | + * const tokenizer = new GPT2Tokenizer({vocabulary: vocab, merges}); |
| 143 | + * const preprocessor = new GPT2CausalLMPreprocessor({ |
| 144 | + * tokenizer, |
| 145 | + * sequence_length: 128, |
| 146 | + * }); |
| 147 | + * const backbone = new GPT2Backbone({ |
| 148 | + * vocabularysize: 30552, |
| 149 | + * numlayers: 4, |
| 150 | + * numheads: 4, |
| 151 | + * hiddendim: 256, |
| 152 | + * intermediatedim: 512, |
| 153 | + * maxSequenceLength: 128, |
| 154 | + * }); |
| 155 | + * const gpt2LM = new GPT2CausalLM({backbone, preprocessor}); |
| 156 | + * gpt2LM.fit(features, {batch_size: 2}); |
| 157 | + * ``` |
| 158 | + */ |
| 159 | +export class GPT2CausalLM extends GenerativeTask { |
| 160 | + /** @nocollapse */ |
| 161 | + static override className = 'GPT2CausalLM'; |
| 162 | + |
| 163 | + constructor(args: GPT2CausalLMArgs) { |
| 164 | + super(args); |
| 165 | + throw new NotImplementedError(`Uses ${ReverseEmbedding}.`); |
| 166 | + } |
| 167 | + |
| 168 | + static override presets<T extends serialization.Serializable>( |
| 169 | + cls: serialization.SerializableConstructor<T> |
| 170 | + ): {} { |
| 171 | + throw new NotImplementedError(); |
| 172 | + } |
| 173 | + |
| 174 | + /** |
| 175 | + * Forward pass of `GPT2CausalLM` with cache. |
| 176 | + * |
| 177 | + * `callWithCache` adds an additional forward pass for the model for |
| 178 | + * autoregressive inference. Unlike calling the model directly, this method |
| 179 | + * allows caching previous key/value Tensors in multi-head attention layer, |
| 180 | + * and avoids recomputing the outputs of seen tokens. |
| 181 | + * |
| 182 | + * @param tokenIds a dense int Tensor with shape `[batchSize, maxLength]`. |
| 183 | + * @param cache a dense float Tensor, the cache of key and value. |
| 184 | + * @param cacheUpdateIndex Integer. The index of current inputs in the whole |
| 185 | + * sequence. |
| 186 | + * @returns [logits, hiddenStates, cache], where `logits` is the |
| 187 | + * language model logits for the input tokenIds, `hiddenStates` is |
| 188 | + * the final hidden representation of the input tokens, and `cache` is |
| 189 | + * the decoding cache. |
| 190 | + */ |
| 191 | + callWithCache( |
| 192 | + tokenIds: Tensor, |
| 193 | + cache: Tensor, |
| 194 | + cacheUpdateIndex: number |
| 195 | + ): [Tensor, Tensor, Tensor] { |
| 196 | + throw new NotImplementedError(); |
| 197 | + } |
| 198 | + |
| 199 | + /** |
| 200 | + * Build an empty cache for use with `callWithCache()`. |
| 201 | + */ |
| 202 | + private buildCache(tokenIds: Tensor): [Tensor, Tensor] { |
| 203 | + throw new NotImplementedError(); |
| 204 | + } |
| 205 | + |
| 206 | + /** |
| 207 | + * A compilable generation function for a single batch of inputs. |
| 208 | + * |
| 209 | + * This function represents the inner generation function for a single batch |
| 210 | + * of inputs. |
| 211 | + * |
| 212 | + * @param inputs An object with two keys `tokenIds` and `paddingMask` and |
| 213 | + * batched tensor values. |
| 214 | + * @param endTokenId The id of the end token to stop on. If all |
| 215 | + * sequences have produced a new `endTokenId`, generation will stop. |
| 216 | + */ |
| 217 | + override generateStep( |
| 218 | + inputs: NamedTensorMap, |
| 219 | + endTokenId: number |
| 220 | + ): NamedTensorMap { |
| 221 | + throw new NotImplementedError(`Uses ${this.buildCache}`); |
| 222 | + } |
| 223 | +} |
| 224 | +serialization.registerClass(GPT2CausalLM); |
0 commit comments