@@ -48,7 +48,6 @@ public HuggingFaceEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, HuggingFa
48
48
this .runtime = runtime ;
49
49
inputIdsName = config .transformerInputIds ();
50
50
attentionMaskName = config .transformerAttentionMask ();
51
- tokenTypeIdsName = config .transformerTokenTypeIds ();
52
51
outputName = config .transformerOutput ();
53
52
normalize = config .normalize ();
54
53
prependQuery = config .prependQuery ();
@@ -75,15 +74,29 @@ public HuggingFaceEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, HuggingFa
75
74
onnxOpts .setExecutionMode (config .transformerExecutionMode ().toString ());
76
75
onnxOpts .setThreads (config .transformerInterOpThreads (), config .transformerIntraOpThreads ());
77
76
evaluator = onnx .evaluatorOf (config .transformerModel ().toString (), onnxOpts );
77
+ tokenTypeIdsName = detectTokenTypeIds (config , evaluator );
78
78
validateModel ();
79
79
}
80
80
81
+ private static String detectTokenTypeIds (HuggingFaceEmbedderConfig config , OnnxEvaluator evaluator ) {
82
+ String configured = config .transformerTokenTypeIds ();
83
+ Map <String , TensorType > inputs = evaluator .getInputInfo ();
84
+ if (inputs .size () < 3 ) {
85
+ // newer models have only 2 inputs (they do not use token type IDs)
86
+ return "" ;
87
+ } else {
88
+ // could detect fallback from inputs here, currently set as default in .def file
89
+ return configured ;
90
+ }
91
+ }
92
+
81
93
private void validateModel () {
82
94
Map <String , TensorType > inputs = evaluator .getInputInfo ();
83
95
validateName (inputs , inputIdsName , "input" );
84
96
validateName (inputs , attentionMaskName , "input" );
85
- if (!tokenTypeIdsName .isEmpty ()) validateName (inputs , tokenTypeIdsName , "input" );
86
-
97
+ if (!tokenTypeIdsName .isEmpty ()) {
98
+ validateName (inputs , tokenTypeIdsName , "input" );
99
+ }
87
100
Map <String , TensorType > outputs = evaluator .getOutputInfo ();
88
101
validateName (outputs , outputName , "output" );
89
102
}
@@ -250,4 +263,3 @@ protected record HFEmbeddingResult(IndexedTensor output, Tensor attentionMask, S
250
263
protected record HFEmbedderCacheKey (String embedderId , Object embeddedValue ) { }
251
264
252
265
}
253
-
0 commit comments