Skip to content

Commit 0cd53ba

Browse files
pforderiqueLinchennmattsoulanillefengwuyao
authored
Implement GPT2 Backbone (#7894)
* Add spec for multi-head attention * Add CachedMultiHeadAttention cache * Fix typos * Lint * Add Transformer Decoder spec * lint * Add Einsum spec * lint * Remove unused type declaration * Move helper functions outside EinsumDense class * Implement Einsum Dense * Address comments * Implement MHA Layer * Add masked softmax support * Add transformer utils * Add CMHA impl and tests * lint * Fix typo * Fix typo * Check for undef and null * Implement TransformerDecoder * Add Transfomer Decoder tests * Make buildFromSignature public * lint * Strip debug ops in jax conversion tests (#7889) INTERNAL This fixes an internal issue with jax tests. See cl/550054296. * Add gpt2backbone * lint * Fix return type for tokenEmbedding * Break up for loop * lint * Add classnames * Dont need presets --------- Co-authored-by: Linchenn <[email protected]> Co-authored-by: Matthew Soulanille <[email protected]> Co-authored-by: fengwuyao <[email protected]>
1 parent 405c453 commit 0cd53ba

File tree

3 files changed

+306
-1
lines changed

3 files changed

+306
-1
lines changed

tfjs-layers/src/layers/nlp/models/backbone.ts

+4-1
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,11 @@ import { serialization } from '@tensorflow/tfjs-core';
2525
import { ContainerArgs } from '../../../engine/container';
2626
import { LayersModel } from '../../../engine/training';
2727
import { NotImplementedError } from '../../../errors';
28+
import { Layer } from '../../../exports_layers';
2829

2930
export class Backbone extends LayersModel {
31+
/** @nocollapse */
32+
static override className = 'Backbone';
3033

3134
constructor(args: ContainerArgs) {
3235
super(args);
@@ -35,7 +38,7 @@ export class Backbone extends LayersModel {
3538
/**
3639
* A `tf.layers.embedding` instance for embedding token ids.
3740
*/
38-
get tokenEmbedding() {
41+
get tokenEmbedding(): Layer {
3942
throw new NotImplementedError();
4043
}
4144

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
/**
19+
* Base class for Backbone models.
20+
*/
21+
22+
/* Original source: keras_nlp/models/gpt2/gpt2_backbone.py */
23+
import { serialization } from '@tensorflow/tfjs-core';
24+
25+
import { RandomNormal } from '../../../../initializers';
26+
import { input } from '../../../../exports';
27+
import { Embedding } from '../../../embeddings';
28+
import { SymbolicTensor } from '../../../../engine/topology';
29+
import { PositionEmbedding } from '../../modeling/position_embedding';
30+
import { add } from '../../../../exports_layers';
31+
import { Dropout } from '../../../core';
32+
import { TransformerDecoder } from '../../modeling/transformer_decoder';
33+
import { getActivation } from '../../../../activations';
34+
import { LayerNormalization } from '../../../normalization';
35+
import { Backbone } from '../backbone';
36+
37+
function gpt2KernelInitializer(stddev = 0.02) {
38+
return new RandomNormal({stddev});
39+
}
40+
41+
export interface GPT2BackboneArgs {
42+
/**
43+
* Integer. The size of the token vocabulary.
44+
*/
45+
vocabularySize: number;
46+
47+
/**
48+
* Integer. The number of transformer layers.
49+
*/
50+
numLayers: number;
51+
52+
/**
53+
* Integer. The number of attention heads for each transformer.
54+
* The hidden size must be divisible by the number of attention heads.
55+
*/
56+
numHeads: number;
57+
58+
/**
59+
* Integer. The size of the transformer encoding and pooler layers.
60+
*/
61+
hiddenDim: number;
62+
63+
/**
64+
* Integer. The output dimension of the first Dense layer in a two-layer
65+
* feedforward network for each transformer.
66+
*/
67+
intermediateDim: number;
68+
69+
/**
70+
* Float. Dropout probability for the Transformer encoder.
71+
* Defaults to 0.2.
72+
*/
73+
dropout?: number;
74+
75+
/**
76+
* Integer. The maximum sequence length that this encoder can consume.
77+
* If `null`, `maxSequenceLength` uses the value from sequence length.
78+
* This determines the variable shape for positional embeddings.
79+
* Defaults to 1024.
80+
*/
81+
maxSequenceLength?: number;
82+
}
83+
84+
/**
85+
* GPT-2 core network with hyperparameters.
86+
*
87+
* This network implements a Transformer-based decoder network,
88+
* Generative Pretrained Transformer-2 (GPT-2), as described in
89+
* ["Language Models are Unsupervised Multitask Learners"](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf).
90+
* It includes the embedding lookups and transformer layers.
91+
*
92+
* The default constructor gives a fully customizable, randomly initialized
93+
* GPT-2 model with any number of layers, heads, and embedding
94+
* dimensions. To load preset architectures and weights, use the `fromPreset`
95+
* constructor.
96+
*
97+
* Disclaimer: Pre-trained models are provided on an "as is" basis, without
98+
* warranties or conditions of any kind. The underlying model is provided by a
99+
* third party and subject to a separate license, available
100+
* [here](https://github.com/openai/gpt-2).
101+
*
102+
*
103+
* Example usage:
104+
* ```js
105+
* const tokenIds = tf.ones([1, 12]), dtype="int32");
106+
* const paddingMask = tf.tensor(
107+
* [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]], 'int32');
108+
*
109+
* # Pretrained GPT-2 decoder.
110+
* model = GPT2Backbone.fromPreset("gpt2_base_en");
111+
* model.apply(inputData, {paddingMask});
112+
*
113+
* # Randomly initialized GPT-2 decoder with custom config.
114+
* model = kerasNlp.models.GPT2Backbone({
115+
* vocabularySize: 50257,
116+
* numLayers: 12,
117+
* numHeads: 12,
118+
* hiddenDim: 768,
119+
* intermediateDim: 3072,
120+
* maxSequenceLength: 1024,
121+
* });
122+
* model.apply(inputData, {paddingMask});
123+
* ```
124+
*/
125+
export class GPT2Backbone extends Backbone {
126+
/** @nocollapse */
127+
static override className = 'GPT2Backbone';
128+
129+
private vocabularySize: number;
130+
private numLayers: number;
131+
private numHeads: number;
132+
private hiddenDim: number;
133+
private intermediateDim: number;
134+
private dropout: number;
135+
private maxSequenceLength: number;
136+
137+
constructor(args: GPT2BackboneArgs) {
138+
args.dropout = args.dropout ?? 0.1;
139+
args.maxSequenceLength = args.maxSequenceLength ?? 1024;
140+
141+
// Inputs
142+
const tokenIds = input({shape: [null], dtype: 'int32', name: 'token_ids'});
143+
const paddingMask =
144+
input({shape: [null], dtype: 'int32', name: 'padding_mask'});
145+
146+
// Embed tokens, positions.
147+
const tokenEmbedding = new Embedding({
148+
inputDim: args.vocabularySize,
149+
outputDim: args.hiddenDim,
150+
embeddingsInitializer: gpt2KernelInitializer(0.01),
151+
name: 'token_embedding',
152+
}).apply(tokenIds) as SymbolicTensor;
153+
154+
const positionEmbedding = new PositionEmbedding({
155+
initializer: gpt2KernelInitializer(0.02),
156+
sequenceLength: args.maxSequenceLength,
157+
name: 'position_embedding',
158+
}).apply(tokenEmbedding) as SymbolicTensor;
159+
160+
// Sum and apply dropout to embeddings.
161+
let x = add({name: 'embeddings_add'})
162+
.apply([tokenEmbedding, positionEmbedding]) as SymbolicTensor;
163+
x = new Dropout({rate: args.dropout, name: 'embeddings_dropout'})
164+
.apply(x) as SymbolicTensor;
165+
166+
// Apply successive transformer decoder blocks.
167+
for(let i = 0; i < args.numLayers; i++) {
168+
x = new TransformerDecoder({
169+
intermediateDim: args.intermediateDim,
170+
numHeads: args.numHeads,
171+
dropout: args.dropout,
172+
layerNormEpsilon: 1e-05,
173+
// TODO(pforderique): Implement gelu.
174+
activation: getActivation('relu'),
175+
kernelInitializer: gpt2KernelInitializer(0.02),
176+
normalizeFirst: true,
177+
name: `transformer_layer_${i}`,
178+
}).apply(x, {decoderPaddingMask: paddingMask}) as SymbolicTensor;
179+
}
180+
181+
const sequenceOutput = new LayerNormalization({
182+
name: 'layer_norm',
183+
axis: -1,
184+
epsilon: 1e-05,
185+
dtype: 'float32',
186+
}).apply(x) as SymbolicTensor;
187+
188+
// Instantiate using Functional API Model constructor.
189+
super({
190+
inputs: [tokenIds, paddingMask],
191+
outputs: sequenceOutput,
192+
name: 'gpt2_backbone'
193+
});
194+
this.vocabularySize = args.vocabularySize;
195+
this.numLayers = args.numLayers;
196+
this.numHeads = args.numHeads;
197+
this.hiddenDim = args.hiddenDim;
198+
this.intermediateDim = args.intermediateDim;
199+
this.dropout = args.dropout ?? 0.1;
200+
this.maxSequenceLength = args.maxSequenceLength ?? 1024;
201+
}
202+
203+
override getConfig(): serialization.ConfigDict {
204+
const config: serialization.ConfigDict = {
205+
vocabularySize: this.vocabularySize,
206+
numLayers: this.numLayers,
207+
numHeads: this.numHeads,
208+
hiddenDim: this.hiddenDim,
209+
intermediateDim: this.intermediateDim,
210+
dropout: this.dropout,
211+
maxSequenceLength: this.maxSequenceLength,
212+
};
213+
const baseConfig = super.getConfig();
214+
Object.assign(config, baseConfig);
215+
return config;
216+
}
217+
218+
override get tokenEmbedding() {
219+
return this.getLayer('token_embedding');
220+
}
221+
}
222+
serialization.registerClass(GPT2Backbone);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import { Tensor, memory, ones } from '@tensorflow/tfjs-core';
19+
import { GPT2Backbone } from './gpt2_backbone';
20+
21+
/**
22+
* Tests for GPT-2 backbone models.
23+
*/
24+
25+
describe('GPT2Backbone', () => {
26+
let backbone: GPT2Backbone;
27+
// let inputBatch: {[name: string]: Tensor};
28+
let inputBatch: Tensor[];
29+
30+
beforeAll(() => {
31+
backbone = new GPT2Backbone({
32+
vocabularySize: 10,
33+
numLayers: 2,
34+
numHeads: 2,
35+
hiddenDim: 2,
36+
intermediateDim: 4,
37+
maxSequenceLength: 5,
38+
});
39+
inputBatch = [
40+
ones([2, 5], 'int32'), // tokenIds
41+
ones([2, 5], 'int32'), // paddingMask
42+
];
43+
});
44+
45+
it('call', () => {
46+
expect(() => backbone.apply(inputBatch)).not.toThrow();
47+
});
48+
49+
it('token embedding', () => {
50+
const output = backbone.tokenEmbedding.apply(inputBatch[0]) as Tensor;
51+
expect(output.shape).toEqual([2, 5, 2]);
52+
});
53+
54+
it('name', () => {
55+
// Check default name passed through.
56+
expect(backbone.name).toMatch('gpt2_backbone');
57+
});
58+
59+
it('variable sequence length', () => {
60+
let inputData: Tensor[];
61+
for (const seqLength of [2, 3, 4]) {
62+
inputData = [
63+
ones([2, seqLength], 'int32'), // tokenIds
64+
ones([2, seqLength], 'int32'), // paddingMask
65+
];
66+
expect(() => backbone.apply(inputData)).not.toThrow();
67+
}
68+
});
69+
70+
it('predict', () => {
71+
expect(() => backbone.predict(inputBatch)).not.toThrow();
72+
});
73+
74+
it('does not leak memory', () => {
75+
const numTensors = memory().numTensors;
76+
backbone.apply(inputBatch);
77+
78+
expect(memory().numTensors).toEqual(numTensors + 1);
79+
});
80+
});

0 commit comments

Comments
 (0)