Skip to content

Commit f33dc43

Browse files
committed
correct gigabug creating tensors
1 parent d951617 commit f33dc43

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

Diff for: src/main/java/io/bioimage/modelrunner/pytorch/javacpp/shm/ShmBuilder.java

+10-6
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@
2929

3030
import org.bytedeco.pytorch.Tensor;
3131

32-
import net.imglib2.type.numeric.integer.UnsignedByteType;
32+
import net.imglib2.type.numeric.integer.IntType;
33+
import net.imglib2.type.numeric.integer.LongType;
34+
import net.imglib2.type.numeric.integer.ByteType;
35+
import net.imglib2.type.numeric.real.DoubleType;
36+
import net.imglib2.type.numeric.real.FloatType;
3337

3438
/**
3539
* A utility class that converts {@link Tensor}s into {@link SharedMemoryArray}s for
@@ -79,7 +83,7 @@ private static void buildFromTensorByte(Tensor tensor, String memoryName) throws
7983
if (CommonUtils.int32Overflows(arrayShape, 1))
8084
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
8185
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
82-
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
86+
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new ByteType(), false, true);
8387
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
8488
if (PlatformDetection.isWindows()) shma.close();
8589
}
@@ -90,7 +94,7 @@ private static void buildFromTensorInt(Tensor tensor, String memoryName) throws
9094
if (CommonUtils.int32Overflows(arrayShape, 4))
9195
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
9296
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
93-
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
97+
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
9498
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
9599
if (PlatformDetection.isWindows()) shma.close();
96100
}
@@ -101,7 +105,7 @@ private static void buildFromTensorFloat(Tensor tensor, String memoryName) throw
101105
if (CommonUtils.int32Overflows(arrayShape, 4))
102106
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
103107
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
104-
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
108+
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
105109
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
106110
if (PlatformDetection.isWindows()) shma.close();
107111
}
@@ -112,7 +116,7 @@ private static void buildFromTensorDouble(Tensor tensor, String memoryName) thro
112116
if (CommonUtils.int32Overflows(arrayShape, 8))
113117
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
114118
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
115-
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
119+
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true);
116120
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
117121
if (PlatformDetection.isWindows()) shma.close();
118122
}
@@ -123,7 +127,7 @@ private static void buildFromTensorLong(Tensor tensor, String memoryName) throws
123127
if (CommonUtils.int32Overflows(arrayShape, 8))
124128
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
125129
+ " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8);
126-
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
130+
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true);
127131
shma.getDataBufferNoHeader().put(tensor.asByteBuffer());
128132
if (PlatformDetection.isWindows()) shma.close();
129133
}

0 commit comments

Comments
 (0)