Skip to content

Commit 790f7a4

Browse files
committed
add new method inference
1 parent 9ede4b1 commit 790f7a4

File tree

3 files changed

+215
-22
lines changed

3 files changed

+215
-22
lines changed

Diff for: src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/JavaWorker.java

+12-3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ public class JavaWorker {
1919
private final String uuid;
2020

2121
private final Tensorflow2Interface ti;
22+
23+
private Map<String, Object> outputs = new HashMap<String, Object>();
2224

2325
private boolean cancelRequested = false;
2426

@@ -88,10 +90,17 @@ private void executeScript(String script, Map<String, Object> inputs) {
8890
try {
8991
if (script.equals("loadModel")) {
9092
ti.loadModel((String) inputs.get("modelFolder"), null);
91-
} else if (script.equals("inference")) {
93+
} else if (script.equals("run")) {
9294
ti.runFromShmas((List<String>) inputs.get("inputs"), (List<String>) inputs.get("outputs"));
95+
} else if (script.equals("inference")) {
96+
List<String> encodedOutputs = ti.inferenceFromShmas((List<String>) inputs.get("inputs"));
97+
HashMap<String, List<String>> out = new HashMap<String, List<String>>();
98+
out.put("encoded", encodedOutputs);
99+
outputs.put("outputs", out);
93100
} else if (script.equals("close")) {
94101
ti.closeModel();
102+
} else if (script.equals("closeTensors")) {
103+
ti.closeFromInterp();
95104
}
96105
} catch(Exception | Error ex) {
97106
this.fail(Types.stackTrace(ex));
@@ -109,7 +118,7 @@ private void reportLaunch() {
109118
}
110119

111120
private void reportCompletion() {
112-
respond(ResponseType.COMPLETION, null);
121+
respond(ResponseType.COMPLETION, outputs);
113122
}
114123

115124
private void update(String message, Integer current, Integer maximum) {
@@ -130,7 +139,7 @@ private void respond(ResponseType responseType, Map<String, Object> args) {
130139
Map<String, Object> response = new HashMap<String, Object>();
131140
response.put("task", uuid);
132141
response.put("responseType", responseType);
133-
if (args != null)
142+
if (args != null && args.keySet().size() > 0)
134143
response.putAll(args);
135144
try {
136145
System.out.println(Types.encode(response));

Diff for: src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface.java

+170-3
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,51 @@ void run(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors)
307307
tt.close();
308308
}
309309
}
310+
311+
@Override
312+
public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<RandomAccessibleInterval<R>> inference(
313+
List<RandomAccessibleInterval<T>> inputs) throws RunModelException {
314+
if (interprocessing) {
315+
return runInterprocessing(inputs);
316+
}
317+
Session session = model.session();
318+
Session.Runner runner = session.runner();
319+
List<String> inputListNames = new ArrayList<String>();
320+
List<org.tensorflow.Tensor<?>> inTensors = new ArrayList<org.tensorflow.Tensor<?>>();
321+
int c = 0;
322+
for (RandomAccessibleInterval<T> tt : inputs) {
323+
org.tensorflow.Tensor<?> inT = TensorBuilder.build(tt);
324+
inTensors.add(inT);
325+
String inputName = getModelInputName("input" + c, c ++);
326+
runner.feed(inputName, inT);
327+
}
328+
c = 0;
329+
List<String> outputInfo = sig.getOutputsMap().values().stream()
330+
.map(nn -> {
331+
String name = nn.getName();
332+
if (name.endsWith(":0"))
333+
return name.substring(0, name.length() - 2);
334+
return name;
335+
}).collect(Collectors.toList());
336+
337+
for (String name : outputInfo)
338+
runner = runner.fetch(name);
339+
// Run runner
340+
List<org.tensorflow.Tensor<?>> resultPatchTensors = runner.run();
341+
for (org.tensorflow.Tensor<?> tt : inTensors)
342+
tt.close();
343+
List<RandomAccessibleInterval<R>> rais = new ArrayList<RandomAccessibleInterval<R>>();
344+
for (int i = 0; i < resultPatchTensors.size(); i++) {
345+
try {
346+
rais.add(ImgLib2Builder.build(resultPatchTensors.get(i)));
347+
} catch (IllegalArgumentException ex) {
348+
for (org.tensorflow.Tensor<?> tt : resultPatchTensors)
349+
tt.close();
350+
throw new RunModelException(Types.stackTrace(ex));
351+
}
352+
}
353+
return rais;
354+
}
310355

311356
protected void runFromShmas(List<String> inputs, List<String> outputs) throws IOException {
312357
Session session = model.session();
@@ -345,6 +390,45 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
345390
}
346391
}
347392

393+
protected List<String> inferenceFromShmas(List<String> inputs) throws IOException, RunModelException {
394+
Session session = model.session();
395+
Session.Runner runner = session.runner();
396+
List<org.tensorflow.Tensor<?>> inTensors =
397+
new ArrayList<org.tensorflow.Tensor<?>>();
398+
int c = 0;
399+
for (String ee : inputs) {
400+
Map<String, Object> decoded = Types.decode(ee);
401+
SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY));
402+
org.tensorflow.Tensor<?> inT = io.bioimage.modelrunner.tensorflow.v2.api020.shm.TensorBuilder.build(shma);
403+
if (PlatformDetection.isWindows()) shma.close();
404+
inTensors.add(inT);
405+
String inputName = getModelInputName((String) decoded.get(NAME_KEY), c ++);
406+
runner.feed(inputName, inT);
407+
}
408+
List<String> outputInfo = sig.getOutputsMap().values().stream()
409+
.map(nn -> {
410+
String name = nn.getName();
411+
if (name.endsWith(":0"))
412+
return name.substring(0, name.length() - 2);
413+
return name;
414+
}).collect(Collectors.toList());
415+
416+
for (String name : outputInfo)
417+
runner = runner.fetch(name);
418+
// Run runner
419+
List<org.tensorflow.Tensor<?>> resultPatchTensors = runner.run();
420+
for (org.tensorflow.Tensor<?> tt : inTensors)
421+
tt.close();
422+
423+
shmaNamesList = new ArrayList<String>();
424+
for (int i = 0; i < resultPatchTensors.size(); i ++) {
425+
String name = SharedMemoryArray.createShmName();
426+
ShmBuilder.build(resultPatchTensors.get(i), name, false);
427+
shmaNamesList.add(name);
428+
}
429+
return shmaNamesList;
430+
}
431+
348432
/**
349433
* MEthod only used in MacOS Intel and Windows systems that makes all the arrangements
350434
* to create another process, communicate the model info and tensors to the other
@@ -397,7 +481,89 @@ else if (task.status == TaskStatus.CRASHED) {
397481
closeShmas();
398482
}
399483

400-
private void closeShmas() {
484+
private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
485+
List<RandomAccessibleInterval<R>> runInterprocessing(List<RandomAccessibleInterval<T>> inputs) throws RunModelException {
486+
shmaInputList = new ArrayList<SharedMemoryArray>();
487+
List<String> encIns = new ArrayList<String>();
488+
Gson gson = new Gson();
489+
for (RandomAccessibleInterval<T> tt : inputs) {
490+
SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt, false, true);
491+
shmaInputList.add(shma);
492+
HashMap<String, Object> map = new HashMap<String, Object>();
493+
map.put(SHAPE_KEY, tt.dimensionsAsLongArray());
494+
map.put(DTYPE_KEY, CommonUtils.getDataTypeFromRAI(tt));
495+
map.put(IS_INPUT_KEY, true);
496+
map.put(MEM_NAME_KEY, shma.getName());
497+
encIns.add(gson.toJson(map));
498+
}
499+
LinkedHashMap<String, Object> args = new LinkedHashMap<String, Object>();
500+
args.put("inputs", encIns);
501+
502+
try {
503+
Task task = runner.task("inference", args);
504+
task.waitFor();
505+
if (task.status == TaskStatus.CANCELED)
506+
throw new RuntimeException();
507+
else if (task.status == TaskStatus.FAILED)
508+
throw new RuntimeException(task.error);
509+
else if (task.status == TaskStatus.CRASHED) {
510+
this.runner.close();
511+
runner = null;
512+
throw new RuntimeException(task.error);
513+
} else if (task.outputs == null)
514+
throw new RuntimeException("No outputs generated");
515+
List<String> outputs = (List<String>) task.outputs.get("encoded");
516+
List<RandomAccessibleInterval<R>> rais = new ArrayList<RandomAccessibleInterval<R>>();
517+
for (String out : outputs) {
518+
String name = (String) Types.decode(out).get(MEM_NAME_KEY);
519+
SharedMemoryArray shm = SharedMemoryArray.read(name);
520+
RandomAccessibleInterval<R> rai = shm.getSharedRAI();
521+
rais.add(Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(rai), Util.getTypeFromInterval(Cast.unchecked(rai))));
522+
shm.close();
523+
}
524+
closeShmas();
525+
return rais;
526+
} catch (Exception e) {
527+
closeShmas();
528+
if (e instanceof RunModelException)
529+
throw (RunModelException) e;
530+
throw new RunModelException(Types.stackTrace(e));
531+
}
532+
}
533+
534+
private void closeInterprocess() throws RunModelException {
535+
try {
536+
Task task = runner.task("closeTensors");
537+
task.waitFor();
538+
if (task.status == TaskStatus.CANCELED)
539+
throw new RuntimeException();
540+
else if (task.status == TaskStatus.FAILED)
541+
throw new RuntimeException(task.error);
542+
else if (task.status == TaskStatus.CRASHED) {
543+
this.runner.close();
544+
runner = null;
545+
throw new RuntimeException(task.error);
546+
}
547+
} catch (Exception e) {
548+
if (e instanceof RunModelException)
549+
throw (RunModelException) e;
550+
throw new RunModelException(Types.stackTrace(e));
551+
}
552+
}
553+
554+
protected void closeFromInterp() {
555+
if (!PlatformDetection.isWindows())
556+
return;
557+
this.shmaNamesList.stream().forEach(nn -> {
558+
try {
559+
SharedMemoryArray.read(nn).close();
560+
} catch (IOException e) {
561+
e.printStackTrace();
562+
}
563+
});
564+
}
565+
566+
private void closeShmas() throws RunModelException {
401567
shmaInputList.forEach(shm -> {
402568
try { shm.close(); } catch (IOException e1) { e1.printStackTrace();}
403569
});
@@ -406,8 +572,9 @@ private void closeShmas() {
406572
try { shm.close(); } catch (IOException e1) { e1.printStackTrace();}
407573
});
408574
shmaOutputList = null;
409-
}
410-
575+
if (interprocessing)
576+
closeInterprocess();
577+
}
411578

412579
private <T extends RealType<T> & NativeType<T>> List<String> encodeInputs(List<Tensor<T>> inputTensors) {
413580
List<String> encodedInputTensors = new ArrayList<String>();

Diff for: src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/shm/ShmBuilder.java

+33-16
Original file line numberDiff line numberDiff line change
@@ -67,32 +67,49 @@ private ShmBuilder()
6767
* @throws IllegalArgumentException if the data type of the tensor is not supported
6868
* @throws IOException if there is any error creating the shared memory array
6969
*/
70-
@SuppressWarnings("unchecked")
7170
public static void build(Tensor<? extends TType> tensor, String memoryName) throws IllegalArgumentException, IOException
71+
{
72+
build(tensor, memoryName, true);
73+
}
74+
75+
/**
76+
* Create a {@link SharedMemoryArray} from a {@link Tensor}
77+
* @param tensor
78+
* the tensor to be passed into the other process through the shared memory
79+
* @param memoryName
80+
* the name of the memory region where the tensor is going to be copied
81+
* @param close
82+
* on Windows, whether to close the shma after creating it or reading it
83+
* @throws IllegalArgumentException if the data type of the tensor is not supported
84+
* @throws IllegalArgumentException if the data type of the tensor is not supported
85+
* @throws IOException if there is any error creating the shared memory array
86+
*/
87+
@SuppressWarnings("unchecked")
88+
public static void build(Tensor<? extends TType> tensor, String memoryName, boolean close) throws IllegalArgumentException, IOException
7289
{
7390
switch (tensor.dataType().name())
7491
{
7592
case TUint8.NAME:
76-
buildFromTensorUByte((Tensor<TUint8>) tensor, memoryName);
93+
buildFromTensorUByte((Tensor<TUint8>) tensor, memoryName, close);
7794
break;
7895
case TInt32.NAME:
79-
buildFromTensorInt((Tensor<TInt32>) tensor, memoryName);
96+
buildFromTensorInt((Tensor<TInt32>) tensor, memoryName, close);
8097
break;
8198
case TFloat32.NAME:
82-
buildFromTensorFloat((Tensor<TFloat32>) tensor, memoryName);
99+
buildFromTensorFloat((Tensor<TFloat32>) tensor, memoryName, close);
83100
break;
84101
case TFloat64.NAME:
85-
buildFromTensorDouble((Tensor<TFloat64>) tensor, memoryName);
102+
buildFromTensorDouble((Tensor<TFloat64>) tensor, memoryName, close);
86103
break;
87104
case TInt64.NAME:
88-
buildFromTensorLong((Tensor<TInt64>) tensor, memoryName);
105+
buildFromTensorLong((Tensor<TInt64>) tensor, memoryName, close);
89106
break;
90107
default:
91108
throw new IllegalArgumentException("Unsupported tensor type: " + tensor.dataType().name());
92109
}
93110
}
94111

95-
private static void buildFromTensorUByte(Tensor<TUint8> tensor, String memoryName) throws IOException
112+
private static void buildFromTensorUByte(Tensor<TUint8> tensor, String memoryName, boolean close) throws IOException
96113
{
97114
long[] arrayShape = tensor.shape().asArray();
98115
if (CommonUtils.int32Overflows(arrayShape, 1))
@@ -104,10 +121,10 @@ private static void buildFromTensorUByte(Tensor<TUint8> tensor, String memoryNam
104121
ByteBuffer buff2 = ByteBuffer.wrap(flat).order(ByteOrder.LITTLE_ENDIAN);
105122
tensor.rawData().read(flat, 0, buff.capacity());
106123
buff.put(buff2);
107-
if (PlatformDetection.isWindows()) shma.close();
124+
if (PlatformDetection.isWindows() && close) shma.close();
108125
}
109126

110-
private static void buildFromTensorInt(Tensor<TInt32> tensor, String memoryName) throws IOException
127+
private static void buildFromTensorInt(Tensor<TInt32> tensor, String memoryName, boolean close) throws IOException
111128
{
112129
long[] arrayShape = tensor.shape().asArray();
113130
if (CommonUtils.int32Overflows(arrayShape, 4))
@@ -120,10 +137,10 @@ private static void buildFromTensorInt(Tensor<TInt32> tensor, String memoryName)
120137
ByteBuffer buff2 = ByteBuffer.wrap(flat).order(ByteOrder.LITTLE_ENDIAN);
121138
tensor.rawData().read(flat, 0, buff.capacity());
122139
buff.put(buff2);
123-
if (PlatformDetection.isWindows()) shma.close();
140+
if (PlatformDetection.isWindows() && close) shma.close();
124141
}
125142

126-
private static void buildFromTensorFloat(Tensor<TFloat32> tensor, String memoryName) throws IOException
143+
private static void buildFromTensorFloat(Tensor<TFloat32> tensor, String memoryName, boolean close) throws IOException
127144
{
128145
long[] arrayShape = tensor.shape().asArray();
129146
if (CommonUtils.int32Overflows(arrayShape, 4))
@@ -136,10 +153,10 @@ private static void buildFromTensorFloat(Tensor<TFloat32> tensor, String memoryN
136153
ByteBuffer buff2 = ByteBuffer.wrap(flat).order(ByteOrder.LITTLE_ENDIAN);
137154
tensor.rawData().read(flat, 0, buff.capacity());
138155
buff.put(buff2);
139-
if (PlatformDetection.isWindows()) shma.close();
156+
if (PlatformDetection.isWindows() && close) shma.close();
140157
}
141158

142-
private static void buildFromTensorDouble(Tensor<TFloat64> tensor, String memoryName) throws IOException
159+
private static void buildFromTensorDouble(Tensor<TFloat64> tensor, String memoryName, boolean close) throws IOException
143160
{
144161
long[] arrayShape = tensor.shape().asArray();
145162
if (CommonUtils.int32Overflows(arrayShape, 8))
@@ -152,10 +169,10 @@ private static void buildFromTensorDouble(Tensor<TFloat64> tensor, String memory
152169
ByteBuffer buff2 = ByteBuffer.wrap(flat).order(ByteOrder.LITTLE_ENDIAN);
153170
tensor.rawData().read(flat, 0, buff.capacity());
154171
buff.put(buff2);
155-
if (PlatformDetection.isWindows()) shma.close();
172+
if (PlatformDetection.isWindows() && close) shma.close();
156173
}
157174

158-
private static void buildFromTensorLong(Tensor<TInt64> tensor, String memoryName) throws IOException
175+
private static void buildFromTensorLong(Tensor<TInt64> tensor, String memoryName, boolean close) throws IOException
159176
{
160177
long[] arrayShape = tensor.shape().asArray();
161178
if (CommonUtils.int32Overflows(arrayShape, 8))
@@ -169,6 +186,6 @@ private static void buildFromTensorLong(Tensor<TInt64> tensor, String memoryName
169186
ByteBuffer buff2 = ByteBuffer.wrap(flat).order(ByteOrder.LITTLE_ENDIAN);
170187
tensor.rawData().read(flat, 0, buff.capacity());
171188
buff.put(buff2);
172-
if (PlatformDetection.isWindows()) shma.close();
189+
if (PlatformDetection.isWindows() && close) shma.close();
173190
}
174191
}

0 commit comments

Comments
 (0)