Skip to content

Commit 144d4ad

Browse files
committed
add inference mode
1 parent b3a40b6 commit 144d4ad

File tree

2 files changed

+58
-12
lines changed

2 files changed

+58
-12
lines changed

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
<!-- NB: Deploy releases to the SciJava Maven repository. -->
9393
<releaseProfiles>sign,deploy-to-scijava</releaseProfiles>
9494

95-
<dl-modelrunner.version>0.5.8</dl-modelrunner.version>
95+
<dl-modelrunner.version>0.5.11-SNAPSHOT</dl-modelrunner.version>
9696
<onnxruntime.version>1.12.1</onnxruntime.version>
9797
</properties>
9898

src/main/java/io/bioimage/modelrunner/onnx/OnnxInterface.java

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,17 @@
2020
*/
2121
package io.bioimage.modelrunner.onnx;
2222

23+
import io.bioimage.modelrunner.apposed.appose.Types;
2324
import io.bioimage.modelrunner.engine.DeepLearningEngineInterface;
2425
import io.bioimage.modelrunner.exceptions.LoadModelException;
2526
import io.bioimage.modelrunner.exceptions.RunModelException;
2627
import io.bioimage.modelrunner.onnx.tensor.ImgLib2Builder;
2728
import io.bioimage.modelrunner.onnx.tensor.TensorBuilder;
2829
import io.bioimage.modelrunner.tensor.Tensor;
2930
import net.imglib2.RandomAccessibleInterval;
30-
import net.imglib2.img.array.ArrayImgs;
31-
import net.imglib2.type.numeric.real.FloatType;
31+
import net.imglib2.type.NativeType;
32+
import net.imglib2.type.numeric.RealType;
3233

33-
import java.io.File;
3434
import java.util.ArrayList;
3535
import java.util.Iterator;
3636
import java.util.LinkedHashMap;
@@ -85,6 +85,7 @@ public OnnxInterface()
8585
{
8686
}
8787

88+
/**
8889
public static void main(String args[]) throws LoadModelException, RunModelException {
8990
String folderName = "/home/carlos/git/deep-icy/models/NucleiSegmentationBoundaryModel_27112023_190556";
9091
String source = folderName + File.separator + "weights.onnx";
@@ -104,6 +105,7 @@ public static void main(String args[]) throws LoadModelException, RunModelExcept
104105
oi.run(inps, outs);
105106
System.out.println(false);
106107
}
108+
*/
107109

108110
/**
109111
* {@inheritDoc}
@@ -132,10 +134,12 @@ public void loadModel(String modelFolder, String modelSource) throws LoadModelEx
132134
*
133135
* Run a Onnx model on the data provided by the {@link Tensor} input list
134136
* and modifies the output list with the results obtained
137+
* @throws RunModelException
135138
*
136139
*/
137140
@Override
138-
public void run(List<Tensor<?>> inputTensors, List<Tensor<?>> outputTensors) throws RunModelException {
141+
public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
142+
void run(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors) throws RunModelException {
139143
Result output;
140144
LinkedHashMap<String, OnnxTensor> inputMap = new LinkedHashMap<String, OnnxTensor>();
141145
Iterator<String> inputNames = session.getInputNames().iterator();
@@ -160,10 +164,47 @@ public void run(List<Tensor<?>> inputTensors, List<Tensor<?>> outputTensors) thr
160164
for (OnnxTensor tt : inputMap.values()) {
161165
tt.close();
162166
}
163-
for (Object tt : output) {
164-
tt = null;
167+
output.close();
168+
}
169+
170+
@Override
171+
public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<RandomAccessibleInterval<R>> inference(
172+
List<RandomAccessibleInterval<T>> inputs) throws RunModelException {
173+
Result output;
174+
LinkedHashMap<String, OnnxTensor> inputMap = new LinkedHashMap<String, OnnxTensor>();
175+
Iterator<String> inputNames = session.getInputNames().iterator();
176+
try {
177+
for (RandomAccessibleInterval<T> tt : inputs) {
178+
OnnxTensor inT = TensorBuilder.build(tt, env);
179+
inputMap.put(inputNames.next(), inT);
180+
}
181+
output = session.run(inputMap);
182+
} catch (OrtException ex) {
183+
for (OnnxTensor tt : inputMap.values()) {
184+
tt.close();
185+
}
186+
throw new RunModelException("Error trying to run an Onnx model."
187+
+ System.lineSeparator() + Types.stackTrace(ex));
188+
}
189+
for (OnnxTensor tt : inputMap.values()) {
190+
tt.close();
191+
}
192+
193+
// Fill the agnostic output tensors list with data from the inference result
194+
List<RandomAccessibleInterval<R>> rais = new ArrayList<RandomAccessibleInterval<R>>();
195+
for (int i = 0; i < output.size(); i ++) {
196+
try {
197+
rais.add(ImgLib2Builder.build(output.get(i).getValue()));
198+
output.get(i).close();
199+
} catch (IllegalArgumentException | OrtException e) {
200+
for (int j = i; j < output.size(); j ++)
201+
output.get(j).close();
202+
output.close();
203+
throw new RunModelException("Error converting tensor into RAI" + Types.stackTrace(e));
204+
}
165205
}
166206
output.close();
207+
return rais;
167208
}
168209

169210
/**
@@ -179,17 +220,22 @@ public void run(List<Tensor<?>> inputTensors, List<Tensor<?>> outputTensors) thr
179220
* @throws RunModelException If the number of tensors expected is not the same
180221
* as the number of Tensors outputed by the model
181222
*/
182-
public static void fillOutputTensors(Result onnxTensors, List<Tensor<?>> outputTensors) throws RunModelException{
223+
public static <T extends RealType<T> & NativeType<T>>
224+
void fillOutputTensors(Result onnxTensors,
225+
List<Tensor<T>> outputTensors) throws RunModelException {
183226
if (onnxTensors.size() != outputTensors.size())
184227
throw new RunModelException(onnxTensors.size(), outputTensors.size());
185228
int cc = 0;
186229
for (Tensor tt : outputTensors) {
187230
try {
188-
tt.setData(ImgLib2Builder.build(onnxTensors.get(cc ++).getValue()));
231+
tt.setData(ImgLib2Builder.build(onnxTensors.get(cc).getValue()));
232+
onnxTensors.get(cc).close();
233+
cc ++;
189234
} catch (IllegalArgumentException | OrtException e) {
190-
e.printStackTrace();
191-
throw new RunModelException("Unable to recover value of output tensor: " + tt.getName()
192-
+ System.lineSeparator() + e.getCause().toString());
235+
for (int j = cc; j < onnxTensors.size(); j ++)
236+
onnxTensors.get(j).close();
237+
onnxTensors.close();
238+
throw new RunModelException("Error converting tensor '" + tt.getName() + "' into RAI" + Types.stackTrace(e));
193239
}
194240
}
195241
}

0 commit comments

Comments
 (0)