Skip to content

Commit ac25ac6

Browse files
committed
GOOD WORKING INTERFACE
1 parent a2809d1 commit ac25ac6

File tree

2 files changed

+5
-49
lines changed

2 files changed

+5
-49
lines changed

Diff for: src/main/java/io/bioimage/modelrunner/pytorch/PytorchInterface.java

-42
Original file line numberDiff line numberDiff line change
@@ -468,48 +468,6 @@ else if (task.status == TaskStatus.CRASHED)
468468
model = null;
469469
}
470470

471-
/** TODO remove
472-
* Create the arguments needed to execute Pytorch in another
473-
* process with the corresponding tensors
474-
* @return the command used to call the separate process
475-
* @throws IOException if the command needed to execute interprocessing is too long
476-
* @throws URISyntaxException if there is any error with the URIs retrieved from the classes
477-
*/
478-
private List<String> getProcessCommandsWithoutArgs2() throws IOException, URISyntaxException {
479-
String javaHome = System.getProperty("java.home");
480-
String javaBin = javaHome + File.separator + "bin" + File.separator + "java";
481-
482-
String modelrunnerPath = getPathFromClass(DeepLearningEngineInterface.class);
483-
String imglib2Path = getPathFromClass(NativeType.class);
484-
String gsonPath = getPathFromClass(Gson.class);
485-
String jnaPath = getPathFromClass(com.sun.jna.Library.class);
486-
String jnaPlatformPath = getPathFromClass(com.sun.jna.platform.FileUtils.class);
487-
if (modelrunnerPath == null || (modelrunnerPath.endsWith("DeepLearningEngineInterface.class")
488-
&& !modelrunnerPath.contains(File.pathSeparator)))
489-
modelrunnerPath = System.getProperty("java.class.path");
490-
String classpath = modelrunnerPath + File.pathSeparator + imglib2Path + File.pathSeparator;
491-
classpath = classpath + gsonPath + File.pathSeparator;
492-
classpath = classpath + jnaPath + File.pathSeparator;
493-
classpath = classpath + jnaPlatformPath + File.pathSeparator;
494-
ProtectionDomain protectionDomain = PytorchInterface.class.getProtectionDomain();
495-
String codeSource = protectionDomain.getCodeSource().getLocation().getPath();
496-
String f_name = URLDecoder.decode(codeSource, StandardCharsets.UTF_8.toString());
497-
f_name = new File(f_name).getAbsolutePath();
498-
for (File ff : new File(f_name).getParentFile().listFiles()) {
499-
if (ff.getName().startsWith(JAR_FILE_NAME) && !ff.getAbsolutePath().equals(f_name))
500-
continue;
501-
classpath += ff.getAbsolutePath() + File.pathSeparator;
502-
}
503-
String className = PytorchInterface.class.getName();
504-
List<String> command = new LinkedList<String>();
505-
command.add(padSpecialJavaBin(javaBin));
506-
command.add("-cp");
507-
command.add(classpath);
508-
command.add(className);
509-
command.add(modelSource);
510-
return command;
511-
}
512-
513471
/**
514472
* Create the arguments needed to execute tensorflow 2 in another
515473
* process with the corresponding tensors

Diff for: src/main/java/io/bioimage/modelrunner/pytorch/shm/ShmBuilder.java

+5-7
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import io.bioimage.modelrunner.utils.CommonUtils;
2626

2727
import java.io.IOException;
28-
import java.nio.ByteBuffer;
2928
import java.util.Arrays;
3029

3130
import ai.djl.ndarray.NDArray;
@@ -61,7 +60,6 @@ private ShmBuilder()
6160
*/
6261
public static void build(NDArray tensor, String memoryName) throws IllegalArgumentException, IOException
6362
{
64-
System.out.println(tensor.getDataType().asNumpy());
6563
switch (tensor.getDataType())
6664
{
6765
case UINT8:
@@ -91,7 +89,7 @@ private static void buildFromTensorUByte(NDArray tensor, String memoryName) thro
9189
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
9290
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
9391
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
94-
shma.setBuffer(ByteBuffer.wrap(tensor.toByteArray()));
92+
shma.getDataBufferNoHeader().put(tensor.toByteArray());
9593
if (PlatformDetection.isWindows()) shma.close();
9694
}
9795

@@ -103,7 +101,7 @@ private static void buildFromTensorInt(NDArray tensor, String memoryName) throws
103101
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
104102

105103
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
106-
shma.setBuffer(ByteBuffer.wrap(tensor.toByteArray()));
104+
shma.getDataBufferNoHeader().put(tensor.toByteArray());
107105
if (PlatformDetection.isWindows()) shma.close();
108106
}
109107

@@ -115,7 +113,7 @@ private static void buildFromTensorFloat(NDArray tensor, String memoryName) thro
115113
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
116114

117115
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
118-
shma.setBuffer(ByteBuffer.wrap(tensor.toByteArray()));
116+
shma.getDataBufferNoHeader().put(tensor.toByteArray());
119117
if (PlatformDetection.isWindows()) shma.close();
120118
}
121119

@@ -127,7 +125,7 @@ private static void buildFromTensorDouble(NDArray tensor, String memoryName) thr
127125
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
128126

129127
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true);
130-
shma.setBuffer(ByteBuffer.wrap(tensor.toByteArray()));
128+
shma.getDataBufferNoHeader().put(tensor.toByteArray());
131129
if (PlatformDetection.isWindows()) shma.close();
132130
}
133131

@@ -140,7 +138,7 @@ private static void buildFromTensorLong(NDArray tensor, String memoryName) throw
140138

141139

142140
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true);
143-
shma.setBuffer(ByteBuffer.wrap(tensor.toByteArray()));
141+
shma.getDataBufferNoHeader().put(tensor.toByteArray());
144142
if (PlatformDetection.isWindows()) shma.close();
145143
}
146144
}

0 commit comments

Comments
 (0)