Skip to content

Commit b2fbcc9

Browse files
Clarify testing function and simplify mapToIdx and tokenizeAndMapToIdx.
1 parent 352524a commit b2fbcc9

File tree

2 files changed

+25
-62
lines changed

2 files changed

+25
-62
lines changed

animated-transformer/src/lib/tokens/token_gemb.spec.ts

Lines changed: 23 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,6 @@ import {
2727
expectedOutputSeqPrepFn,
2828
} from '../tokens/token_gemb';
2929

30-
function tokenize_fn_test(input: string): number[] {
31-
if (input == "")
32-
return [];
33-
if (input == "a")
34-
return [0];
35-
if (input == "b")
36-
return [1];
37-
38-
return tokenize_fn_test(
39-
input.substring(0, input.length / 2)).concat(tokenize_fn_test(input.substring(input.length / 2, input.length)));
40-
};
41-
4230
describe('token_gemb', () => {
4331
it('embed', () => {
4432
const [aEmb, bEmb, padEmb] = [
@@ -176,50 +164,33 @@ describe('token_gemb', () => {
176164
expect(targetTokensOneHot.tensor.arraySync()).toEqual(expectedOutputArr);
177165
expect(targetTokensOneHot.dimNames).toEqual(['batch', 'pos', 'tokenId'])
178166
});
179-
180-
it('batchEmbed, pad start', () => {
181-
const [aEmb, bEmb, padEmb] = [
182-
[1, 1],
183-
[2, 2],
184-
[0, 0],
185-
];
186-
const tokens = ['a', 'b'];
187-
const tokenEmbedding = new GTensor(tf.tensor([aEmb, bEmb, padEmb]), ['tokenId', 'inputRep']);
167+
it('Test tokenizeAndMapToIdx', () => {
168+
// Mock a tokenizer for testing tokenizeAndMapToIdx.
169+
function tokenize_fn_test(input: string): number[] {
170+
let output: number[] = [];
171+
for (let i = 0; i < input.length; i++) {
172+
if (input[i] == 'a')
173+
output = output.concat(0);
174+
else
175+
output = output.concat(1);
176+
}
177+
return output;
178+
};
188179

189180
const seqsToEmbed = ['aba', 'ab', '', 'b', 'a'];
190181
const seqsIdxs = tokenizeAndMapToIdx(tokenize_fn_test, seqsToEmbed);
182+
const expectedIdxs =
183+
[[0, 1, 0], [0, 1], [], [1], [0]];
191184

192-
const seqEmb = embedBatch(tokenEmbedding, seqsIdxs, {
193-
paddingId: 2,
194-
padAt: 'start',
195-
dtype: 'int32',
196-
maxInputLength: 2,
197-
});
198-
199-
const expectedOutputArr: number[][][] = [
200-
[
201-
[1, 1],
202-
[2, 2],
203-
],
204-
[
205-
[1, 1],
206-
[2, 2],
207-
],
208-
[
209-
[0, 0],
210-
[0, 0],
211-
],
212-
[
213-
[0, 0],
214-
[2, 2],
215-
],
216-
[
217-
[0, 0],
218-
[1, 1],
219-
],
220-
];
185+
expect(seqsIdxs).toEqual(expectedIdxs);
186+
});
187+
it('Test mapToIdx', () => {
188+
const tokens = ['a', 'b', '[pad]'];
189+
const tokenRep = prepareBasicTaskTokenRep(tokens);
221190

222-
expect(seqEmb.tensor.arraySync()).toEqual(expectedOutputArr);
223-
expect(seqEmb.dimNames).toEqual(['batch', 'pos', 'inputRep'])
191+
const seqsToEmbed = [['a', 'b', '[pad]', 'a'], ['a', 'b'], [], ['b'], ['a']];
192+
const seqsIdxs = mapToIdx(tokenRep.tokenToIdx, seqsToEmbed);
193+
const expectedIdxs = [[0, 1, 2, 0], [0, 1], [], [1], [0]];
194+
expect(seqsIdxs).toEqual(expectedIdxs);
224195
});
225196
});

animated-transformer/src/lib/tokens/token_gemb.ts

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,7 @@ export function mapToIdx(
6262
tokenToIdx: { [token: string]: number },
6363
examples: string[][]
6464
): number[][] {
65-
const tokenIdxs: number[][] = [];
66-
examples.forEach((example) => {
67-
tokenIdxs.push(example.map((s) => tokenToIdx[s]))
68-
});
69-
return tokenIdxs;
65+
return examples.map((example) => example.map((s) => tokenToIdx[s]));
7066
}
7167

7268
// TODO(@aliciafmachado): Merge this function with the one below
@@ -75,11 +71,7 @@ export function tokenizeAndMapToIdx(
7571
tokenize_fn: (input: string) => number[],
7672
examples: string[]
7773
): number[][] {
78-
const tokenIdxs: number[][] = [];
79-
examples.forEach((example) => {
80-
tokenIdxs.push(tokenize_fn(example))
81-
});
82-
return tokenIdxs;
74+
return examples.map((example) => tokenize_fn(example));
8375
}
8476

8577
// When batchSize is defined and batchSize > examples.length, then

0 commit comments

Comments
 (0)