|
| 1 | +/** |
| 2 | + * @license |
| 3 | + * Copyright 2023 Google LLC. |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + * ============================================================================= |
| 16 | + */ |
| 17 | + |
| 18 | +/** |
| 19 | + * Base class for Backbone models. |
| 20 | + */ |
| 21 | + |
| 22 | +/* Original source: keras_nlp/models/gpt2/gpt2_backbone.py */ |
| 23 | +import { serialization } from '@tensorflow/tfjs-core'; |
| 24 | + |
| 25 | +import { RandomNormal } from '../../../../initializers'; |
| 26 | +import { input } from '../../../../exports'; |
| 27 | +import { Embedding } from '../../../embeddings'; |
| 28 | +import { SymbolicTensor } from '../../../../engine/topology'; |
| 29 | +import { PositionEmbedding } from '../../modeling/position_embedding'; |
| 30 | +import { add } from '../../../../exports_layers'; |
| 31 | +import { Dropout } from '../../../core'; |
| 32 | +import { TransformerDecoder } from '../../modeling/transformer_decoder'; |
| 33 | +import { getActivation } from '../../../../activations'; |
| 34 | +import { LayerNormalization } from '../../../normalization'; |
| 35 | +import { Backbone } from '../backbone'; |
| 36 | + |
| 37 | +function gpt2KernelInitializer(stddev = 0.02) { |
| 38 | + return new RandomNormal({stddev}); |
| 39 | +} |
| 40 | + |
| 41 | +export interface GPT2BackboneArgs { |
| 42 | + /** |
| 43 | + * Integer. The size of the token vocabulary. |
| 44 | + */ |
| 45 | + vocabularySize: number; |
| 46 | + |
| 47 | + /** |
| 48 | + * Integer. The number of transformer layers. |
| 49 | + */ |
| 50 | + numLayers: number; |
| 51 | + |
| 52 | + /** |
| 53 | + * Integer. The number of attention heads for each transformer. |
| 54 | + * The hidden size must be divisible by the number of attention heads. |
| 55 | + */ |
| 56 | + numHeads: number; |
| 57 | + |
| 58 | + /** |
| 59 | + * Integer. The size of the transformer encoding and pooler layers. |
| 60 | + */ |
| 61 | + hiddenDim: number; |
| 62 | + |
| 63 | + /** |
| 64 | + * Integer. The output dimension of the first Dense layer in a two-layer |
| 65 | + * feedforward network for each transformer. |
| 66 | + */ |
| 67 | + intermediateDim: number; |
| 68 | + |
| 69 | + /** |
| 70 | + * Float. Dropout probability for the Transformer encoder. |
| 71 | + * Defaults to 0.2. |
| 72 | + */ |
| 73 | + dropout?: number; |
| 74 | + |
| 75 | + /** |
| 76 | + * Integer. The maximum sequence length that this encoder can consume. |
| 77 | + * If `null`, `maxSequenceLength` uses the value from sequence length. |
| 78 | + * This determines the variable shape for positional embeddings. |
| 79 | + * Defaults to 1024. |
| 80 | + */ |
| 81 | + maxSequenceLength?: number; |
| 82 | +} |
| 83 | + |
| 84 | +/** |
| 85 | + * GPT-2 core network with hyperparameters. |
| 86 | + * |
| 87 | + * This network implements a Transformer-based decoder network, |
| 88 | + * Generative Pretrained Transformer-2 (GPT-2), as described in |
| 89 | + * ["Language Models are Unsupervised Multitask Learners"](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf). |
| 90 | + * It includes the embedding lookups and transformer layers. |
| 91 | + * |
| 92 | + * The default constructor gives a fully customizable, randomly initialized |
| 93 | + * GPT-2 model with any number of layers, heads, and embedding |
| 94 | + * dimensions. To load preset architectures and weights, use the `fromPreset` |
| 95 | + * constructor. |
| 96 | + * |
| 97 | + * Disclaimer: Pre-trained models are provided on an "as is" basis, without |
| 98 | + * warranties or conditions of any kind. The underlying model is provided by a |
| 99 | + * third party and subject to a separate license, available |
| 100 | + * [here](https://github.com/openai/gpt-2). |
| 101 | + * |
| 102 | + * |
| 103 | + * Example usage: |
| 104 | + * ```js |
| 105 | + * const tokenIds = tf.ones([1, 12]), dtype="int32"); |
| 106 | + * const paddingMask = tf.tensor( |
| 107 | + * [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]], 'int32'); |
| 108 | + * |
| 109 | + * # Pretrained GPT-2 decoder. |
| 110 | + * model = GPT2Backbone.fromPreset("gpt2_base_en"); |
| 111 | + * model.apply(inputData, {paddingMask}); |
| 112 | + * |
| 113 | + * # Randomly initialized GPT-2 decoder with custom config. |
| 114 | + * model = kerasNlp.models.GPT2Backbone({ |
| 115 | + * vocabularySize: 50257, |
| 116 | + * numLayers: 12, |
| 117 | + * numHeads: 12, |
| 118 | + * hiddenDim: 768, |
| 119 | + * intermediateDim: 3072, |
| 120 | + * maxSequenceLength: 1024, |
| 121 | + * }); |
| 122 | + * model.apply(inputData, {paddingMask}); |
| 123 | + * ``` |
| 124 | + */ |
| 125 | +export class GPT2Backbone extends Backbone { |
| 126 | + /** @nocollapse */ |
| 127 | + static override className = 'GPT2Backbone'; |
| 128 | + |
| 129 | + private vocabularySize: number; |
| 130 | + private numLayers: number; |
| 131 | + private numHeads: number; |
| 132 | + private hiddenDim: number; |
| 133 | + private intermediateDim: number; |
| 134 | + private dropout: number; |
| 135 | + private maxSequenceLength: number; |
| 136 | + |
| 137 | + constructor(args: GPT2BackboneArgs) { |
| 138 | + args.dropout = args.dropout ?? 0.1; |
| 139 | + args.maxSequenceLength = args.maxSequenceLength ?? 1024; |
| 140 | + |
| 141 | + // Inputs |
| 142 | + const tokenIds = input({shape: [null], dtype: 'int32', name: 'token_ids'}); |
| 143 | + const paddingMask = |
| 144 | + input({shape: [null], dtype: 'int32', name: 'padding_mask'}); |
| 145 | + |
| 146 | + // Embed tokens, positions. |
| 147 | + const tokenEmbedding = new Embedding({ |
| 148 | + inputDim: args.vocabularySize, |
| 149 | + outputDim: args.hiddenDim, |
| 150 | + embeddingsInitializer: gpt2KernelInitializer(0.01), |
| 151 | + name: 'token_embedding', |
| 152 | + }).apply(tokenIds) as SymbolicTensor; |
| 153 | + |
| 154 | + const positionEmbedding = new PositionEmbedding({ |
| 155 | + initializer: gpt2KernelInitializer(0.02), |
| 156 | + sequenceLength: args.maxSequenceLength, |
| 157 | + name: 'position_embedding', |
| 158 | + }).apply(tokenEmbedding) as SymbolicTensor; |
| 159 | + |
| 160 | + // Sum and apply dropout to embeddings. |
| 161 | + let x = add({name: 'embeddings_add'}) |
| 162 | + .apply([tokenEmbedding, positionEmbedding]) as SymbolicTensor; |
| 163 | + x = new Dropout({rate: args.dropout, name: 'embeddings_dropout'}) |
| 164 | + .apply(x) as SymbolicTensor; |
| 165 | + |
| 166 | + // Apply successive transformer decoder blocks. |
| 167 | + for(let i = 0; i < args.numLayers; i++) { |
| 168 | + x = new TransformerDecoder({ |
| 169 | + intermediateDim: args.intermediateDim, |
| 170 | + numHeads: args.numHeads, |
| 171 | + dropout: args.dropout, |
| 172 | + layerNormEpsilon: 1e-05, |
| 173 | + // TODO(pforderique): Implement gelu. |
| 174 | + activation: getActivation('relu'), |
| 175 | + kernelInitializer: gpt2KernelInitializer(0.02), |
| 176 | + normalizeFirst: true, |
| 177 | + name: `transformer_layer_${i}`, |
| 178 | + }).apply(x, {decoderPaddingMask: paddingMask}) as SymbolicTensor; |
| 179 | + } |
| 180 | + |
| 181 | + const sequenceOutput = new LayerNormalization({ |
| 182 | + name: 'layer_norm', |
| 183 | + axis: -1, |
| 184 | + epsilon: 1e-05, |
| 185 | + dtype: 'float32', |
| 186 | + }).apply(x) as SymbolicTensor; |
| 187 | + |
| 188 | + // Instantiate using Functional API Model constructor. |
| 189 | + super({ |
| 190 | + inputs: [tokenIds, paddingMask], |
| 191 | + outputs: sequenceOutput, |
| 192 | + name: 'gpt2_backbone' |
| 193 | + }); |
| 194 | + this.vocabularySize = args.vocabularySize; |
| 195 | + this.numLayers = args.numLayers; |
| 196 | + this.numHeads = args.numHeads; |
| 197 | + this.hiddenDim = args.hiddenDim; |
| 198 | + this.intermediateDim = args.intermediateDim; |
| 199 | + this.dropout = args.dropout ?? 0.1; |
| 200 | + this.maxSequenceLength = args.maxSequenceLength ?? 1024; |
| 201 | + } |
| 202 | + |
| 203 | + override getConfig(): serialization.ConfigDict { |
| 204 | + const config: serialization.ConfigDict = { |
| 205 | + vocabularySize: this.vocabularySize, |
| 206 | + numLayers: this.numLayers, |
| 207 | + numHeads: this.numHeads, |
| 208 | + hiddenDim: this.hiddenDim, |
| 209 | + intermediateDim: this.intermediateDim, |
| 210 | + dropout: this.dropout, |
| 211 | + maxSequenceLength: this.maxSequenceLength, |
| 212 | + }; |
| 213 | + const baseConfig = super.getConfig(); |
| 214 | + Object.assign(config, baseConfig); |
| 215 | + return config; |
| 216 | + } |
| 217 | + |
| 218 | + override get tokenEmbedding() { |
| 219 | + return this.getLayer('token_embedding'); |
| 220 | + } |
| 221 | +} |
| 222 | +serialization.registerClass(GPT2Backbone); |
0 commit comments