Skip to content

Commit 91d07ba

Browse files
[gpt2pre 3] Preprocessor Layer (#7794)
* Add Preprocessor layer * Remove uneeded args * Use LayerArgs * Remove import from src * Add fromConfig method * Serialize tokenizer properly * Add test cases for preprocessor * Preprocessor tests with no set tokenizer --------- Co-authored-by: Linchenn <[email protected]>
1 parent d45c6af commit 91d07ba

File tree

3 files changed

+112
-1
lines changed

3 files changed

+112
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
/* Original source: keras-nlp/models/preprocessor.py */
19+
import { serialization } from '@tensorflow/tfjs-core';
20+
21+
import { Layer, LayerArgs } from '../../../engine/topology';
22+
import { Tokenizer } from '../tokenizers';
23+
import { Kwargs } from '../../../types';
24+
import { deserializeKerasObject, serializeKerasObject } from '../../../utils/generic_utils';
25+
26+
/**
27+
* Base class for model Preprocessors.
28+
*/
29+
export class Preprocessor extends Layer {
30+
/** @nocollapse */
31+
static readonly className = 'Preprocessor';
32+
33+
private _tokenizer: Tokenizer;
34+
35+
constructor(args: LayerArgs) {
36+
super(args);
37+
}
38+
39+
/**
40+
* The tokenizer used to tokenize strings.
41+
*/
42+
get tokenizer() {
43+
return this._tokenizer;
44+
}
45+
46+
set tokenizer(value: Tokenizer) {
47+
this._tokenizer = value;
48+
}
49+
50+
override getConfig(): serialization.ConfigDict {
51+
const config = super.getConfig();
52+
config.tokenizer = serializeKerasObject(this.tokenizer);
53+
return config;
54+
}
55+
56+
static override fromConfig<T extends serialization.Serializable>(
57+
cls: serialization.SerializableConstructor<T>,
58+
config: serialization.ConfigDict
59+
): T {
60+
const kwargs: Kwargs = config;
61+
62+
if (config.tokenizer != null && !(config.tokenizer instanceof Tokenizer)) {
63+
const tokenizerConfigDict = config.tokenizer as serialization.ConfigDict;
64+
65+
kwargs.tokenizer = deserializeKerasObject(
66+
tokenizerConfigDict,
67+
serialization.SerializationMap.getMap().classNameMap,
68+
{}, 'preprocessor');
69+
}
70+
return new cls(kwargs);
71+
}
72+
73+
static tokenizerCls<T extends serialization.Serializable>(
74+
cls: serialization.SerializableConstructor<T>) {}
75+
}
76+
serialization.registerClass(Preprocessor);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/**
2+
* @license
3+
* Copyright 2023 Google LLC.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
/**
19+
* Unit Tests for Preprocessor Layers.
20+
*/
21+
import { Preprocessor } from './preprocessor';
22+
23+
describe('Preprocessor', () => {
24+
let preprocessor: Preprocessor;
25+
26+
beforeEach(() => {
27+
preprocessor = new Preprocessor({});
28+
});
29+
30+
it('serialization round-trip with no set tokenizer', () => {
31+
const reserialized = Preprocessor.fromConfig(
32+
Preprocessor, preprocessor.getConfig());
33+
expect(reserialized.getConfig()).toEqual(preprocessor.getConfig());
34+
});
35+
});

tfjs-layers/src/layers/nlp/tokenizers.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ export class BytePairTokenizer extends Tokenizer {
331331

332332
override getConfig(): serialization.ConfigDict {
333333
const config = {
334-
vocabulary: this.vocabulary,
334+
vocabulary: Array.from(this._vocabulary.entries()),
335335
merges: this.merges,
336336
sequenceLength: this.sequenceLength,
337337
addPrefixSpace: this.addPrefixSpace,

0 commit comments

Comments
 (0)