Skip to content

Commit a937f90

Browse files
committed
corerct ultramegabug that was avoiding copying from tensor to shm
1 parent f33dc43 commit a937f90

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

Diff for: src/main/java/io/bioimage/modelrunner/pytorch/javacpp/PytorchJavaCPPInterface.java

+3
Original file line numberDiff line numberDiff line change
@@ -253,15 +253,18 @@ protected void runFromShmas(List<String> inputs, List<String> outputs) throws IO
253253
IValue output = model.forward(inputsVector);
254254
TensorVector outputTensorVector = null;
255255
if (output.isTensorList()) {
256+
System.out.println("entered 1");
256257
outputTensorVector = output.toTensorVector();
257258
} else {
259+
System.out.println("entered 2");
258260
outputTensorVector = new TensorVector();
259261
outputTensorVector.put(output.toTensor());
260262
}
261263

262264
// Fill the agnostic output tensors list with data from the inference result
263265
int c = 0;
264266
for (String ee : outputs) {
267+
System.out.println(ee);
265268
Map<String, Object> decoded = Types.decode(ee);
266269
ShmBuilder.build(outputTensorVector.get(c ++), (String) decoded.get(MEM_NAME_KEY));
267270
}

Diff for: src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java

+6-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import io.bioimage.modelrunner.utils.CommonUtils;
2626

2727
import java.io.IOException;
28+
import java.nio.ByteBuffer;
2829
import java.util.Arrays;
2930

3031
import org.bytedeco.pytorch.Tensor;
@@ -106,7 +107,11 @@ private static void buildFromTensorFloat(Tensor tensor, String memoryName) throw
106107
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
107108
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
108109
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
109-
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
110+
long flatSize = 1;
111+
for (long l : arrayShape) {flatSize *= l;}
112+
ByteBuffer byteBuffer = ByteBuffer.allocate((int) (flatSize * Float.BYTES));
113+
tensor.data_ptr_float().get(byteBuffer.asFloatBuffer().array());
114+
shma.getDataBufferNoHeader().put(byteBuffer);
110115
if (PlatformDetection.isWindows()) shma.close();
111116
}
112117

0 commit comments

Comments
 (0)