|
20 | 20 | */
|
21 | 21 | package io.bioimage.modelrunner.pytorch.javacpp.shm;
|
22 | 22 |
|
| 23 | +import io.bioimage.modelrunner.pytorch.javacpp.tensor.ImgLib2Builder; |
23 | 24 | import io.bioimage.modelrunner.system.PlatformDetection;
|
24 | 25 | import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
|
25 | 26 | import io.bioimage.modelrunner.utils.CommonUtils;
|
26 | 27 |
|
27 | 28 | import java.io.IOException;
|
28 | 29 | import java.nio.ByteBuffer;
|
| 30 | +import java.nio.FloatBuffer; |
29 | 31 | import java.util.Arrays;
|
30 | 32 |
|
31 | 33 | import org.bytedeco.pytorch.Tensor;
|
32 | 34 |
|
33 | 35 | import net.imglib2.type.numeric.integer.IntType;
|
34 | 36 | import net.imglib2.type.numeric.integer.LongType;
|
| 37 | +import net.imglib2.RandomAccessibleInterval; |
35 | 38 | import net.imglib2.type.numeric.integer.ByteType;
|
36 | 39 | import net.imglib2.type.numeric.real.DoubleType;
|
37 | 40 | import net.imglib2.type.numeric.real.FloatType;
|
@@ -96,7 +99,8 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws
|
96 | 99 | throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
|
97 | 100 | + " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
|
98 | 101 | SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
|
99 |
| - shma.getDataBufferNoHeader().put(tensor.asByteBuffer()); |
| 102 | + RandomAccessibleInterval<?> rai = shma.getSharedRAI(); |
| 103 | + rai = ImgLib2Builder.build(tensor); |
100 | 104 | if (PlatformDetection.isWindows()) shma.close();
|
101 | 105 | }
|
102 | 106 |
|
|
0 commit comments