Skip to content

Commit a0115ea

Browse files
authored
Add spec for GPT2CausalLM and dependencies (#7897)
Adds the specs for GPT2CausalLM, GPT2CausalLMPreprocessor, GenerativeTask, Task, and PipelineModel.
1 parent d500a0d commit a0115ea

File tree

10 files changed

+730
-31
lines changed

10 files changed

+730
-31
lines changed

tfjs-layers/src/layers/nlp/models/backbone.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import { serialization } from '@tensorflow/tfjs-core';
2525
import { ContainerArgs } from '../../../engine/container';
2626
import { LayersModel } from '../../../engine/training';
2727
import { NotImplementedError } from '../../../errors';
28-
import { Layer } from '../../../exports_layers';
28+
import { Embedding } from '../../embeddings';
2929

3030
export class Backbone extends LayersModel {
3131
/** @nocollapse */
@@ -38,7 +38,7 @@ export class Backbone extends LayersModel {
3838
/**
3939
* A `tf.layers.embedding` instance for embedding token ids.
4040
*/
41-
get tokenEmbedding(): Layer {
41+
get tokenEmbedding(): Embedding {
4242
throw new NotImplementedError();
4343
}
4444

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
/**
19+
* Base class for Generative Task models.
20+
*/
21+
22+
/* Original source: keras_nlp/models/generative_task.py */
23+
import { NamedTensorMap, Tensor } from '@tensorflow/tfjs-core';
24+
25+
import { NotImplementedError } from '../../../errors';
26+
import { ModelCompileArgs } from '../../../engine/training';
27+
28+
import { Task } from './task';
29+
30+
export type GenerateFn =
31+
(inputs: NamedTensorMap, endTokenId?: number) => NamedTensorMap;
32+
33+
/**
34+
* Base class for Generative Task models.
35+
*/
36+
export class GenerativeTask extends Task {
37+
/** @nocollapse */
38+
static override className = 'GenerativeTask';
39+
40+
protected generateFunction: GenerateFn;
41+
42+
override compile(args: ModelCompileArgs): void {
43+
throw new NotImplementedError();
44+
}
45+
46+
/**
47+
* Run the generation on a single batch of input.
48+
*/
49+
generateStep(
50+
inputs: NamedTensorMap,
51+
endTokenId: number
52+
): NamedTensorMap {
53+
throw new NotImplementedError();
54+
}
55+
56+
/**
57+
* Create or return the compiled generation function.
58+
*/
59+
makeGenerateFunction(): GenerateFn {
60+
throw new NotImplementedError();
61+
}
62+
63+
/**
64+
* Normalize user input to the generate function.
65+
*
66+
* This function converts all inputs to tensors, adds a batch dimension if
67+
* necessary, and returns a iterable "dataset like" object.
68+
*/
69+
protected normalizeGenerateInputs(inputs: Tensor): [Tensor, boolean] {
70+
throw new NotImplementedError();
71+
}
72+
73+
/**
74+
* Normalize user output from the generate function.
75+
*
76+
* This function converts all output to numpy (for integer output), or
77+
* python strings (for string output). If a batch dimension was added to
78+
* the input, it is removed from the output (so generate can be string in,
79+
* string out).
80+
*/
81+
protected normalizeGenerateOutputs(
82+
outputs: Tensor,
83+
inputIsScalar: boolean
84+
): Tensor {
85+
throw new NotImplementedError();
86+
}
87+
88+
/**
89+
* Generate text given prompt `inputs`.
90+
*
91+
* This method generates text based on given `inputs`. The sampling method
92+
* used for generation can be set via the `compile()` method.
93+
*
94+
* `inputs` will be handled as a single batch.
95+
*
96+
* If a `preprocessor` is attached to the model, `inputs` will be
97+
* preprocessed inside the `generate()` function and should match the
98+
* structure expected by the `preprocessor` layer (usually raw strings).
99+
* If a `preprocessor` is not attached, inputs should match the structure
100+
* expected by the `backbone`. See the example usage above for a
101+
* demonstration of each.
102+
*
103+
* @param inputs tensor data. If a `preprocessor` is attached to the model,
104+
* `inputs` should match the structure expected by the `preprocessor` layer.
105+
* If a `preprocessor` is not attached, `inputs` should match the structure
106+
* expected the the `backbone` model.
107+
* @param maxLength Integer. The max length of the generated sequence.
108+
* Will default to the max configured `sequenceLength` of the
109+
* `preprocessor`. If `preprocessor` is `null`, `inputs` should be
110+
* should be padded to the desired maximum length and this argument
111+
* will be ignored.
112+
*/
113+
generate(inputs: Tensor, maxLength?: number) {
114+
throw new NotImplementedError();
115+
}
116+
}

tfjs-layers/src/layers/nlp/models/gpt2/gpt2_backbone.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ export class GPT2Backbone extends Backbone {
215215
return config;
216216
}
217217

218-
override get tokenEmbedding() {
219-
return this.getLayer('token_embedding');
218+
override get tokenEmbedding(): Embedding {
219+
return this.getLayer('token_embedding') as Embedding;
220220
}
221221
}
222222
serialization.registerClass(GPT2Backbone);
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
/**
19+
* GPT2 Causal LM (Language Model).
20+
*/
21+
22+
/* Original source: keras-nlp/models/gpt2/gpt2_causal_lm.py */
23+
import { NamedTensorMap, Tensor, serialization } from '@tensorflow/tfjs-core';
24+
25+
import { GPT2Preprocessor } from './gpt2_preprocessor';
26+
import { NotImplementedError } from '../../../../errors';
27+
import { Layer } from '../../../../exports_layers';
28+
import { LayerArgs } from '../../../../engine/topology';
29+
import { Embedding } from '../../../../layers/embeddings';
30+
import { Shape } from '../../../../keras_format/common';
31+
import { GenerativeTask } from '../generative_task';
32+
import { GPT2Backbone } from './gpt2_backbone';
33+
import { PipelineModelArgs } from '../../utils';
34+
import { Kwargs } from '../../../../types';
35+
36+
declare interface ReverseEmbeddingArgs extends LayerArgs {
37+
embedding: Embedding;
38+
}
39+
40+
class ReverseEmbedding extends Layer {
41+
protected embedding: Embedding;
42+
43+
constructor(args: ReverseEmbeddingArgs) {
44+
super(args);
45+
this.embedding = args.embedding;
46+
}
47+
48+
override call(inputs: Tensor|Tensor[], kwargs: Kwargs): Tensor|Tensor[] {
49+
throw new NotImplementedError();
50+
}
51+
52+
override computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {
53+
throw new NotImplementedError();
54+
}
55+
56+
}
57+
58+
export declare interface GPT2CausalLMArgs extends PipelineModelArgs {
59+
/**
60+
* A `GPT2Backbone` instance.
61+
*/
62+
backbone: GPT2Backbone;
63+
64+
/**
65+
* Optional `GPT2CausalLMPreprocessor`.
66+
* If `null`, this model will not apply preprocessing, and inputs should be
67+
* preprocessed before calling the model.
68+
*/
69+
preprocessor?: GPT2Preprocessor;
70+
}
71+
72+
/**
73+
* An end-to-end GPT2 model for causal langauge modeling.
74+
*
75+
* A causal language model (LM) predicts the next token based on previous
76+
* tokens. This task setup can be used to train the model unsupervised on
77+
* plain text input, or to autoregressively generate plain text similar to
78+
* the data used for training. This task can be used for pre-training or
79+
* fine-tuning a GPT-2 model, simply by calling `fit()`.
80+
*
81+
* This model has a `generate()` method, which generates text based on a
82+
* prompt. The generation strategy used is controlled by an additional
83+
* sampler` argument on `compile()`.
84+
* By default, the top k results will be returned.
85+
*
86+
* This model can optionally be configured with a `preprocessor` layer, in
87+
* which case it will automatically apply preprocessing to string inputs during
88+
* fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default
89+
* when creating the model with `fromPreset()`.
90+
*
91+
* Disclaimer: Pre-trained models are provided on an "as is" basis, without
92+
* warranties or conditions of any kind. The underlying model is provided by a
93+
* third party and subject to a separate license, available
94+
* here](https://github.com/openai/gpt-2).
95+
*
96+
* Use `generate()` to do text generation.
97+
* ```js
98+
* const gpt2LM = GPT2CausalLM.fromPreset('gpt2_base_en');
99+
* gpt2LM.generate("I want to say", max_length=30);
100+
* // Generate with batched prompts.
101+
* gpt2LM.generate(["This is a", "Where are you"], max_length=30);
102+
* ```
103+
*
104+
* Use `generate()` without preprocessing.
105+
* ```js
106+
* // Prompt the model with `5338, 318` (the token ids for `"Who is"`).
107+
* // Use `"paddingMask"` to indicate values that should not be overridden.
108+
* const prompt = {
109+
* tokenIds: tf.tensor([[5338, 318, 0, 0, 0], [5338, 318, 0, 0, 0]]),
110+
* paddingMask: tf.tensor([[1, 1, 0, 0, 0], [1, 1, 0, 0, 0]]]),
111+
* };
112+
* const gpt2LM = GPT2CausalLM.from_preset('gpt2_base_en', null);
113+
* gpt2LM.generate(prompt);
114+
* ```
115+
*
116+
* Call `fit()` on a single batch.
117+
* ```js
118+
* const features = ['The quick brown fox jumped.', 'I forgot my homework.'];
119+
* const gpt2LM = GPT2CausalLM.from_preset('gpt2_base_en');
120+
* gpt2LM.fit(features, {batchSize: 2});
121+
* ```
122+
*
123+
* Call `fit()` without preprocessing.
124+
* ```js
125+
* const x = {
126+
* tokenIds: tf.tensor([[50256, 1, 2, 3, 4], [50256, 1, 2, 3, 4]]),
127+
* paddingMask: tf.tensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]),
128+
* };
129+
* const y = tf.tensor([[1, 2, 3, 4, 50256], [1, 2, 3, 4, 50256]]);
130+
* const sw = tf.tensor([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]);
131+
* const gpt2LM = GPT2CausalLM.from_preset('gpt2_base_en', null);
132+
* gpt2LM.fit(x, y, {sampleWeight: sw, batchSize: 2});
133+
* ```
134+
*
135+
* Custom backbone and vocabulary.
136+
* ```js
137+
* const features = ["a quick fox.", "a fox quick."];
138+
* const vocab = {"<|endoftext|>": 0, "a": 4, "Ġquick": 5, "Ġfox": 6};
139+
* const merges = [
140+
* "Ġ q", "u i", "c k", "ui ck", "Ġq uick", "Ġ f", "o x", "Ġf ox"
141+
* ];
142+
* const tokenizer = new GPT2Tokenizer({vocabulary: vocab, merges});
143+
* const preprocessor = new GPT2CausalLMPreprocessor({
144+
* tokenizer,
145+
* sequence_length: 128,
146+
* });
147+
* const backbone = new GPT2Backbone({
148+
* vocabularysize: 30552,
149+
* numlayers: 4,
150+
* numheads: 4,
151+
* hiddendim: 256,
152+
* intermediatedim: 512,
153+
* maxSequenceLength: 128,
154+
* });
155+
* const gpt2LM = new GPT2CausalLM({backbone, preprocessor});
156+
* gpt2LM.fit(features, {batch_size: 2});
157+
* ```
158+
*/
159+
export class GPT2CausalLM extends GenerativeTask {
160+
/** @nocollapse */
161+
static override className = 'GPT2CausalLM';
162+
163+
constructor(args: GPT2CausalLMArgs) {
164+
super(args);
165+
throw new NotImplementedError(`Uses ${ReverseEmbedding}.`);
166+
}
167+
168+
static override presets<T extends serialization.Serializable>(
169+
cls: serialization.SerializableConstructor<T>
170+
): {} {
171+
throw new NotImplementedError();
172+
}
173+
174+
/**
175+
* Forward pass of `GPT2CausalLM` with cache.
176+
*
177+
* `callWithCache` adds an additional forward pass for the model for
178+
* autoregressive inference. Unlike calling the model directly, this method
179+
* allows caching previous key/value Tensors in multi-head attention layer,
180+
* and avoids recomputing the outputs of seen tokens.
181+
*
182+
* @param tokenIds a dense int Tensor with shape `[batchSize, maxLength]`.
183+
* @param cache a dense float Tensor, the cache of key and value.
184+
* @param cacheUpdateIndex Integer. The index of current inputs in the whole
185+
* sequence.
186+
* @returns [logits, hiddenStates, cache], where `logits` is the
187+
* language model logits for the input tokenIds, `hiddenStates` is
188+
* the final hidden representation of the input tokens, and `cache` is
189+
* the decoding cache.
190+
*/
191+
callWithCache(
192+
tokenIds: Tensor,
193+
cache: Tensor,
194+
cacheUpdateIndex: number
195+
): [Tensor, Tensor, Tensor] {
196+
throw new NotImplementedError();
197+
}
198+
199+
/**
200+
* Build an empty cache for use with `callWithCache()`.
201+
*/
202+
private buildCache(tokenIds: Tensor): [Tensor, Tensor] {
203+
throw new NotImplementedError();
204+
}
205+
206+
/**
207+
* A compilable generation function for a single batch of inputs.
208+
*
209+
* This function represents the inner generation function for a single batch
210+
* of inputs.
211+
*
212+
* @param inputs An object with two keys `tokenIds` and `paddingMask` and
213+
* batched tensor values.
214+
* @param endTokenId The id of the end token to stop on. If all
215+
* sequences have produced a new `endTokenId`, generation will stop.
216+
*/
217+
override generateStep(
218+
inputs: NamedTensorMap,
219+
endTokenId: number
220+
): NamedTensorMap {
221+
throw new NotImplementedError(`Uses ${this.buildCache}`);
222+
}
223+
}
224+
serialization.registerClass(GPT2CausalLM);

0 commit comments

Comments
 (0)