Skip to content

Commit f50a2ad

Browse files
committed
stable working version
1 parent e6f5bd5 commit f50a2ad

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java

-3
Original file line numberDiff line numberDiff line change
@@ -253,18 +253,15 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
253253
IValue output = model.forward(inputsVector);
254254
TensorVector outputTensorVector = null;
255255
if (output.isTensorList()) {
256-
System.out.println("entered 1");
257256
outputTensorVector = output.toTensorVector();
258257
} else {
259-
System.out.println("entered 2");
260258
outputTensorVector = new TensorVector();
261259
outputTensorVector.put(output.toTensor());
262260
}
263261

264262
// Fill the agnostic output tensors list with data from the inference result
265263
int c = 0;
266264
for (String ee : outputs) {
267-
System.out.println(ee);
268265
Map<String, Object> decoded = Types.decode(ee);
269266
ShmBuilder.build(outputTensorVector.get(c ++), (String) decoded.get(MEM_NAME_KEY));
270267
}

src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java

+38-7
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,22 @@
2020
*/
2121
package io.bioimage.modelrunner.pytorch.javacpp.shm;
2222

23-
import io.bioimage.modelrunner.pytorch.javacpp.tensor.ImgLib2Builder;
2423
import io.bioimage.modelrunner.system.PlatformDetection;
2524
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
2625
import io.bioimage.modelrunner.utils.CommonUtils;
2726

2827
import java.io.IOException;
2928
import java.nio.ByteBuffer;
29+
import java.nio.DoubleBuffer;
3030
import java.nio.FloatBuffer;
31+
import java.nio.IntBuffer;
32+
import java.nio.LongBuffer;
3133
import java.util.Arrays;
3234

3335
import org.bytedeco.pytorch.Tensor;
3436

3537
import net.imglib2.type.numeric.integer.IntType;
3638
import net.imglib2.type.numeric.integer.LongType;
37-
import net.imglib2.RandomAccessibleInterval;
3839
import net.imglib2.type.numeric.integer.ByteType;
3940
import net.imglib2.type.numeric.real.DoubleType;
4041
import net.imglib2.type.numeric.real.FloatType;
@@ -88,7 +89,14 @@ private static void buildFromTensorByte(Tensor tensor, String memoryName) throws
8889
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
8990
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
9091
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new ByteType(), false, true);
91-
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
92+
long flatSize = 1;
93+
for (long l : arrayShape) {flatSize *= l;}
94+
byte[] flat = new byte[(int) flatSize];
95+
ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) (flatSize));
96+
tensor.data_ptr_byte().get(flat);
97+
byteBuffer.put(flat);
98+
byteBuffer.rewind();
99+
shma.getDataBufferNoHeader().put(byteBuffer);
92100
if (PlatformDetection.isWindows()) shma.close();
93101
}
94102

@@ -99,8 +107,15 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws
99107
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
100108
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
101109
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
102-
RandomAccessibleInterval<?> rai = shma.getSharedRAI();
103-
rai = ImgLib2Builder.build(tensor);
110+
long flatSize = 1;
111+
for (long l : arrayShape) {flatSize *= l;}
112+
int[] flat = new int[(int) flatSize];
113+
ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) (flatSize * Integer.BYTES));
114+
IntBuffer floatBuffer = byteBuffer.asIntBuffer();
115+
tensor.data_ptr_int().get(flat);
116+
floatBuffer.put(flat);
117+
byteBuffer.rewind();
118+
shma.getDataBufferNoHeader().put(byteBuffer);
104119
if (PlatformDetection.isWindows()) shma.close();
105120
}
106121

@@ -130,7 +145,15 @@ private static void buildFromTensorDouble(Tensor tensor, String memoryName) thro
130145
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
131146
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
132147
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true);
133-
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
148+
long flatSize = 1;
149+
for (long l : arrayShape) {flatSize *= l;}
150+
double[] flat = new double[(int) flatSize];
151+
ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) (flatSize * Double.BYTES));
152+
DoubleBuffer floatBuffer = byteBuffer.asDoubleBuffer();
153+
tensor.data_ptr_double().get(flat);
154+
floatBuffer.put(flat);
155+
byteBuffer.rewind();
156+
shma.getDataBufferNoHeader().put(byteBuffer);
134157
if (PlatformDetection.isWindows()) shma.close();
135158
}
136159

@@ -141,7 +164,15 @@ private static void buildFromTensorLong(Tensor tensor, String memoryName) throws
141164
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
142165
+ " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8);
143166
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true);
144-
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
167+
long flatSize = 1;
168+
for (long l : arrayShape) {flatSize *= l;}
169+
long[] flat = new long[(int) flatSize];
170+
ByteBuffer byteBuffer = ByteBuffer.allocateDirect((int) (flatSize * Long.BYTES));
171+
LongBuffer floatBuffer = byteBuffer.asLongBuffer();
172+
tensor.data_ptr_long().get(flat);
173+
floatBuffer.put(flat);
174+
byteBuffer.rewind();
175+
shma.getDataBufferNoHeader().put(byteBuffer);
145176
if (PlatformDetection.isWindows()) shma.close();
146177
}
147178
}

0 commit comments

Comments
 (0)