20
20
*/
21
21
package io .bioimage .modelrunner .onnx ;
22
22
23
+ import io .bioimage .modelrunner .apposed .appose .Types ;
23
24
import io .bioimage .modelrunner .engine .DeepLearningEngineInterface ;
24
25
import io .bioimage .modelrunner .exceptions .LoadModelException ;
25
26
import io .bioimage .modelrunner .exceptions .RunModelException ;
26
27
import io .bioimage .modelrunner .onnx .tensor .ImgLib2Builder ;
27
28
import io .bioimage .modelrunner .onnx .tensor .TensorBuilder ;
28
29
import io .bioimage .modelrunner .tensor .Tensor ;
29
30
import net .imglib2 .RandomAccessibleInterval ;
30
- import net .imglib2 .img . array . ArrayImgs ;
31
- import net .imglib2 .type .numeric .real . FloatType ;
31
+ import net .imglib2 .type . NativeType ;
32
+ import net .imglib2 .type .numeric .RealType ;
32
33
33
- import java .io .File ;
34
34
import java .util .ArrayList ;
35
35
import java .util .Iterator ;
36
36
import java .util .LinkedHashMap ;
@@ -85,6 +85,7 @@ public OnnxInterface()
85
85
{
86
86
}
87
87
88
+ /**
88
89
public static void main(String args[]) throws LoadModelException, RunModelException {
89
90
String folderName = "/home/carlos/git/deep-icy/models/NucleiSegmentationBoundaryModel_27112023_190556";
90
91
String source = folderName + File.separator + "weights.onnx";
@@ -104,6 +105,7 @@ public static void main(String args[]) throws LoadModelException, RunModelExcept
104
105
oi.run(inps, outs);
105
106
System.out.println(false);
106
107
}
108
+ */
107
109
108
110
/**
109
111
* {@inheritDoc}
@@ -132,10 +134,12 @@ public void loadModel(String modelFolder, String modelSource) throws LoadModelEx
132
134
*
133
135
* Run a Onnx model on the data provided by the {@link Tensor} input list
134
136
* and modifies the output list with the results obtained
137
+ * @throws RunModelException
135
138
*
136
139
*/
137
140
@ Override
138
- public void run (List <Tensor <?>> inputTensors , List <Tensor <?>> outputTensors ) throws RunModelException {
141
+ public <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >>
142
+ void run (List <Tensor <T >> inputTensors , List <Tensor <R >> outputTensors ) throws RunModelException {
139
143
Result output ;
140
144
LinkedHashMap <String , OnnxTensor > inputMap = new LinkedHashMap <String , OnnxTensor >();
141
145
Iterator <String > inputNames = session .getInputNames ().iterator ();
@@ -160,10 +164,47 @@ public void run(List<Tensor<?>> inputTensors, List<Tensor<?>> outputTensors) thr
160
164
for (OnnxTensor tt : inputMap .values ()) {
161
165
tt .close ();
162
166
}
163
- for (Object tt : output ) {
164
- tt = null ;
167
+ output .close ();
168
+ }
169
+
170
+ @ Override
171
+ public <T extends RealType <T > & NativeType <T >, R extends RealType <R > & NativeType <R >> List <RandomAccessibleInterval <R >> inference (
172
+ List <RandomAccessibleInterval <T >> inputs ) throws RunModelException {
173
+ Result output ;
174
+ LinkedHashMap <String , OnnxTensor > inputMap = new LinkedHashMap <String , OnnxTensor >();
175
+ Iterator <String > inputNames = session .getInputNames ().iterator ();
176
+ try {
177
+ for (RandomAccessibleInterval <T > tt : inputs ) {
178
+ OnnxTensor inT = TensorBuilder .build (tt , env );
179
+ inputMap .put (inputNames .next (), inT );
180
+ }
181
+ output = session .run (inputMap );
182
+ } catch (OrtException ex ) {
183
+ for (OnnxTensor tt : inputMap .values ()) {
184
+ tt .close ();
185
+ }
186
+ throw new RunModelException ("Error trying to run an Onnx model."
187
+ + System .lineSeparator () + Types .stackTrace (ex ));
188
+ }
189
+ for (OnnxTensor tt : inputMap .values ()) {
190
+ tt .close ();
191
+ }
192
+
193
+ // Fill the agnostic output tensors list with data from the inference result
194
+ List <RandomAccessibleInterval <R >> rais = new ArrayList <RandomAccessibleInterval <R >>();
195
+ for (int i = 0 ; i < output .size (); i ++) {
196
+ try {
197
+ rais .add (ImgLib2Builder .build (output .get (i ).getValue ()));
198
+ output .get (i ).close ();
199
+ } catch (IllegalArgumentException | OrtException e ) {
200
+ for (int j = i ; j < output .size (); j ++)
201
+ output .get (j ).close ();
202
+ output .close ();
203
+ throw new RunModelException ("Error converting tensor into RAI" + Types .stackTrace (e ));
204
+ }
165
205
}
166
206
output .close ();
207
+ return rais ;
167
208
}
168
209
169
210
/**
@@ -179,17 +220,22 @@ public void run(List<Tensor<?>> inputTensors, List<Tensor<?>> outputTensors) thr
179
220
* @throws RunModelException If the number of tensors expected is not the same
180
221
* as the number of Tensors outputed by the model
181
222
*/
182
- public static void fillOutputTensors (Result onnxTensors , List <Tensor <?>> outputTensors ) throws RunModelException {
223
+ public static <T extends RealType <T > & NativeType <T >>
224
+ void fillOutputTensors (Result onnxTensors ,
225
+ List <Tensor <T >> outputTensors ) throws RunModelException {
183
226
if (onnxTensors .size () != outputTensors .size ())
184
227
throw new RunModelException (onnxTensors .size (), outputTensors .size ());
185
228
int cc = 0 ;
186
229
for (Tensor tt : outputTensors ) {
187
230
try {
188
- tt .setData (ImgLib2Builder .build (onnxTensors .get (cc ++).getValue ()));
231
+ tt .setData (ImgLib2Builder .build (onnxTensors .get (cc ).getValue ()));
232
+ onnxTensors .get (cc ).close ();
233
+ cc ++;
189
234
} catch (IllegalArgumentException | OrtException e ) {
190
- e .printStackTrace ();
191
- throw new RunModelException ("Unable to recover value of output tensor: " + tt .getName ()
192
- + System .lineSeparator () + e .getCause ().toString ());
235
+ for (int j = cc ; j < onnxTensors .size (); j ++)
236
+ onnxTensors .get (j ).close ();
237
+ onnxTensors .close ();
238
+ throw new RunModelException ("Error converting tensor '" + tt .getName () + "' into RAI" + Types .stackTrace (e ));
193
239
}
194
240
}
195
241
}
0 commit comments