Skip to content

Commit 5241390

Browse files
authored
feat: text embeddings dynamic size; new text embedding models consts (#273)
## Description text embeddings dynamic size. New text embedding models constants. Default value for skipSpecialTokens. ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [x] Android ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings
1 parent c3124a3 commit 5241390

File tree

4 files changed

+34
-8
lines changed

4 files changed

+34
-8
lines changed

android/src/main/java/com/swmansion/rnexecutorch/models/TextEmbeddings/TextEmbeddingsModel.kt

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ class TextEmbeddingsModel(
1818
fun preprocess(input: String): Array<LongArray> {
1919
val inputIds = tokenizer.encode(input).map { it.toLong() }.toLongArray()
2020
val attentionMask = inputIds.map { if (it != 0L) 1L else 0L }.toLongArray()
21-
return arrayOf(inputIds, attentionMask) // Shape: [2, max_length]
21+
return arrayOf(inputIds, attentionMask) // Shape: [2, tokens]
2222
}
2323

2424
fun postprocess(
25-
modelOutput: FloatArray, // [max_length * embedding_dim]
26-
attentionMask: LongArray, // [max_length]
25+
modelOutput: FloatArray, // [tokens * embedding_dim]
26+
attentionMask: LongArray, // [tokens]
2727
): DoubleArray {
2828
val modelOutputDouble = modelOutput.map { it.toDouble() }.toDoubleArray()
2929
val embeddings = TextEmbeddingsUtils.meanPooling(modelOutputDouble, attentionMask)

ios/RnExecutorch/models/text_embeddings/TextEmbeddingsModel.mm

+15-4
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ - (NSArray *)preprocess:(NSString *)input {
99
for (int i = 0; i < [input_ids count]; i++) {
1010
[attention_mask addObject:@((int)([input_ids[i] intValue] != 0))];
1111
}
12-
return @[ input_ids, attention_mask ]; // [2, max_length]
12+
return @[ input_ids, attention_mask ]; // [2, tokens]
1313
}
1414

15-
- (NSArray *)postprocess:(NSArray *)modelOutput // [max_length * embedding_dim]
16-
attentionMask:(NSArray *)attentionMask // [max_length]
15+
- (NSArray *)postprocess:(NSArray *)modelOutput // [tokens * embedding_dim]
16+
attentionMask:(NSArray *)attentionMask // [tokens]
1717
{
1818
NSArray *embeddings = [TextEmbeddingsUtils meanPooling:modelOutput
1919
attentionMask:attentionMask];
@@ -22,7 +22,18 @@ - (NSArray *)postprocess:(NSArray *)modelOutput // [max_length * embedding_dim]
2222

2323
- (NSArray *)runModel:(NSString *)input {
2424
NSArray *modelInput = [self preprocess:input];
25-
NSArray *modelOutput = [self forward:modelInput];
25+
26+
NSMutableArray *inputTypes = [NSMutableArray arrayWithObjects:@4, @4, nil];
27+
NSMutableArray *shapes = [NSMutableArray new];
28+
29+
NSNumber *tokenCount = @([modelInput[0] count]);
30+
for (__unused id _ in modelInput) {
31+
[shapes addObject:[NSMutableArray arrayWithObjects:@1, tokenCount, nil]];
32+
}
33+
34+
NSArray *modelOutput = [self forward:modelInput
35+
shapes:shapes
36+
inputTypes:inputTypes];
2637
return [self postprocess:modelOutput[0] attentionMask:modelInput[1]];
2738
}
2839

src/constants/modelUrls.ts

+15
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,21 @@ export const ALL_MINILM_L6_V2 =
168168
export const ALL_MINILM_L6_V2_TOKENIZER =
169169
'https://huggingface.co/software-mansion/react-native-executorch-all-MiniLM-L6-v2/resolve/v0.4.0/tokenizer.json';
170170

171+
export const ALL_MPNET_BASE_V2 =
172+
'https://huggingface.co/software-mansion/react-native-executorch-all-mpnet-base-v2/resolve/v0.4.0/all-mpnet-base-v2_xnnpack.pte';
173+
export const ALL_MPNET_BASE_V2_TOKENIZER =
174+
'https://huggingface.co/software-mansion/react-native-executorch-all-mpnet-base-v2/resolve/v0.4.0/tokenizer.json';
175+
176+
export const MULTI_QA_MINILM_L6_COS_V1 =
177+
'https://huggingface.co/software-mansion/react-native-executorch-multi-qa-MiniLM-L6-cos-v1/resolve/v0.4.0/multi-qa-MiniLM-L6-cos-v1_xnnpack.pte';
178+
export const MULTI_QA_MINILM_L6_COS_V1_TOKENIZER =
179+
'https://huggingface.co/software-mansion/react-native-executorch-multi-qa-MiniLM-L6-cos-v1/resolve/v0.4.0/tokenizer.json';
180+
181+
export const MULTI_QA_MPNET_BASE_DOT_V1 =
182+
'https://huggingface.co/software-mansion/react-native-executorch-multi-qa-mpnet-base-dot-v1/resolve/v0.4.0/multi-qa-mpnet-base-dot-v1_xnnpack.pte';
183+
export const MULTI_QA_MPNET_BASE_DOT_V1_TOKENIZER =
184+
'https://huggingface.co/software-mansion/react-native-executorch-multi-qa-mpnet-base-dot-v1/resolve/v0.4.0/tokenizer.json';
185+
171186
// Backward compatibility
172187
export const LLAMA3_2_3B_URL = LLAMA3_2_3B;
173188
export const LLAMA3_2_3B_QLORA_URL = LLAMA3_2_3B_QLORA;

src/modules/natural_language_processing/TokenizerModule.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ export class TokenizerModule extends BaseModule {
1111

1212
static async decode(
1313
input: number[],
14-
skipSpecialTokens: boolean
14+
skipSpecialTokens = false
1515
): Promise<string> {
1616
return await this.nativeModule.decode(input, skipSpecialTokens);
1717
}

0 commit comments

Comments
 (0)