Skip to content

Commit fbd3db9

Browse files
committed
add tools to move to persistent interprocessing
1 parent a55ce82 commit fbd3db9

File tree

7 files changed

+483
-1511
lines changed

7 files changed

+483
-1511
lines changed

Diff for: src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/Tensorflow2Interface.java

+14-26
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,9 @@
3737
import io.bioimage.modelrunner.system.PlatformDetection;
3838
import io.bioimage.modelrunner.tensor.Tensor;
3939
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
40+
import io.bioimage.modelrunner.tensorflow.v2.api020.shm.ShmBuilder;
4041
import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.ImgLib2Builder;
4142
import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.TensorBuilder;
42-
import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.mappedbuffer.ImgLib2ToMappedBuffer;
43-
import io.bioimage.modelrunner.tensorflow.v2.api020.tensor.mappedbuffer.MappedBufferToImgLib2;
4443
import io.bioimage.modelrunner.utils.CommonUtils;
4544
import io.bioimage.modelrunner.utils.Constants;
4645
import io.bioimage.modelrunner.utils.ZipUtils;
@@ -50,24 +49,14 @@
5049
import net.imglib2.util.Cast;
5150
import net.imglib2.util.Util;
5251

53-
import java.io.BufferedReader;
5452
import java.io.File;
5553
import java.io.IOException;
56-
import java.io.InputStreamReader;
57-
import java.io.RandomAccessFile;
5854
import java.io.UnsupportedEncodingException;
5955
import java.net.URISyntaxException;
6056
import java.net.URL;
6157
import java.net.URLDecoder;
62-
import java.nio.ByteBuffer;
63-
import java.nio.MappedByteBuffer;
64-
import java.nio.channels.FileChannel;
6558
import java.nio.charset.StandardCharsets;
66-
import java.nio.file.Files;
67-
import java.nio.file.Paths;
6859
import java.security.ProtectionDomain;
69-
import java.time.LocalDateTime;
70-
import java.time.format.DateTimeFormatter;
7160
import java.util.ArrayList;
7261
import java.util.HashMap;
7362
import java.util.LinkedHashMap;
@@ -81,7 +70,6 @@
8170
import org.tensorflow.proto.framework.MetaGraphDef;
8271
import org.tensorflow.proto.framework.SignatureDef;
8372
import org.tensorflow.proto.framework.TensorInfo;
84-
import org.tensorflow.types.family.TType;
8573

8674
/**
8775
* Class to that communicates with the dl-model runner, see
@@ -290,28 +278,28 @@ void run(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors)
290278
Session session = model.session();
291279
Session.Runner runner = session.runner();
292280
List<String> inputListNames = new ArrayList<String>();
293-
List<TType> inTensors = new ArrayList<TType>();
281+
List<org.tensorflow.Tensor<?>> inTensors = new ArrayList<org.tensorflow.Tensor<?>>();
294282
int c = 0;
295-
for (Tensor<?> tt : inputTensors) {
283+
for (Tensor<T> tt : inputTensors) {
296284
inputListNames.add(tt.getName());
297-
TType inT = TensorBuilder.build(tt);
285+
org.tensorflow.Tensor<?> inT = TensorBuilder.build(tt);
298286
inTensors.add(inT);
299287
String inputName = getModelInputName(tt.getName(), c ++);
300288
runner.feed(inputName, inT);
301289
}
302290
c = 0;
303-
for (Tensor<?> tt : outputTensors)
291+
for (Tensor<R> tt : outputTensors)
304292
runner = runner.fetch(getModelOutputName(tt.getName(), c ++));
305293
// Run runner
306-
List<org.tensorflow.Tensor> resultPatchTensors = runner.run();
294+
List<org.tensorflow.Tensor<?>> resultPatchTensors = runner.run();
307295

308296
// Fill the agnostic output tensors list with data from the inference result
309297
fillOutputTensors(resultPatchTensors, outputTensors);
310298
// Close the remaining resources
311-
for (TType tt : inTensors) {
299+
for (org.tensorflow.Tensor<?> tt : inTensors) {
312300
tt.close();
313301
}
314-
for (org.tensorflow.Tensor tt : resultPatchTensors) {
302+
for (org.tensorflow.Tensor<?> tt : resultPatchTensors) {
315303
tt.close();
316304
}
317305
}
@@ -320,12 +308,12 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
320308
Session session = model.session();
321309
Session.Runner runner = session.runner();
322310

323-
List<TType> inTensors = new ArrayList<TType>();
311+
List<org.tensorflow.Tensor<?>> inTensors = new ArrayList<org.tensorflow.Tensor<?>>();
324312
int c = 0;
325313
for (String ee : inputs) {
326314
Map<String, Object> decoded = Types.decode(ee);
327315
SharedMemoryArray shma = SharedMemoryArray.read((String) decoded.get(MEM_NAME_KEY));
328-
TType inT = io.bioimage.modelrunner.tensorflow.v2.api030.shm.TensorBuilder.build(shma);
316+
org.tensorflow.Tensor<?> inT = io.bioimage.modelrunner.tensorflow.v2.api020.shm.TensorBuilder.build(shma);
329317
if (PlatformDetection.isWindows()) shma.close();
330318
inTensors.add(inT);
331319
String inputName = getModelInputName((String) decoded.get(NAME_KEY), c ++);
@@ -336,19 +324,19 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
336324
for (String ee : outputs)
337325
runner = runner.fetch(getModelOutputName((String) Types.decode(ee).get(NAME_KEY), c ++));
338326
// Run runner
339-
List<org.tensorflow.Tensor> resultPatchTensors = runner.run();
327+
List<org.tensorflow.Tensor<?>> resultPatchTensors = runner.run();
340328

341329
// Fill the agnostic output tensors list with data from the inference result
342330
c = 0;
343331
for (String ee : outputs) {
344332
Map<String, Object> decoded = Types.decode(ee);
345-
ShmBuilder.build((TType) resultPatchTensors.get(c ++), (String) decoded.get(MEM_NAME_KEY));
333+
ShmBuilder.build((org.tensorflow.Tensor<?>) resultPatchTensors.get(c ++), (String) decoded.get(MEM_NAME_KEY));
346334
}
347335
// Close the remaining resources
348-
for (TType tt : inTensors) {
336+
for (org.tensorflow.Tensor<?> tt : inTensors) {
349337
tt.close();
350338
}
351-
for (org.tensorflow.Tensor tt : resultPatchTensors) {
339+
for (org.tensorflow.Tensor<?> tt : resultPatchTensors) {
352340
tt.close();
353341
}
354342
}

0 commit comments

Comments
 (0)