Skip to content

Commit 6c12f4a

Browse files
committed
add new method inference
1 parent 1124257 commit 6c12f4a

File tree

3 files changed

+213
-19
lines changed

3 files changed

+213
-19
lines changed

src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/JavaWorker.java

+12-3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ public class JavaWorker {
1919
private final String uuid;
2020

2121
private final Tensorflow2Interface ti;
22+
23+
private Map<String, Object> outputs = new HashMap<String, Object>();
2224

2325
private boolean cancelRequested = false;
2426

@@ -88,10 +90,17 @@ private void executeScript(String script, Map<String, Object> inputs) {
8890
try {
8991
if (script.equals("loadModel")) {
9092
ti.loadModel((String) inputs.get("modelFolder"), null);
91-
} else if (script.equals("inference")) {
93+
} else if (script.equals("run")) {
9294
ti.runFromShmas((List<String>) inputs.get("inputs"), (List<String>) inputs.get("outputs"));
95+
} else if (script.equals("inference")) {
96+
List<String> encodedOutputs = ti.inferenceFromShmas((List<String>) inputs.get("inputs"));
97+
HashMap<String, List<String>> out = new HashMap<String, List<String>>();
98+
out.put("encoded", encodedOutputs);
99+
outputs.put("outputs", out);
93100
} else if (script.equals("close")) {
94101
ti.closeModel();
102+
} else if (script.equals("closeTensors")) {
103+
ti.closeFromInterp();
95104
}
96105
} catch(Exception | Error ex) {
97106
this.fail(Types.stackTrace(ex));
@@ -109,7 +118,7 @@ private void reportLaunch() {
109118
}
110119

111120
private void reportCompletion() {
112-
respond(ResponseType.COMPLETION, null);
121+
respond(ResponseType.COMPLETION, outputs);
113122
}
114123

115124
private void update(String message, Integer current, Integer maximum) {
@@ -130,7 +139,7 @@ private void respond(ResponseType responseType, Map<String, Object> args) {
130139
Map<String, Object> response = new HashMap<String, Object>();
131140
response.put("task", uuid);
132141
response.put("responseType", responseType);
133-
if (args != null)
142+
if (args != null && args.keySet().size() > 0)
134143
response.putAll(args);
135144
try {
136145
System.out.println(Types.encode(response));

src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/Tensorflow2Interface.java

+170-1
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,51 @@ void run(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors)
295295
resultPatchTensors.get(i).close();
296296
}
297297
}
298+
299+
@Override
300+
public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<RandomAccessibleInterval<R>> inference(
301+
List<RandomAccessibleInterval<T>> inputs) throws RunModelException {
302+
if (interprocessing) {
303+
return runInterprocessing(inputs);
304+
}
305+
Session session = model.session();
306+
Session.Runner runner = session.runner();
307+
List<TType> inTensors = new ArrayList<TType>();
308+
int c = 0;
309+
for (RandomAccessibleInterval<T> tt : inputs) {
310+
TType inT = TensorBuilder.build(tt);
311+
inTensors.add(inT);
312+
String inputName = getModelInputName("input" + c, c ++);
313+
runner.feed(inputName, inT);
314+
}
315+
c = 0;
316+
List<String> outputInfo = sig.getOutputsMap().values().stream()
317+
.map(nn -> {
318+
String name = nn.getName();
319+
if (name.endsWith(":0"))
320+
return name.substring(0, name.length() - 2);
321+
return name;
322+
}).collect(Collectors.toList());
323+
324+
for (String name : outputInfo)
325+
runner = runner.fetch(name);
326+
// Run runner
327+
Result resultPatchTensors = runner.run();
328+
for (TType tt : inTensors)
329+
tt.close();
330+
List<RandomAccessibleInterval<R>> rais = new ArrayList<RandomAccessibleInterval<R>>();
331+
for (int i = 0; i < resultPatchTensors.size(); i++) {
332+
try {
333+
rais.add(ImgLib2Builder.build((TType) resultPatchTensors.get(i)));
334+
((TType) resultPatchTensors.get(i)).close();
335+
} catch (IllegalArgumentException ex) {
336+
for (int j = i; j < resultPatchTensors.size(); j++)
337+
((TType) resultPatchTensors.get(j)).close();
338+
throw new RunModelException(Types.stackTrace(ex));
339+
}
340+
}
341+
return rais;
342+
}
298343

299344
protected void runFromShmas(List<String> inputs, List<String> outputs) throws IOException {
300345
Session session = model.session();
@@ -333,6 +378,46 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
333378
}
334379
}
335380

381+
protected List<String> inferenceFromShmas(List<String> inputs) throws IOException, RunModelException {
382+
Session session = model.session();
383+
Session.Runner runner = session.runner();
384+
List<TType> inTensors = new ArrayList<TType>();
385+
int c = 0;
386+
for (String ee : inputs) {
387+
Map<String, Object> decoded = Types.decode(ee);
388+
SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY));
389+
TType inT = io.bioimage.modelrunner.tensorflow.v2.api050.shm.TensorBuilder.build(shma);
390+
if (PlatformDetection.isWindows()) shma.close();
391+
inTensors.add(inT);
392+
String inputName = getModelInputName((String) decoded.get(NAME_KEY), c ++);
393+
runner.feed(inputName, inT);
394+
}
395+
List<String> outputInfo = sig.getOutputsMap().values().stream()
396+
.map(nn -> {
397+
String name = nn.getName();
398+
if (name.endsWith(":0"))
399+
return name.substring(0, name.length() - 2);
400+
return name;
401+
}).collect(Collectors.toList());
402+
403+
for (String name : outputInfo)
404+
runner = runner.fetch(name);
405+
// Run runner
406+
Result resultPatchTensors = runner.run();
407+
for (TType tt : inTensors)
408+
tt.close();
409+
410+
shmaNamesList = new ArrayList<String>();
411+
for (int i = 0; i < resultPatchTensors.size(); i ++) {
412+
String name = SharedMemoryArray.createShmName();
413+
ShmBuilder.build((TType) resultPatchTensors.get(i), name, false);
414+
shmaNamesList.add(name);
415+
}
416+
for (int i = 0; i < resultPatchTensors.size(); i ++)
417+
resultPatchTensors.get(i).close();
418+
return shmaNamesList;
419+
}
420+
336421
/**
337422
* MEthod only used in MacOS Intel and Windows systems that makes all the arrangements
338423
* to create another process, communicate the model info and tensors to the other
@@ -389,7 +474,89 @@ else if (task.status == TaskStatus.CRASHED) {
389474
closeShmas();
390475
}
391476

392-
private void closeShmas() {
477+
private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>>
478+
List<RandomAccessibleInterval<R>> runInterprocessing(List<RandomAccessibleInterval<T>> inputs) throws RunModelException {
479+
shmaInputList = new ArrayList<SharedMemoryArray>();
480+
List<String> encIns = new ArrayList<String>();
481+
Gson gson = new Gson();
482+
for (RandomAccessibleInterval<T> tt : inputs) {
483+
SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt, false, true);
484+
shmaInputList.add(shma);
485+
HashMap<String, Object> map = new HashMap<String, Object>();
486+
map.put(SHAPE_KEY, tt.dimensionsAsLongArray());
487+
map.put(DTYPE_KEY, CommonUtils.getDataTypeFromRAI(tt));
488+
map.put(IS_INPUT_KEY, true);
489+
map.put(MEM_NAME_KEY, shma.getName());
490+
encIns.add(gson.toJson(map));
491+
}
492+
LinkedHashMap<String, Object> args = new LinkedHashMap<String, Object>();
493+
args.put("inputs", encIns);
494+
495+
try {
496+
Task task = runner.task("inference", args);
497+
task.waitFor();
498+
if (task.status == TaskStatus.CANCELED)
499+
throw new RuntimeException();
500+
else if (task.status == TaskStatus.FAILED)
501+
throw new RuntimeException(task.error);
502+
else if (task.status == TaskStatus.CRASHED) {
503+
this.runner.close();
504+
runner = null;
505+
throw new RuntimeException(task.error);
506+
} else if (task.outputs == null)
507+
throw new RuntimeException("No outputs generated");
508+
List<String> outputs = (List<String>) task.outputs.get("encoded");
509+
List<RandomAccessibleInterval<R>> rais = new ArrayList<RandomAccessibleInterval<R>>();
510+
for (String out : outputs) {
511+
String name = (String) Types.decode(out).get(MEM_NAME_KEY);
512+
SharedMemoryArray shm = SharedMemoryArray.read(name);
513+
RandomAccessibleInterval<R> rai = shm.getSharedRAI();
514+
rais.add(Tensor.createCopyOfRaiInWantedDataType(Cast.unchecked(rai), Util.getTypeFromInterval(Cast.unchecked(rai))));
515+
shm.close();
516+
}
517+
closeShmas();
518+
return rais;
519+
} catch (Exception e) {
520+
closeShmas();
521+
if (e instanceof RunModelException)
522+
throw (RunModelException) e;
523+
throw new RunModelException(Types.stackTrace(e));
524+
}
525+
}
526+
527+
private void closeInterprocess() throws RunModelException {
528+
try {
529+
Task task = runner.task("closeTensors");
530+
task.waitFor();
531+
if (task.status == TaskStatus.CANCELED)
532+
throw new RuntimeException();
533+
else if (task.status == TaskStatus.FAILED)
534+
throw new RuntimeException(task.error);
535+
else if (task.status == TaskStatus.CRASHED) {
536+
this.runner.close();
537+
runner = null;
538+
throw new RuntimeException(task.error);
539+
}
540+
} catch (Exception e) {
541+
if (e instanceof RunModelException)
542+
throw (RunModelException) e;
543+
throw new RunModelException(Types.stackTrace(e));
544+
}
545+
}
546+
547+
protected void closeFromInterp() {
548+
if (!PlatformDetection.isWindows())
549+
return;
550+
this.shmaNamesList.stream().forEach(nn -> {
551+
try {
552+
SharedMemoryArray.read(nn).close();
553+
} catch (IOException e) {
554+
e.printStackTrace();
555+
}
556+
});
557+
}
558+
559+
private void closeShmas() throws RunModelException {
393560
shmaInputList.forEach(shm -> {
394561
try { shm.close(); } catch (IOException e1) { e1.printStackTrace();}
395562
});
@@ -398,6 +565,8 @@ private void closeShmas() {
398565
try { shm.close(); } catch (IOException e1) { e1.printStackTrace();}
399566
});
400567
shmaOutputList = null;
568+
if (interprocessing)
569+
closeInterprocess();
401570
}
402571

403572
private <T extends RealType<T> & NativeType<T>> List<String> encodeInputs(List<Tensor<T>> inputTensors) {

src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/shm/ShmBuilder.java

+31-15
Original file line numberDiff line numberDiff line change
@@ -67,34 +67,50 @@ private ShmBuilder()
6767
* @throws IOException if there is any error creating the shared memory array
6868
*/
6969
public static void build(TType tensor, String memoryName) throws IllegalArgumentException, IOException
70+
{
71+
build(tensor, memoryName, true);
72+
}
73+
74+
/**
75+
* Create a {@link SharedMemoryArray} from a {@link TType} tensor
76+
* @param tensor
77+
* the tensor to be passed into the other process through the shared memory
78+
* @param memoryName
79+
* the name of the memory region where the tensor is going to be copied
80+
* @param close
81+
* on Windows, whether to close the shma after creating it or reading it
82+
* @throws IllegalArgumentException if the data type of the tensor is not supported
83+
* @throws IOException if there is any error creating the shared memory array
84+
*/
85+
public static void build(TType tensor, String memoryName, boolean close) throws IllegalArgumentException, IOException
7086
{
7187
if (tensor instanceof TUint8)
7288
{
73-
buildFromTensorUByte((TUint8) tensor, memoryName);
89+
buildFromTensorUByte((TUint8) tensor, memoryName, close);
7490
}
7591
else if (tensor instanceof TInt32)
7692
{
77-
buildFromTensorInt((TInt32) tensor, memoryName);
93+
buildFromTensorInt((TInt32) tensor, memoryName, close);
7894
}
7995
else if (tensor instanceof TFloat32)
8096
{
81-
buildFromTensorFloat((TFloat32) tensor, memoryName);
97+
buildFromTensorFloat((TFloat32) tensor, memoryName, close);
8298
}
8399
else if (tensor instanceof TFloat64)
84100
{
85-
buildFromTensorDouble((TFloat64) tensor, memoryName);
101+
buildFromTensorDouble((TFloat64) tensor, memoryName, close);
86102
}
87103
else if (tensor instanceof TInt64)
88104
{
89-
buildFromTensorLong((TInt64) tensor, memoryName);
105+
buildFromTensorLong((TInt64) tensor, memoryName, close);
90106
}
91107
else
92108
{
93109
throw new IllegalArgumentException("Unsupported tensor type: " + tensor.dataType().name());
94110
}
95111
}
96112

97-
private static void buildFromTensorUByte(TUint8 tensor, String memoryName) throws IOException
113+
private static void buildFromTensorUByte(TUint8 tensor, String memoryName, boolean close) throws IOException
98114
{
99115
long[] arrayShape = tensor.shape().asArray();
100116
if (CommonUtils.int32Overflows(arrayShape, 1))
@@ -106,10 +122,10 @@ private static void buildFromTensorUByte(TUint8 tensor, String memoryName) throw
106122
ByteBuffer buff2 = ByteBuffer.wrap(flat).order(ByteOrder.LITTLE_ENDIAN);
107123
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
108124
buff.put(buff2);
109-
if (PlatformDetection.isWindows()) shma.close();
125+
if (PlatformDetection.isWindows() && close) shma.close();
110126
}
111127

112-
private static void buildFromTensorInt(TInt32 tensor, String memoryName) throws IOException
128+
private static void buildFromTensorInt(TInt32 tensor, String memoryName, boolean close) throws IOException
113129
{
114130
long[] arrayShape = tensor.shape().asArray();
115131
if (CommonUtils.int32Overflows(arrayShape, 4))
@@ -122,10 +138,10 @@ private static void buildFromTensorInt(TInt32 tensor, String memoryName) throws
122138
ByteBuffer buff2 = ByteBuffer.wrap(flat).order(ByteOrder.LITTLE_ENDIAN);
123139
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
124140
buff.put(buff2);
125-
if (PlatformDetection.isWindows()) shma.close();
141+
if (PlatformDetection.isWindows() && close) shma.close();
126142
}
127143

128-
private static void buildFromTensorFloat(TFloat32 tensor, String memoryName) throws IOException
144+
private static void buildFromTensorFloat(TFloat32 tensor, String memoryName, boolean close) throws IOException
129145
{
130146
long[] arrayShape = tensor.shape().asArray();
131147
if (CommonUtils.int32Overflows(arrayShape, 4))
@@ -138,10 +154,10 @@ private static void buildFromTensorFloat(TFloat32 tensor, String memoryName) thr
138154
ByteBuffer buff2 = ByteBuffer.wrap(flat).order(ByteOrder.LITTLE_ENDIAN);
139155
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
140156
buff.put(buff2);
141-
if (PlatformDetection.isWindows()) shma.close();
157+
if (PlatformDetection.isWindows() && close) shma.close();
142158
}
143159

144-
private static void buildFromTensorDouble(TFloat64 tensor, String memoryName) throws IOException
160+
private static void buildFromTensorDouble(TFloat64 tensor, String memoryName, boolean close) throws IOException
145161
{
146162
long[] arrayShape = tensor.shape().asArray();
147163
if (CommonUtils.int32Overflows(arrayShape, 8))
@@ -154,10 +170,10 @@ private static void buildFromTensorDouble(TFloat64 tensor, String memoryName) th
154170
ByteBuffer buff2 = ByteBuffer.wrap(flat).order(ByteOrder.LITTLE_ENDIAN);
155171
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
156172
buff.put(buff2);
157-
if (PlatformDetection.isWindows()) shma.close();
173+
if (PlatformDetection.isWindows() && close) shma.close();
158174
}
159175

160-
private static void buildFromTensorLong(TInt64 tensor, String memoryName) throws IOException
176+
private static void buildFromTensorLong(TInt64 tensor, String memoryName, boolean close) throws IOException
161177
{
162178
long[] arrayShape = tensor.shape().asArray();
163179
if (CommonUtils.int32Overflows(arrayShape, 8))
@@ -171,6 +187,6 @@ private static void buildFromTensorLong(TInt64 tensor, String memoryName) throws
171187
ByteBuffer buff2 = ByteBuffer.wrap(flat).order(ByteOrder.LITTLE_ENDIAN);
172188
tensor.asRawTensor().data().read(flat, 0, buff.capacity());
173189
buff.put(buff2);
174-
if (PlatformDetection.isWindows()) shma.close();
190+
if (PlatformDetection.isWindows() && close) shma.close();
175191
}
176192
}

0 commit comments

Comments
 (0)