Skip to content

Commit 1fc09f8

Browse files
authored
Merge pull request #33171 from vespa-engine/arnej/detect-no-token-type-ids
detect if model does not use token_type_ids
2 parents effc2d1 + 2e88fc4 commit 1fc09f8

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ public HuggingFaceEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, HuggingFa
4848
this.runtime = runtime;
4949
inputIdsName = config.transformerInputIds();
5050
attentionMaskName = config.transformerAttentionMask();
51-
tokenTypeIdsName = config.transformerTokenTypeIds();
5251
outputName = config.transformerOutput();
5352
normalize = config.normalize();
5453
prependQuery = config.prependQuery();
@@ -75,15 +74,29 @@ public HuggingFaceEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, HuggingFa
7574
onnxOpts.setExecutionMode(config.transformerExecutionMode().toString());
7675
onnxOpts.setThreads(config.transformerInterOpThreads(), config.transformerIntraOpThreads());
7776
evaluator = onnx.evaluatorOf(config.transformerModel().toString(), onnxOpts);
77+
tokenTypeIdsName = detectTokenTypeIds(config, evaluator);
7878
validateModel();
7979
}
8080

81+
private static String detectTokenTypeIds(HuggingFaceEmbedderConfig config, OnnxEvaluator evaluator) {
82+
String configured = config.transformerTokenTypeIds();
83+
Map<String, TensorType> inputs = evaluator.getInputInfo();
84+
if (inputs.size() < 3) {
85+
// newer models have only 2 inputs (they do not use token type IDs)
86+
return "";
87+
} else {
88+
// could detect fallback from inputs here, currently set as default in .def file
89+
return configured;
90+
}
91+
}
92+
8193
private void validateModel() {
8294
Map<String, TensorType> inputs = evaluator.getInputInfo();
8395
validateName(inputs, inputIdsName, "input");
8496
validateName(inputs, attentionMaskName, "input");
85-
if (!tokenTypeIdsName.isEmpty()) validateName(inputs, tokenTypeIdsName, "input");
86-
97+
if (!tokenTypeIdsName.isEmpty()) {
98+
validateName(inputs, tokenTypeIdsName, "input");
99+
}
87100
Map<String, TensorType> outputs = evaluator.getOutputInfo();
88101
validateName(outputs, outputName, "output");
89102
}
@@ -250,4 +263,3 @@ protected record HFEmbeddingResult(IndexedTensor output, Tensor attentionMask, S
250263
protected record HFEmbedderCacheKey(String embedderId, Object embeddedValue) { }
251264

252265
}
253-

0 commit comments

Comments
 (0)