Skip to content

Commit f686c26

Browse files
committed
add inference
1 parent c7b2984 commit f686c26

File tree

3 files changed

+187
-37
lines changed

3 files changed

+187
-37
lines changed

src/main/java/io/bioimage/modelrunner/pytorch/JavaWorker.java

+12-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ public class JavaWorker {
1616

1717
private static LinkedHashMap<String, Object> tasks = new LinkedHashMap<String, Object>();
1818

19+
private Map<String, Object> outputs;
20+
1921
private final String uuid;
2022

2123
private final PytorchInterface pi;
@@ -90,10 +92,18 @@ private void executeScript(String script, Map<String, Object> inputs) {
9092
try {
9193
if (script.equals("loadModel")) {
9294
pi.loadModel((String) inputs.get("modelFolder"), (String) inputs.get("modelSource"));
93-
} else if (script.equals("inference")) {
95+
} else if (script.equals("run")) {
9496
pi.runFromShmas((List<String>) inputs.get("inputs"), (List<String>) inputs.get("outputs"));
97+
} else if (script.equals("inference")) {
98+
List<String> encodedOutputs = pi.inferenceFromShmas((List<String>) inputs.get("inputs"));
99+
outputs = new HashMap<String, Object>();
100+
HashMap<String, List<String>> out = new HashMap<String, List<String>>();
101+
out.put("encoded", encodedOutputs);
102+
outputs.put("outputs", out);
95103
} else if (script.equals("close")) {
96104
pi.closeModel();
105+
} else if (script.equals("closeTensors")) {
106+
pi.closeFromInterp();
97107
}
98108
} catch(Exception | Error ex) {
99109
this.fail(Types.stackTrace(ex));
@@ -111,7 +121,7 @@ private void reportLaunch() {
111121
}
112122

113123
private void reportCompletion() {
114-
respond(ResponseType.COMPLETION, null);
124+
respond(ResponseType.COMPLETION, outputs);
115125
}
116126

117127
private void update(String message, Integer current, Integer maximum) {

src/main/java/io/bioimage/modelrunner/pytorch/PytorchInterface.java

+143-19
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import io.bioimage.modelrunner.engine.DeepLearningEngineInterface;
2828
import io.bioimage.modelrunner.exceptions.LoadModelException;
2929
import io.bioimage.modelrunner.exceptions.RunModelException;
30-
import io.bioimage.modelrunner.numpy.DecodeNumpy;
3130
import io.bioimage.modelrunner.pytorch.shm.ShmBuilder;
3231
import io.bioimage.modelrunner.pytorch.shm.TensorBuilder;
3332
import io.bioimage.modelrunner.pytorch.tensor.ImgLib2Builder;
@@ -37,7 +36,6 @@
3736
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
3837
import io.bioimage.modelrunner.utils.CommonUtils;
3938
import net.imglib2.RandomAccessibleInterval;
40-
import net.imglib2.img.array.ArrayImgs;
4139
import net.imglib2.type.NativeType;
4240
import net.imglib2.type.numeric.RealType;
4341
import net.imglib2.util.Cast;
@@ -246,6 +244,34 @@ private static String getModelName(String modelSource) throws IOException {
246244
return modelName.substring(0, ind);
247245
}
248246

247+
@Override
248+
public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<RandomAccessibleInterval<R>> inference(
249+
List<RandomAccessibleInterval<T>> inputs) throws RunModelException {
250+
if (interprocessing) {
251+
return runInterprocessing(inputs);
252+
}
253+
try (NDManager manager = NDManager.newBaseManager()) {
254+
// Create the input lists of engine tensors (NDArrays) and their
255+
// corresponding names
256+
NDList inputList = new NDList();
257+
for (RandomAccessibleInterval<T> tt : inputs) {
258+
inputList.add(NDArrayBuilder.build(tt, manager));
259+
}
260+
// Run model
261+
Predictor<NDList, NDList> predictor = model.newPredictor();
262+
NDList outputNDArrays = predictor.predict(inputList);
263+
// Fill the agnostic output tensors list with data from the inference
264+
// result
265+
List<RandomAccessibleInterval<R>> outs = new ArrayList<RandomAccessibleInterval<R>>();
266+
for (int i = 0; i < outputNDArrays.size(); i ++)
267+
outs.add(ImgLib2Builder.build(outputNDArrays.get(i)));
268+
return outs;
269+
}
270+
catch (TranslateException e) {
271+
throw new RunModelException(Types.stackTrace(e));
272+
}
273+
}
274+
249275
/**
250276
* {@inheritDoc}
251277
*
@@ -310,21 +336,36 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
310336
}
311337
}
312338

313-
/**
314-
* MEthod only used in MacOS Intel and Windows systems that makes all the arrangements
315-
* to create another process, communicate the model info and tensors to the other
316-
* process and then retrieve the results of the other process
317-
* @param <T>
318-
* ImgLib2 data type of the input tensors
319-
* @param <R>
320-
* ImgLib2 data type of the output tensors, it can be the same as the input tensors' data type
321-
* @param inputTensors
322-
* tensors that are going to be run on the model
323-
* @param outputTensors
324-
* expected results of the model
325-
* @throws RunModelException if there is any issue running the model
326-
*/
327-
public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
339+
protected List<String> inferenceFromShmas(List<String> inputs) throws IOException, RunModelException {
340+
try (NDManager manager = NDManager.newBaseManager()) {
341+
// Create the input lists of engine tensors (NDArrays) and their
342+
// corresponding names
343+
NDList inputList = new NDList();
344+
for (String ee : inputs) {
345+
Map<String, Object> decoded = Types.decode(ee);
346+
SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY));
347+
NDArray inT = TensorBuilder.build(shma, manager);
348+
if (PlatformDetection.isWindows()) shma.close();
349+
inputList.add(inT);
350+
}
351+
// Run model
352+
Predictor<NDList, NDList> predictor = model.newPredictor();
353+
NDList outputNDArrays = predictor.predict(inputList);
354+
355+
shmaNamesList = new ArrayList<String>();
356+
for (int i = 0; i < outputNDArrays.size(); i ++) {
357+
String name = SharedMemoryArray.createShmName();
358+
ShmBuilder.build(outputNDArrays.get(i), name, false);
359+
shmaNamesList.add(name);
360+
}
361+
return shmaNamesList;
362+
}
363+
catch (TranslateException e) {
364+
throw new RunModelException(Types.stackTrace(e));
365+
}
366+
}
367+
368+
private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
328369
void runInterprocessing(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors) throws RunModelException {
329370
shmaInputList = new ArrayList<SharedMemoryArray>();
330371
shmaOutputList = new ArrayList<SharedMemoryArray>();
@@ -335,7 +376,7 @@ void runInterprocessing(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTens
335376
args.put("outputs", encOuts);
336377

337378
try {
338-
Task task = runner.task("inference", args);
379+
Task task = runner.task("run", args);
339380
task.waitFor();
340381
if (task.status == TaskStatus.CANCELED)
341382
throw new RuntimeException();
@@ -365,6 +406,76 @@ else if (task.status == TaskStatus.CRASHED) {
365406
}
366407
closeShmas();
367408
}
409+
410+
private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
411+
List<RandomAccessibleInterval<R>> runInterprocessing(List<RandomAccessibleInterval<T>> inputs) throws RunModelException {
412+
shmaInputList = new ArrayList<SharedMemoryArray>();
413+
List<String> encIns = new ArrayList<String>();
414+
Gson gson = new Gson();
415+
for (RandomAccessibleInterval<T> tt : inputs) {
416+
SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt, false, true);
417+
shmaInputList.add(shma);
418+
HashMap<String, Object> map = new HashMap<String, Object>();
419+
map.put(SHAPE_KEY, tt.dimensionsAsLongArray());
420+
map.put(DTYPE_KEY, CommonUtils.getDataTypeFromRAI(tt));
421+
map.put(IS_INPUT_KEY, true);
422+
map.put(MEM_NAME_KEY, shma.getName());
423+
encIns.add(gson.toJson(map));
424+
}
425+
LinkedHashMap<String, Object> args = new LinkedHashMap<String, Object>();
426+
args.put("inputs", encIns);
427+
428+
try {
429+
Task task = runner.task("inference", args);
430+
task.waitFor();
431+
if (task.status == TaskStatus.CANCELED)
432+
throw new RuntimeException();
433+
else if (task.status == TaskStatus.FAILED)
434+
throw new RuntimeException(task.error);
435+
else if (task.status == TaskStatus.CRASHED) {
436+
this.runner.close();
437+
runner = null;
438+
throw new RuntimeException(task.error);
439+
} else if (task.outputs == null)
440+
throw new RuntimeException("No outputs generated");
441+
List<String> outputs = (List<String>) task.outputs.get("encoded");
442+
List<RandomAccessibleInterval<R>> rais = new ArrayList<RandomAccessibleInterval<R>>();
443+
for (String out : outputs) {
444+
String name = (String) Types.decode(out).get(MEM_NAME_KEY);
445+
SharedMemoryArray shm = SharedMemoryArray.read(name);
446+
RandomAccessibleInterval<R> rai = shm.getSharedRAI();
447+
rais.add(Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(rai), Util.getTypeFromInterval(Cast.unchecked(rai))));
448+
shm.close();
449+
}
450+
closeShmas();
451+
return rais;
452+
} catch (Exception e) {
453+
closeShmas();
454+
if (e instanceof RunModelException)
455+
throw (RunModelException) e;
456+
throw new RunModelException(Types.stackTrace(e));
457+
}
458+
}
459+
460+
private void closeInterprocess() throws RunModelException {
461+
try {
462+
Task task = runner.task("closeTensors");
463+
task.waitFor();
464+
if (task.status == TaskStatus.CANCELED)
465+
throw new RuntimeException();
466+
else if (task.status == TaskStatus.FAILED)
467+
throw new RuntimeException(task.error);
468+
else if (task.status == TaskStatus.CRASHED) {
469+
this.runner.close();
470+
runner = null;
471+
throw new RuntimeException(task.error);
472+
}
473+
} catch (Exception e) {
474+
if (e instanceof RunModelException)
475+
throw (RunModelException) e;
476+
throw new RunModelException(Types.stackTrace(e));
477+
}
478+
}
368479

369480
/**
370481
* Create the list a list of output tensors agnostic to the Deep Learning
@@ -391,7 +502,19 @@ void fillOutputTensors(NDList outputNDArrays,
391502
}
392503
}
393504

394-
private void closeShmas() {
505+
protected void closeFromInterp() {
506+
if (!PlatformDetection.isWindows())
507+
return;
508+
this.shmaNamesList.stream().forEach(nn -> {
509+
try {
510+
SharedMemoryArray.read(nn).close();
511+
} catch (IOException e) {
512+
e.printStackTrace();
513+
}
514+
});
515+
}
516+
517+
private void closeShmas() throws RunModelException {
395518
shmaInputList.forEach(shm -> {
396519
try { shm.close(); } catch (IOException e1) { e1.printStackTrace();}
397520
});
@@ -400,6 +523,7 @@ private void closeShmas() {
400523
try { shm.close(); } catch (IOException e1) { e1.printStackTrace();}
401524
});
402525
shmaOutputList = null;
526+
closeInterprocess();
403527
}
404528

405529

src/main/java/io/bioimage/modelrunner/pytorch/shm/ShmBuilder.java

+32-16
Original file line numberDiff line numberDiff line change
@@ -58,42 +58,58 @@ private ShmBuilder()
5858
* @throws IllegalArgumentException if the data type of the tensor is not supported
5959
* @throws IOException if there is any error creating the shared memory array
6060
*/
61-
public static void build(NDArray tensor, String memoryName) throws IllegalArgumentException, IOException
61+
public static void build(NDArray tensor, String memoryName) throws IllegalArgumentException, IOException
62+
{
63+
build(tensor, memoryName, true);
64+
}
65+
66+
/**
67+
* Create a {@link SharedMemoryArray} from a {@link NDArray}
68+
* @param tensor
69+
* the tensor to be passed into the other process through the shared memory
70+
* @param memoryName
71+
* the name of the memory region where the tensor is going to be copied
72+
* @param close
73+
* for windows, whether to close the shm once it has been copied
74+
* @throws IllegalArgumentException if the data type of the tensor is not supported
75+
* @throws IOException if there is any error creating the shared memory array
76+
*/
77+
public static void build(NDArray tensor, String memoryName, boolean close) throws IllegalArgumentException, IOException
6278
{
6379
switch (tensor.getDataType())
6480
{
6581
case UINT8:
66-
buildFromTensorUByte(tensor, memoryName);
82+
buildFromTensorUByte(tensor, memoryName, close);
6783
break;
6884
case INT32:
69-
buildFromTensorInt(tensor, memoryName);
85+
buildFromTensorInt(tensor, memoryName, close);
7086
break;
7187
case FLOAT32:
72-
buildFromTensorFloat(tensor, memoryName);
88+
buildFromTensorFloat(tensor, memoryName, close);
7389
break;
7490
case FLOAT64:
75-
buildFromTensorDouble(tensor, memoryName);
91+
buildFromTensorDouble(tensor, memoryName, close);
7692
break;
7793
case INT64:
78-
buildFromTensorLong(tensor, memoryName);
94+
buildFromTensorLong(tensor, memoryName, close);
7995
break;
8096
default:
8197
throw new IllegalArgumentException("Unsupported tensor type: " + tensor.getDataType().asNumpy());
8298
}
8399
}
84100

85-
private static void buildFromTensorUByte(NDArray tensor, String memoryName) throws IOException
101+
private static void buildFromTensorUByte(NDArray tensor, String memoryName, boolean close) throws IOException
86102
{
87103
long[] arrayShape = tensor.getShape().getShape();
88104
if (CommonUtils.int32Overflows(arrayShape, 1))
89105
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
90106
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
91107
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
92108
shma.getDataBufferNoHeader().put(tensor.toByteArray());
93-
if (PlatformDetection.isWindows()) shma.close();
109+
if (PlatformDetection.isWindows() && close) shma.close();
94110
}
95111

96-
private static void buildFromTensorInt(NDArray tensor, String memoryName) throws IOException
112+
private static void buildFromTensorInt(NDArray tensor, String memoryName, boolean close) throws IOException
97113
{
98114
long[] arrayShape = tensor.getShape().getShape();
99115
if (CommonUtils.int32Overflows(arrayShape, 4))
@@ -102,10 +118,10 @@ private static void buildFromTensorInt(NDArray tensor, String memoryName) throws
102118

103119
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
104120
shma.getDataBufferNoHeader().put(tensor.toByteArray());
105-
if (PlatformDetection.isWindows()) shma.close();
121+
if (PlatformDetection.isWindows() && close) shma.close();
106122
}
107123

108-
private static void buildFromTensorFloat(NDArray tensor, String memoryName) throws IOException
124+
private static void buildFromTensorFloat(NDArray tensor, String memoryName, boolean close) throws IOException
109125
{
110126
long[] arrayShape = tensor.getShape().getShape();
111127
if (CommonUtils.int32Overflows(arrayShape, 4))
@@ -114,10 +130,10 @@ private static void buildFromTensorFloat(NDArray tensor, String memoryName) thro
114130

115131
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
116132
shma.getDataBufferNoHeader().put(tensor.toByteArray());
117-
if (PlatformDetection.isWindows()) shma.close();
133+
if (PlatformDetection.isWindows() && close) shma.close();
118134
}
119135

120-
private static void buildFromTensorDouble(NDArray tensor, String memoryName) throws IOException
136+
private static void buildFromTensorDouble(NDArray tensor, String memoryName, boolean close) throws IOException
121137
{
122138
long[] arrayShape = tensor.getShape().getShape();
123139
if (CommonUtils.int32Overflows(arrayShape, 8))
@@ -126,10 +142,10 @@ private static void buildFromTensorDouble(NDArray tensor, String memoryName) thr
126142

127143
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true);
128144
shma.getDataBufferNoHeader().put(tensor.toByteArray());
129-
if (PlatformDetection.isWindows()) shma.close();
145+
if (PlatformDetection.isWindows() && close) shma.close();
130146
}
131147

132-
private static void buildFromTensorLong(NDArray tensor, String memoryName) throws IOException
148+
private static void buildFromTensorLong(NDArray tensor, String memoryName, boolean close) throws IOException
133149
{
134150
long[] arrayShape = tensor.getShape().getShape();
135151
if (CommonUtils.int32Overflows(arrayShape, 8))
@@ -139,6 +155,6 @@ private static void buildFromTensorLong(NDArray tensor, String memoryName) throw
139155

140156
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true);
141157
shma.getDataBufferNoHeader().put(tensor.toByteArray());
142-
if (PlatformDetection.isWindows()) shma.close();
158+
if (PlatformDetection.isWindows() && close) shma.close();
143159
}
144160
}

0 commit comments

Comments
 (0)