Skip to content

Commit

Permalink
Clarify testing function and simplify mapToIdx and tokenizeAndMapToIdx.
Browse files Browse the repository at this point in the history
  • Loading branch information
aliciafmachado committed Feb 1, 2025
1 parent 352524a commit b2fbcc9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 62 deletions.
75 changes: 23 additions & 52 deletions animated-transformer/src/lib/tokens/token_gemb.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,6 @@ import {
expectedOutputSeqPrepFn,
} from '../tokens/token_gemb';

function tokenize_fn_test(input: string): number[] {
if (input == "")
return [];
if (input == "a")
return [0];
if (input == "b")
return [1];

return tokenize_fn_test(
input.substring(0, input.length / 2)).concat(tokenize_fn_test(input.substring(input.length / 2, input.length)));
};

describe('token_gemb', () => {
it('embed', () => {
const [aEmb, bEmb, padEmb] = [
Expand Down Expand Up @@ -176,50 +164,33 @@ describe('token_gemb', () => {
expect(targetTokensOneHot.tensor.arraySync()).toEqual(expectedOutputArr);
expect(targetTokensOneHot.dimNames).toEqual(['batch', 'pos', 'tokenId'])
});

it('batchEmbed, pad start', () => {
const [aEmb, bEmb, padEmb] = [
[1, 1],
[2, 2],
[0, 0],
];
const tokens = ['a', 'b'];
const tokenEmbedding = new GTensor(tf.tensor([aEmb, bEmb, padEmb]), ['tokenId', 'inputRep']);
it('Test tokenizeAndMapToIdx', () => {
// Mock a tokenizer for testing tokenizeAndMapToIdx.
function tokenize_fn_test(input: string): number[] {
let output: number[] = [];
for (let i = 0; i < input.length; i++) {
if (input[i] == 'a')
output = output.concat(0);
else
output = output.concat(1);
}
return output;
};

const seqsToEmbed = ['aba', 'ab', '', 'b', 'a'];
const seqsIdxs = tokenizeAndMapToIdx(tokenize_fn_test, seqsToEmbed);
const expectedIdxs =
[[0, 1, 0], [0, 1], [], [1], [0]];

const seqEmb = embedBatch(tokenEmbedding, seqsIdxs, {
paddingId: 2,
padAt: 'start',
dtype: 'int32',
maxInputLength: 2,
});

const expectedOutputArr: number[][][] = [
[
[1, 1],
[2, 2],
],
[
[1, 1],
[2, 2],
],
[
[0, 0],
[0, 0],
],
[
[0, 0],
[2, 2],
],
[
[0, 0],
[1, 1],
],
];
expect(seqsIdxs).toEqual(expectedIdxs);
});
it('Test mapToIdx', () => {
const tokens = ['a', 'b', '[pad]'];
const tokenRep = prepareBasicTaskTokenRep(tokens);

expect(seqEmb.tensor.arraySync()).toEqual(expectedOutputArr);
expect(seqEmb.dimNames).toEqual(['batch', 'pos', 'inputRep'])
const seqsToEmbed = [['a', 'b', '[pad]', 'a'], ['a', 'b'], [], ['b'], ['a']];
const seqsIdxs = mapToIdx(tokenRep.tokenToIdx, seqsToEmbed);
const expectedIdxs = [[0, 1, 2, 0], [0, 1], [], [1], [0]];
expect(seqsIdxs).toEqual(expectedIdxs);
});
});
12 changes: 2 additions & 10 deletions animated-transformer/src/lib/tokens/token_gemb.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,7 @@ export function mapToIdx(
tokenToIdx: { [token: string]: number },
examples: string[][]
): number[][] {
const tokenIdxs: number[][] = [];
examples.forEach((example) => {
tokenIdxs.push(example.map((s) => tokenToIdx[s]))
});
return tokenIdxs;
return examples.map((example) => example.map((s) => tokenToIdx[s]));
}

// TODO(@aliciafmachado): Merge this function with the one below
Expand All @@ -75,11 +71,7 @@ export function tokenizeAndMapToIdx(
tokenize_fn: (input: string) => number[],
examples: string[]
): number[][] {
const tokenIdxs: number[][] = [];
examples.forEach((example) => {
tokenIdxs.push(tokenize_fn(example))
});
return tokenIdxs;
return examples.map((example) => tokenize_fn(example));
}

// When batchSize is defined and batchSize > examples.length, then
Expand Down

0 comments on commit b2fbcc9

Please sign in to comment.