From b2fbcc962714f083bd8ae1116eed037b8c1527ef Mon Sep 17 00:00:00 2001 From: Alicia Machado Date: Fri, 31 Jan 2025 10:23:46 +0100 Subject: [PATCH] Clarify testing function and simplify mapToIdx and tokenizeAndMapToIdx. --- .../src/lib/tokens/token_gemb.spec.ts | 75 ++++++------------- .../src/lib/tokens/token_gemb.ts | 12 +-- 2 files changed, 25 insertions(+), 62 deletions(-) diff --git a/animated-transformer/src/lib/tokens/token_gemb.spec.ts b/animated-transformer/src/lib/tokens/token_gemb.spec.ts index 7528992..27e346f 100644 --- a/animated-transformer/src/lib/tokens/token_gemb.spec.ts +++ b/animated-transformer/src/lib/tokens/token_gemb.spec.ts @@ -27,18 +27,6 @@ import { expectedOutputSeqPrepFn, } from '../tokens/token_gemb'; -function tokenize_fn_test(input: string): number[] { - if (input == "") - return []; - if (input == "a") - return [0]; - if (input == "b") - return [1]; - - return tokenize_fn_test( - input.substring(0, input.length / 2)).concat(tokenize_fn_test(input.substring(input.length / 2, input.length))); -}; - describe('token_gemb', () => { it('embed', () => { const [aEmb, bEmb, padEmb] = [ @@ -176,50 +164,33 @@ describe('token_gemb', () => { expect(targetTokensOneHot.tensor.arraySync()).toEqual(expectedOutputArr); expect(targetTokensOneHot.dimNames).toEqual(['batch', 'pos', 'tokenId']) }); - - it('batchEmbed, pad start', () => { - const [aEmb, bEmb, padEmb] = [ - [1, 1], - [2, 2], - [0, 0], - ]; - const tokens = ['a', 'b']; - const tokenEmbedding = new GTensor(tf.tensor([aEmb, bEmb, padEmb]), ['tokenId', 'inputRep']); + it('Test tokenizeAndMapToIdx', () => { + // Mock a tokenizer for testing tokenizeAndMapToIdx. + function tokenize_fn_test(input: string): number[] { + let output: number[] = []; + for (let i = 0; i < input.length; i++) { + if (input[i] == 'a') + output = output.concat(0); + else + output = output.concat(1); + } + return output; + }; const seqsToEmbed = ['aba', 'ab', '', 'b', 'a']; const seqsIdxs = tokenizeAndMapToIdx(tokenize_fn_test, seqsToEmbed); + const expectedIdxs = + [[0, 1, 0], [0, 1], [], [1], [0]]; - const seqEmb = embedBatch(tokenEmbedding, seqsIdxs, { - paddingId: 2, - padAt: 'start', - dtype: 'int32', - maxInputLength: 2, - }); - - const expectedOutputArr: number[][][] = [ - [ - [1, 1], - [2, 2], - ], - [ - [1, 1], - [2, 2], - ], - [ - [0, 0], - [0, 0], - ], - [ - [0, 0], - [2, 2], - ], - [ - [0, 0], - [1, 1], - ], - ]; + expect(seqsIdxs).toEqual(expectedIdxs); + }); + it('Test mapToIdx', () => { + const tokens = ['a', 'b', '[pad]']; + const tokenRep = prepareBasicTaskTokenRep(tokens); - expect(seqEmb.tensor.arraySync()).toEqual(expectedOutputArr); - expect(seqEmb.dimNames).toEqual(['batch', 'pos', 'inputRep']) + const seqsToEmbed = [['a', 'b', '[pad]', 'a'], ['a', 'b'], [], ['b'], ['a']]; + const seqsIdxs = mapToIdx(tokenRep.tokenToIdx, seqsToEmbed); + const expectedIdxs = [[0, 1, 2, 0], [0, 1], [], [1], [0]]; + expect(seqsIdxs).toEqual(expectedIdxs); }); }); diff --git a/animated-transformer/src/lib/tokens/token_gemb.ts b/animated-transformer/src/lib/tokens/token_gemb.ts index 672b6af..4586b45 100644 --- a/animated-transformer/src/lib/tokens/token_gemb.ts +++ b/animated-transformer/src/lib/tokens/token_gemb.ts @@ -62,11 +62,7 @@ export function mapToIdx( tokenToIdx: { [token: string]: number }, examples: string[][] ): number[][] { - const tokenIdxs: number[][] = []; - examples.forEach((example) => { - tokenIdxs.push(example.map((s) => tokenToIdx[s])) - }); - return tokenIdxs; + return examples.map((example) => example.map((s) => tokenToIdx[s])); } // TODO(@aliciafmachado): Merge this function with the one below @@ -75,11 +71,7 @@ export function tokenizeAndMapToIdx( tokenize_fn: (input: string) => number[], examples: string[] ): number[][] { - const tokenIdxs: number[][] = []; - examples.forEach((example) => { - tokenIdxs.push(tokenize_fn(example)) - }); - return tokenIdxs; + return examples.map((example) => tokenize_fn(example)); } // When batchSize is defined and batchSize > examples.length, then