Skip to content

Commit 0a7a43d

Browse files
committed
adapt to new inference method
1 parent 2b37f9b commit 0a7a43d

File tree

3 files changed

+216
-33
lines changed

3 files changed

+216
-33
lines changed

src/main/java/io/bioimage/modelrunner/tensorflow/v1/JavaWorker.java

+12-3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ public class JavaWorker {
1919
private final String uuid;
2020

2121
private final Tensorflow1Interface ti;
22+
23+
private Map<String, Object> outputs = new HashMap<String, Object>();
2224

2325
private boolean cancelRequested = false;
2426

@@ -88,10 +90,17 @@ private void executeScript(String script, Map<String, Object> inputs) {
8890
try {
8991
if (script.equals("loadModel")) {
9092
ti.loadModel((String) inputs.get("modelFolder"), null);
91-
} else if (script.equals("inference")) {
93+
} else if (script.equals("run")) {
9294
ti.runFromShmas((List<String>) inputs.get("inputs"), (List<String>) inputs.get("outputs"));
95+
} else if (script.equals("inference")) {
96+
List<String> encodedOutputs = ti.inferenceFromShmas((List<String>) inputs.get("inputs"));
97+
HashMap<String, List<String>> out = new HashMap<String, List<String>>();
98+
out.put("encoded", encodedOutputs);
99+
outputs.put("outputs", out);
93100
} else if (script.equals("close")) {
94101
ti.closeModel();
102+
} else if (script.equals("closeTensors")) {
103+
ti.closeFromInterp();
95104
}
96105
} catch(Exception | Error ex) {
97106
this.fail(Types.stackTrace(ex));
@@ -109,7 +118,7 @@ private void reportLaunch() {
109118
}
110119

111120
private void reportCompletion() {
112-
respond(ResponseType.COMPLETION, null);
121+
respond(ResponseType.COMPLETION, outputs);
113122
}
114123

115124
private void update(String message, Integer current, Integer maximum) {
@@ -130,7 +139,7 @@ private void respond(ResponseType responseType, Map<String, Object> args) {
130139
Map<String, Object> response = new HashMap<String, Object>();
131140
response.put("task", uuid);
132141
response.put("responseType", responseType);
133-
if (args != null)
142+
if (args != null && args.keySet().size() > 0)
134143
response.putAll(args);
135144
try {
136145
System.out.println(Types.encode(response));

src/main/java/io/bioimage/modelrunner/tensorflow/v1/Tensorflow1Interface.java

+172-14
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,51 @@ void run(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors)
316316
tt.close();
317317
}
318318
}
319+
320+
@Override
321+
public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<RandomAccessibleInterval<R>> inference(
322+
List<RandomAccessibleInterval<T>> inputs) throws RunModelException {
323+
if (interprocessing) {
324+
return runInterprocessing(inputs);
325+
}
326+
Session session = model.session();
327+
Session.Runner runner = session.runner();
328+
List<org.tensorflow.Tensor<?>> inTensors =
329+
new ArrayList<org.tensorflow.Tensor<?>>();
330+
int c = 0;
331+
for (RandomAccessibleInterval<T> tt : inputs) {
332+
org.tensorflow.Tensor<?> inT = TensorBuilder.build(tt);
333+
inTensors.add(inT);
334+
String inputName = getModelInputName("input" + c, c ++);
335+
runner.feed(inputName, inT);
336+
}
337+
c = 0;
338+
List<String> outputInfo = sig.getOutputsMap().values().stream()
339+
.map(nn -> {
340+
String name = nn.getName();
341+
if (name.endsWith(":0"))
342+
return name.substring(0, name.length() - 2);
343+
return name;
344+
}).collect(Collectors.toList());
345+
346+
for (String name : outputInfo)
347+
runner = runner.fetch(name);
348+
// Run runner
349+
List<org.tensorflow.Tensor<?>> resultPatchTensors = runner.run();
350+
for (org.tensorflow.Tensor<?> tt : inTensors)
351+
tt.close();
352+
List<RandomAccessibleInterval<R>> rais = new ArrayList<RandomAccessibleInterval<R>>();
353+
for (int i = 0; i < resultPatchTensors.size(); i++) {
354+
try {
355+
rais.add(ImgLib2Builder.build(resultPatchTensors.get(i)));
356+
} catch (IllegalArgumentException ex) {
357+
for (org.tensorflow.Tensor<?> tt : resultPatchTensors)
358+
tt.close();
359+
throw new RunModelException(Types.stackTrace(ex));
360+
}
361+
}
362+
return rais;
363+
}
319364

320365
protected void runFromShmas(List<String> inputs, List<String> outputs) throws IOException {
321366
Session session = model.session();
@@ -354,17 +399,46 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
354399
}
355400
}
356401

357-
/**
358-
* MEthod only used in MacOS Intel systems that makes all the arangements
359-
* to create another process, communicate the model info and tensors to the other
360-
* process and then retrieve the results of the other process
361-
* @param inputTensors
362-
* tensors that are going to be run on the model
363-
* @param outputTensors
364-
* expected results of the model
365-
* @throws RunModelException if there is any issue running the model
366-
*/
367-
public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
402+
protected List<String> inferenceFromShmas(List<String> inputs) throws IOException, RunModelException {
403+
Session session = model.session();
404+
Session.Runner runner = session.runner();
405+
List<org.tensorflow.Tensor<?>> inTensors =
406+
new ArrayList<org.tensorflow.Tensor<?>>();
407+
int c = 0;
408+
for (String ee : inputs) {
409+
Map<String, Object> decoded = Types.decode(ee);
410+
SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY));
411+
org.tensorflow.Tensor<?> inT = io.bioimage.modelrunner.tensorflow.v1.shm.TensorBuilder.build(shma);
412+
if (PlatformDetection.isWindows()) shma.close();
413+
inTensors.add(inT);
414+
String inputName = getModelInputName((String) decoded.get(NAME_KEY), c ++);
415+
runner.feed(inputName, inT);
416+
}
417+
List<String> outputInfo = sig.getOutputsMap().values().stream()
418+
.map(nn -> {
419+
String name = nn.getName();
420+
if (name.endsWith(":0"))
421+
return name.substring(0, name.length() - 2);
422+
return name;
423+
}).collect(Collectors.toList());
424+
425+
for (String name : outputInfo)
426+
runner = runner.fetch(name);
427+
// Run runner
428+
List<org.tensorflow.Tensor<?>> resultPatchTensors = runner.run();
429+
for (org.tensorflow.Tensor<?> tt : inTensors)
430+
tt.close();
431+
432+
shmaNamesList = new ArrayList<String>();
433+
for (int i = 0; i < resultPatchTensors.size(); i ++) {
434+
String name = SharedMemoryArray.createShmName();
435+
ShmBuilder.build(resultPatchTensors.get(i), name, false);
436+
shmaNamesList.add(name);
437+
}
438+
return shmaNamesList;
439+
}
440+
441+
private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
368442
void runInterprocessing(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors) throws RunModelException {
369443
shmaInputList = new ArrayList<SharedMemoryArray>();
370444
shmaOutputList = new ArrayList<SharedMemoryArray>();
@@ -375,7 +449,7 @@ void runInterprocessing(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTens
375449
args.put("outputs", encOuts);
376450

377451
try {
378-
Task task = runner.task("inference", args);
452+
Task task = runner.task("run", args);
379453
task.waitFor();
380454
if (task.status == TaskStatus.CANCELED)
381455
throw new RuntimeException();
@@ -406,7 +480,89 @@ else if (task.status == TaskStatus.CRASHED) {
406480
closeShmas();
407481
}
408482

409-
private void closeShmas() {
483+
private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
484+
List<RandomAccessibleInterval<R>> runInterprocessing(List<RandomAccessibleInterval<T>> inputs) throws RunModelException {
485+
shmaInputList = new ArrayList<SharedMemoryArray>();
486+
List<String> encIns = new ArrayList<String>();
487+
Gson gson = new Gson();
488+
for (RandomAccessibleInterval<T> tt : inputs) {
489+
SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt, false, true);
490+
shmaInputList.add(shma);
491+
HashMap<String, Object> map = new HashMap<String, Object>();
492+
map.put(SHAPE_KEY, tt.dimensionsAsLongArray());
493+
map.put(DTYPE_KEY, CommonUtils.getDataTypeFromRAI(tt));
494+
map.put(IS_INPUT_KEY, true);
495+
map.put(MEM_NAME_KEY, shma.getName());
496+
encIns.add(gson.toJson(map));
497+
}
498+
LinkedHashMap<String, Object> args = new LinkedHashMap<String, Object>();
499+
args.put("inputs", encIns);
500+
501+
try {
502+
Task task = runner.task("inference", args);
503+
task.waitFor();
504+
if (task.status == TaskStatus.CANCELED)
505+
throw new RuntimeException();
506+
else if (task.status == TaskStatus.FAILED)
507+
throw new RuntimeException(task.error);
508+
else if (task.status == TaskStatus.CRASHED) {
509+
this.runner.close();
510+
runner = null;
511+
throw new RuntimeException(task.error);
512+
} else if (task.outputs == null)
513+
throw new RuntimeException("No outputs generated");
514+
List<String> outputs = (List<String>) task.outputs.get("encoded");
515+
List<RandomAccessibleInterval<R>> rais = new ArrayList<RandomAccessibleInterval<R>>();
516+
for (String out : outputs) {
517+
String name = (String) Types.decode(out).get(MEM_NAME_KEY);
518+
SharedMemoryArray shm = SharedMemoryArray.read(name);
519+
RandomAccessibleInterval<R> rai = shm.getSharedRAI();
520+
rais.add(Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(rai), Util.getTypeFromInterval(Cast.unchecked(rai))));
521+
shm.close();
522+
}
523+
closeShmas();
524+
return rais;
525+
} catch (Exception e) {
526+
closeShmas();
527+
if (e instanceof RunModelException)
528+
throw (RunModelException) e;
529+
throw new RunModelException(Types.stackTrace(e));
530+
}
531+
}
532+
533+
private void closeInterprocess() throws RunModelException {
534+
try {
535+
Task task = runner.task("closeTensors");
536+
task.waitFor();
537+
if (task.status == TaskStatus.CANCELED)
538+
throw new RuntimeException();
539+
else if (task.status == TaskStatus.FAILED)
540+
throw new RuntimeException(task.error);
541+
else if (task.status == TaskStatus.CRASHED) {
542+
this.runner.close();
543+
runner = null;
544+
throw new RuntimeException(task.error);
545+
}
546+
} catch (Exception e) {
547+
if (e instanceof RunModelException)
548+
throw (RunModelException) e;
549+
throw new RunModelException(Types.stackTrace(e));
550+
}
551+
}
552+
553+
protected void closeFromInterp() {
554+
if (!PlatformDetection.isWindows())
555+
return;
556+
this.shmaNamesList.stream().forEach(nn -> {
557+
try {
558+
SharedMemoryArray.read(nn).close();
559+
} catch (IOException e) {
560+
e.printStackTrace();
561+
}
562+
});
563+
}
564+
565+
private void closeShmas() throws RunModelException {
410566
shmaInputList.forEach(shm -> {
411567
try { shm.close(); } catch (IOException e1) { e1.printStackTrace();}
412568
});
@@ -415,6 +571,8 @@ private void closeShmas() {
415571
try { shm.close(); } catch (IOException e1) { e1.printStackTrace();}
416572
});
417573
shmaOutputList = null;
574+
if (interprocessing)
575+
closeInterprocess();
418576
}
419577

420578

@@ -483,7 +641,7 @@ public static <T extends RealType<T> & NativeType<T>> void fillOutputTensors(
483641
try {
484642
outputTensors.get(i).setData(ImgLib2Builder.build(outputNDArrays.get(i)));
485643
} catch (IllegalArgumentException ex) {
486-
throw new RunModelException(ex.toString());
644+
throw new RunModelException(Types.stackTrace(ex));
487645
}
488646
}
489647
}

src/main/java/io/bioimage/modelrunner/tensorflow/v1/shm/ShmBuilder.java

+32-16
Original file line numberDiff line numberDiff line change
@@ -61,32 +61,48 @@ private ShmBuilder()
6161
* @throws IllegalArgumentException if the data type of the tensor is not supported
6262
* @throws IOException if there is any error creating the shared memory array
6363
*/
64-
@SuppressWarnings("unchecked")
6564
public static void build(Tensor<?> tensor, String memoryName) throws IllegalArgumentException, IOException
65+
{
66+
build(tensor, memoryName, true);
67+
}
68+
69+
/**
70+
* Create a {@link SharedMemoryArray} from a {@link Tensor} tensor
71+
* @param tensor
72+
* the tensor to be passed into the other process through the shared memory
73+
* @param memoryName
74+
* the name of the memory region where the tensor is going to be copied
75+
* @param close
76+
* on Windows, whether to close the shma after creating it or reading it
77+
* @throws IllegalArgumentException if the data type of the tensor is not supported
78+
* @throws IOException if there is any error creating the shared memory array
79+
*/
80+
@SuppressWarnings("unchecked")
81+
public static void build(Tensor<?> tensor, String memoryName, boolean close) throws IllegalArgumentException, IOException
6682
{
6783
switch (tensor.dataType())
6884
{
6985
case UINT8:
70-
buildFromTensorUByte((Tensor<UInt8>) tensor, memoryName);
86+
buildFromTensorUByte((Tensor<UInt8>) tensor, memoryName, close);
7187
break;
7288
case INT32:
73-
buildFromTensorInt((Tensor<Integer>) tensor, memoryName);
89+
buildFromTensorInt((Tensor<Integer>) tensor, memoryName, close);
7490
break;
7591
case FLOAT:
76-
buildFromTensorFloat((Tensor<Float>) tensor, memoryName);
92+
buildFromTensorFloat((Tensor<Float>) tensor, memoryName, close);
7793
break;
7894
case DOUBLE:
79-
buildFromTensorDouble((Tensor<Double>) tensor, memoryName);
95+
buildFromTensorDouble((Tensor<Double>) tensor, memoryName, close);
8096
break;
8197
case INT64:
82-
buildFromTensorLong((Tensor<Long>) tensor, memoryName);
98+
buildFromTensorLong((Tensor<Long>) tensor, memoryName, close);
8399
break;
84100
default:
85101
throw new IllegalArgumentException("Unsupported tensor type: " + tensor.dataType().name());
86102
}
87103
}
88104

89-
private static void buildFromTensorUByte(Tensor<UInt8> tensor, String memoryName) throws IOException
105+
private static void buildFromTensorUByte(Tensor<UInt8> tensor, String memoryName, boolean close) throws IOException
90106
{
91107
long[] arrayShape = tensor.shape();
92108
if (CommonUtils.int32Overflows(arrayShape, 1))
@@ -95,10 +111,10 @@ private static void buildFromTensorUByte(Tensor<UInt8> tensor, String memoryName
95111
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
96112
ByteBuffer buff = shma.getDataBufferNoHeader();
97113
tensor.writeTo(buff);
98-
if (PlatformDetection.isWindows()) shma.close();
114+
if (PlatformDetection.isWindows() && close) shma.close();
99115
}
100116

101-
private static void buildFromTensorInt(Tensor<Integer> tensor, String memoryName) throws IOException
117+
private static void buildFromTensorInt(Tensor<Integer> tensor, String memoryName, boolean close) throws IOException
102118
{
103119
long[] arrayShape = tensor.shape();
104120
if (CommonUtils.int32Overflows(arrayShape, 4))
@@ -108,10 +124,10 @@ private static void buildFromTensorInt(Tensor<Integer> tensor, String memoryName
108124
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
109125
ByteBuffer buff = shma.getDataBufferNoHeader();
110126
tensor.writeTo(buff);
111-
if (PlatformDetection.isWindows()) shma.close();
127+
if (PlatformDetection.isWindows() && close) shma.close();
112128
}
113129

114-
private static void buildFromTensorFloat(Tensor<Float> tensor, String memoryName) throws IOException
130+
private static void buildFromTensorFloat(Tensor<Float> tensor, String memoryName, boolean close) throws IOException
115131
{
116132
long[] arrayShape = tensor.shape();
117133
if (CommonUtils.int32Overflows(arrayShape, 4))
@@ -121,10 +137,10 @@ private static void buildFromTensorFloat(Tensor<Float> tensor, String memoryName
121137
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
122138
ByteBuffer buff = shma.getDataBufferNoHeader();
123139
tensor.writeTo(buff);
124-
if (PlatformDetection.isWindows()) shma.close();
140+
if (PlatformDetection.isWindows() && close) shma.close();
125141
}
126142

127-
private static void buildFromTensorDouble(Tensor<Double> tensor, String memoryName) throws IOException
143+
private static void buildFromTensorDouble(Tensor<Double> tensor, String memoryName, boolean close) throws IOException
128144
{
129145
long[] arrayShape = tensor.shape();
130146
if (CommonUtils.int32Overflows(arrayShape, 8))
@@ -134,10 +150,10 @@ private static void buildFromTensorDouble(Tensor<Double> tensor, String memoryNa
134150
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true);
135151
ByteBuffer buff = shma.getDataBufferNoHeader();
136152
tensor.writeTo(buff);
137-
if (PlatformDetection.isWindows()) shma.close();
153+
if (PlatformDetection.isWindows() && close) shma.close();
138154
}
139155

140-
private static void buildFromTensorLong(Tensor<Long> tensor, String memoryName) throws IOException
156+
private static void buildFromTensorLong(Tensor<Long> tensor, String memoryName, boolean close) throws IOException
141157
{
142158
long[] arrayShape = tensor.shape();
143159
if (CommonUtils.int32Overflows(arrayShape, 8))
@@ -148,6 +164,6 @@ private static void buildFromTensorLong(Tensor<Long> tensor, String memoryName)
148164
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true);
149165
ByteBuffer buff = shma.getDataBufferNoHeader();
150166
tensor.writeTo(buff);
151-
if (PlatformDetection.isWindows()) shma.close();
167+
if (PlatformDetection.isWindows() && close) shma.close();
152168
}
153169
}

0 commit comments

Comments
 (0)