Skip to content

Commit 4420cd5

Browse files
committed
add new inference method
1 parent 80951cb commit 4420cd5

File tree

4 files changed

+217
-33
lines changed

4 files changed

+217
-33
lines changed

src/main/java/io/bioimage/modelrunner/pytorch/javacpp/JavaWorker.java

+12-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ public class JavaWorker {
1616

1717
private static LinkedHashMap<String, Object> tasks = new LinkedHashMap<String, Object>();
1818

19+
private Map<String, Object> outputs = new HashMap<String, Object>();
20+
1921
private final String uuid;
2022

2123
private final PytorchJavaCPPInterface pi;
@@ -88,10 +90,17 @@ private void executeScript(String script, Map<String, Object> inputs) {
8890
try {
8991
if (script.equals("loadModel")) {
9092
pi.loadModel((String) inputs.get("modelFolder"), (String) inputs.get("modelSource"));
91-
} else if (script.equals("inference")) {
93+
} else if (script.equals("run")) {
9294
pi.runFromShmas((List<String>) inputs.get("inputs"), (List<String>) inputs.get("outputs"));
95+
} else if (script.equals("inference")) {
96+
List<String> encodedOutputs = pi.inferenceFromShmas((List<String>) inputs.get("inputs"));
97+
HashMap<String, List<String>> out = new HashMap<String, List<String>>();
98+
out.put("encoded", encodedOutputs);
99+
outputs.put("outputs", out);
93100
} else if (script.equals("close")) {
94101
pi.closeModel();
102+
} else if (script.equals("closeTensors")) {
103+
pi.closeFromInterp();
95104
}
96105
} catch(Exception | Error ex) {
97106
this.fail(Types.stackTrace(ex));
@@ -109,7 +118,7 @@ private void reportLaunch() {
109118
}
110119

111120
private void reportCompletion() {
112-
respond(ResponseType.COMPLETION, null);
121+
respond(ResponseType.COMPLETION, outputs);
113122
}
114123

115124
private void update(String message, Integer current, Integer maximum) {
@@ -130,7 +139,7 @@ private void respond(ResponseType responseType, Map<String, Object> args) {
130139
Map<String, Object> response = new HashMap<String, Object>();
131140
response.put("task", uuid);
132141
response.put("responseType", responseType);
133-
if (args != null)
142+
if (args != null && args.keySet().size() > 0)
134143
response.putAll(args);
135144
try {
136145
System.out.println(Types.encode(response));

src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java

+172-13
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,52 @@ void run(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors) throws Run
237237
inputsVector.close();
238238
inputsVector.deallocate();
239239
}
240+
241+
/**
242+
* {@inheritDoc}
243+
*
244+
* Run a Pytorch model using JavaCpp on the data provided by the {@link Tensor} input list
245+
* and modifies the output list with the results obtained
246+
*
247+
*/
248+
public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
249+
List<RandomAccessibleInterval<R>> inference(List<RandomAccessibleInterval<T>> inputs) throws RunModelException {
250+
if (interprocessing) {
251+
return runInterprocessing(inputs);
252+
}
253+
IValueVector inputsVector = new IValueVector();
254+
for (RandomAccessibleInterval<T> tt : inputs) {
255+
inputsVector.put(new IValue(JavaCPPTensorBuilder.build(tt)));
256+
}
257+
// Run model
258+
model.eval();
259+
IValue output = model.forward(inputsVector);
260+
TensorVector outputTensorVector = null;
261+
if (output.isTensorList()) {
262+
outputTensorVector = output.toTensorVector();
263+
} else {
264+
outputTensorVector = new TensorVector();
265+
outputTensorVector.put(output.toTensor());
266+
}
267+
// Fill the agnostic output tensors list with data from the inference result
268+
List<RandomAccessibleInterval<R>> rais = new ArrayList<RandomAccessibleInterval<R>>();
269+
for (int i = 0; i < outputTensorVector.size(); i ++) {
270+
rais.add(ImgLib2Builder.build(outputTensorVector.get(i)));
271+
outputTensorVector.get(i).close();
272+
outputTensorVector.get(i).deallocate();
273+
}
274+
outputTensorVector.close();
275+
outputTensorVector.deallocate();
276+
output.close();
277+
output.deallocate();
278+
for (int i = 0; i < inputsVector.size(); i ++) {
279+
inputsVector.get(i).close();
280+
inputsVector.get(i).deallocate();
281+
}
282+
inputsVector.close();
283+
inputsVector.deallocate();
284+
return rais;
285+
}
240286

241287
protected void runFromShmas(List<String> inputs, List<String> outputs) throws IOException {
242288
IValueVector inputsVector = new IValueVector();
@@ -276,17 +322,46 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
276322
inputsVector.deallocate();
277323
}
278324

279-
/**
280-
* MEthod only used in MacOS Intel and Windows systems that makes all the arrangements
281-
* to create another process, communicate the model info and tensors to the other
282-
* process and then retrieve the results of the other process
283-
* @param inputTensors
284-
* tensors that are going to be run on the model
285-
* @param outputTensors
286-
* expected results of the model
287-
* @throws RunModelException if there is any issue running the model
288-
*/
289-
public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
325+
protected List<String> inferenceFromShmas(List<String> inputs) throws IOException, RunModelException {
326+
IValueVector inputsVector = new IValueVector();
327+
for (String ee : inputs) {
328+
Map<String, Object> decoded = Types.decode(ee);
329+
SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY));
330+
org.bytedeco.pytorch.Tensor inT = TensorBuilder.build(shma);
331+
inputsVector.put(new IValue(inT));
332+
if (PlatformDetection.isWindows()) shma.close();
333+
}
334+
// Run model
335+
model.eval();
336+
IValue output = model.forward(inputsVector);
337+
TensorVector outputTensorVector = null;
338+
if (output.isTensorList()) {
339+
outputTensorVector = output.toTensorVector();
340+
} else {
341+
outputTensorVector = new TensorVector();
342+
outputTensorVector.put(output.toTensor());
343+
}
344+
345+
shmaNamesList = new ArrayList<String>();
346+
for (int i = 0; i < outputTensorVector.size(); i ++) {
347+
String name = SharedMemoryArray.createShmName();
348+
ShmBuilder.build(outputTensorVector.get(i), name, false);
349+
shmaNamesList.add(name);
350+
}
351+
outputTensorVector.close();
352+
outputTensorVector.deallocate();
353+
output.close();
354+
output.deallocate();
355+
for (int i = 0; i < inputsVector.size(); i ++) {
356+
inputsVector.get(i).close();
357+
inputsVector.get(i).deallocate();
358+
}
359+
inputsVector.close();
360+
inputsVector.deallocate();
361+
return shmaNamesList;
362+
}
363+
364+
private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
290365
void runInterprocessing(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors) throws RunModelException {
291366
shmaInputList = new ArrayList<SharedMemoryArray>();
292367
shmaOutputList = new ArrayList<SharedMemoryArray>();
@@ -297,7 +372,7 @@ void runInterprocessing(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTens
297372
args.put("outputs", encOuts);
298373

299374
try {
300-
Task task = runner.task("inference", args);
375+
Task task = runner.task("run", args);
301376
task.waitFor();
302377
if (task.status == TaskStatus.CANCELED)
303378
throw new RuntimeException();
@@ -328,7 +403,89 @@ else if (task.status == TaskStatus.CRASHED) {
328403
closeShmas();
329404
}
330405

331-
private void closeShmas() {
406+
private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
407+
List<RandomAccessibleInterval<R>> runInterprocessing(List<RandomAccessibleInterval<T>> inputs) throws RunModelException {
408+
shmaInputList = new ArrayList<SharedMemoryArray>();
409+
List<String> encIns = new ArrayList<String>();
410+
Gson gson = new Gson();
411+
for (RandomAccessibleInterval<T> tt : inputs) {
412+
SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt, false, true);
413+
shmaInputList.add(shma);
414+
HashMap<String, Object> map = new HashMap<String, Object>();
415+
map.put(SHAPE_KEY, tt.dimensionsAsLongArray());
416+
map.put(DTYPE_KEY, CommonUtils.getDataTypeFromRAI(tt));
417+
map.put(IS_INPUT_KEY, true);
418+
map.put(MEM_NAME_KEY, shma.getName());
419+
encIns.add(gson.toJson(map));
420+
}
421+
LinkedHashMap<String, Object> args = new LinkedHashMap<String, Object>();
422+
args.put("inputs", encIns);
423+
424+
try {
425+
Task task = runner.task("inference", args);
426+
task.waitFor();
427+
if (task.status == TaskStatus.CANCELED)
428+
throw new RuntimeException();
429+
else if (task.status == TaskStatus.FAILED)
430+
throw new RuntimeException(task.error);
431+
else if (task.status == TaskStatus.CRASHED) {
432+
this.runner.close();
433+
runner = null;
434+
throw new RuntimeException(task.error);
435+
} else if (task.outputs == null)
436+
throw new RuntimeException("No outputs generated");
437+
List<String> outputs = (List<String>) task.outputs.get("encoded");
438+
List<RandomAccessibleInterval<R>> rais = new ArrayList<RandomAccessibleInterval<R>>();
439+
for (String out : outputs) {
440+
String name = (String) Types.decode(out).get(MEM_NAME_KEY);
441+
SharedMemoryArray shm = SharedMemoryArray.read(name);
442+
RandomAccessibleInterval<R> rai = shm.getSharedRAI();
443+
rais.add(Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(rai), Util.getTypeFromInterval(Cast.unchecked(rai))));
444+
shm.close();
445+
}
446+
closeShmas();
447+
return rais;
448+
} catch (Exception e) {
449+
closeShmas();
450+
if (e instanceof RunModelException)
451+
throw (RunModelException) e;
452+
throw new RunModelException(Types.stackTrace(e));
453+
}
454+
}
455+
456+
private void closeInterprocess() throws RunModelException {
457+
try {
458+
Task task = runner.task("closeTensors");
459+
task.waitFor();
460+
if (task.status == TaskStatus.CANCELED)
461+
throw new RuntimeException();
462+
else if (task.status == TaskStatus.FAILED)
463+
throw new RuntimeException(task.error);
464+
else if (task.status == TaskStatus.CRASHED) {
465+
this.runner.close();
466+
runner = null;
467+
throw new RuntimeException(task.error);
468+
}
469+
} catch (Exception e) {
470+
if (e instanceof RunModelException)
471+
throw (RunModelException) e;
472+
throw new RunModelException(Types.stackTrace(e));
473+
}
474+
}
475+
476+
protected void closeFromInterp() {
477+
if (!PlatformDetection.isWindows())
478+
return;
479+
this.shmaNamesList.stream().forEach(nn -> {
480+
try {
481+
SharedMemoryArray.read(nn).close();
482+
} catch (IOException e) {
483+
e.printStackTrace();
484+
}
485+
});
486+
}
487+
488+
private void closeShmas() throws RunModelException {
332489
shmaInputList.forEach(shm -> {
333490
try { shm.close(); } catch (IOException e1) { e1.printStackTrace();}
334491
});
@@ -337,6 +494,8 @@ private void closeShmas() {
337494
try { shm.close(); } catch (IOException e1) { e1.printStackTrace();}
338495
});
339496
shmaOutputList = null;
497+
if (interprocessing)
498+
closeInterprocess();
340499
}
341500

342501
/**

src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java

+31-15
Original file line numberDiff line numberDiff line change
@@ -66,24 +66,40 @@ private ShmBuilder()
6666
* @throws IOException if there is any error creating the shared memory array
6767
*/
6868
public static void build(Tensor tensor, String memoryName) throws IllegalArgumentException, IOException
69+
{
70+
build(tensor, memoryName, true);
71+
}
72+
73+
/**
74+
* Create a {@link SharedMemoryArray} from a {@link Tensor}
75+
* @param tensor
76+
* the tensor to be passed into the other process through the shared memory
77+
* @param memoryName
78+
* the name of the memory region where the tensor is going to be copied
79+
* @param close
80+
* on Windows, whether to close the shma after reading or creating it
81+
* @throws IllegalArgumentException if the data type of the tensor is not supported
82+
* @throws IOException if there is any error creating the shared memory array
83+
*/
84+
public static void build(Tensor tensor, String memoryName, boolean close) throws IllegalArgumentException, IOException
6985
{
7086
if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Byte)
7187
|| tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Char)) {
72-
buildFromTensorByte(tensor, memoryName);
88+
buildFromTensorByte(tensor, memoryName, close);
7389
} else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Int)) {
74-
buildFromTensorInt(tensor, memoryName);
90+
buildFromTensorInt(tensor, memoryName, close);
7591
} else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Float)) {
76-
buildFromTensorFloat(tensor, memoryName);
92+
buildFromTensorFloat(tensor, memoryName, close);
7793
} else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Double)) {
78-
buildFromTensorDouble(tensor, memoryName);
94+
buildFromTensorDouble(tensor, memoryName, close);
7995
} else if (tensor.dtype().isScalarType(org.bytedeco.pytorch.global.torch.ScalarType.Long)) {
80-
buildFromTensorLong(tensor, memoryName);
96+
buildFromTensorLong(tensor, memoryName, close);
8197
} else {
8298
throw new IllegalArgumentException("Unsupported tensor type: " + tensor.scalar_type());
8399
}
84100
}
85101

86-
private static void buildFromTensorByte(Tensor tensor, String memoryName) throws IOException
102+
private static void buildFromTensorByte(Tensor tensor, String memoryName, boolean close) throws IOException
87103
{
88104
long[] arrayShape = tensor.shape();
89105
if (CommonUtils.int32Overflows(arrayShape, 1))
@@ -97,10 +113,10 @@ private static void buildFromTensorByte(Tensor tensor, String memoryName) throws
97113
tensor.data_ptr_byte().get(flat);
98114
byteBuffer.put(flat);
99115
shma.getDataBufferNoHeader().put(byteBuffer);
100-
if (PlatformDetection.isWindows()) shma.close();
116+
if (PlatformDetection.isWindows() && close) shma.close();
101117
}
102118

103-
private static void buildFromTensorInt(Tensor tensor, String memoryName) throws IOException
119+
private static void buildFromTensorInt(Tensor tensor, String memoryName, boolean close) throws IOException
104120
{
105121
long[] arrayShape = tensor.shape();
106122
if (CommonUtils.int32Overflows(arrayShape, 4))
@@ -115,10 +131,10 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws
115131
tensor.data_ptr_int().get(flat);
116132
floatBuffer.put(flat);
117133
shma.getDataBufferNoHeader().put(byteBuffer);
118-
if (PlatformDetection.isWindows()) shma.close();
134+
if (PlatformDetection.isWindows() && close) shma.close();
119135
}
120136

121-
private static void buildFromTensorFloat(Tensor tensor, String memoryName) throws IOException
137+
private static void buildFromTensorFloat(Tensor tensor, String memoryName, boolean close) throws IOException
122138
{
123139
long[] arrayShape = tensor.shape();
124140
if (CommonUtils.int32Overflows(arrayShape, 4))
@@ -133,10 +149,10 @@ private static void buildFromTensorFloat(Tensor tensor, String memoryName) throw
133149
tensor.data_ptr_float().get(flat);
134150
floatBuffer.put(flat);
135151
shma.getDataBufferNoHeader().put(byteBuffer);
136-
if (PlatformDetection.isWindows()) shma.close();
152+
if (PlatformDetection.isWindows() && close) shma.close();
137153
}
138154

139-
private static void buildFromTensorDouble(Tensor tensor, String memoryName) throws IOException
155+
private static void buildFromTensorDouble(Tensor tensor, String memoryName, boolean close) throws IOException
140156
{
141157
long[] arrayShape = tensor.shape();
142158
if (CommonUtils.int32Overflows(arrayShape, 8))
@@ -151,10 +167,10 @@ private static void buildFromTensorDouble(Tensor tensor, String memoryName) thro
151167
tensor.data_ptr_double().get(flat);
152168
floatBuffer.put(flat);
153169
shma.getDataBufferNoHeader().put(byteBuffer);
154-
if (PlatformDetection.isWindows()) shma.close();
170+
if (PlatformDetection.isWindows() && close) shma.close();
155171
}
156172

157-
private static void buildFromTensorLong(Tensor tensor, String memoryName) throws IOException
173+
private static void buildFromTensorLong(Tensor tensor, String memoryName, boolean close) throws IOException
158174
{
159175
long[] arrayShape = tensor.shape();
160176
if (CommonUtils.int32Overflows(arrayShape, 8))
@@ -169,6 +185,6 @@ private static void buildFromTensorLong(Tensor tensor, String memoryName) throws
169185
tensor.data_ptr_long().get(flat);
170186
floatBuffer.put(flat);
171187
shma.getDataBufferNoHeader().put(byteBuffer);
172-
if (PlatformDetection.isWindows()) shma.close();
188+
if (PlatformDetection.isWindows() && close) shma.close();
173189
}
174190
}

src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/JavaCPPTensorBuilder.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public class JavaCPPTensorBuilder {
5858
*/
5959
public static < T extends RealType< T > & NativeType< T > > org.bytedeco.pytorch.Tensor build(Tensor<T> tensor) throws IllegalArgumentException
6060
{
61-
return buildFromRai(tensor.getData());
61+
return build(tensor.getData());
6262
}
6363

6464
/**
@@ -71,7 +71,7 @@ public static < T extends RealType< T > & NativeType< T > > org.bytedeco.pytorch
7171
* @return The {@link org.bytedeco.pytorch.Tensor} built from the {@link RandomAccessibleInterval}.
7272
* @throws IllegalArgumentException if the {@link RandomAccessibleInterval} is not supported
7373
*/
74-
public static < T extends RealType< T > & NativeType< T > > org.bytedeco.pytorch.Tensor buildFromRai(RandomAccessibleInterval<T> tensor) throws IllegalArgumentException
74+
public static < T extends RealType< T > & NativeType< T > > org.bytedeco.pytorch.Tensor build(RandomAccessibleInterval<T> tensor) throws IllegalArgumentException
7575
{
7676
if (Util.getTypeFromInterval(tensor) instanceof ByteType) {
7777
return buildFromTensorByte(Cast.unchecked(tensor));

0 commit comments

Comments
 (0)