Skip to content

Commit 1acfd4f

Browse files
committed
avoid transposition
1 parent 154fb2d commit 1acfd4f

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

src/main/java/io/bioimage/modelrunner/pytorch/PytorchInterface.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ private <T extends RealType<T> & NativeType<T>> List<String> encodeInputs(List<T
390390
List<String> encodedInputTensors = new ArrayList<String>();
391391
Gson gson = new Gson();
392392
for (Tensor<?> tt : inputTensors) {
393-
SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), false, true);
393+
SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), true, true);
394394
shmaInputList.add(shma);
395395
HashMap<String, Object> map = new HashMap<String, Object>();
396396
map.put(NAME_KEY, tt.getName());
@@ -415,7 +415,7 @@ List<String> encodeOutputs(List<Tensor<T>> outputTensors) {
415415
if (!tt.isEmpty()) {
416416
map.put(SHAPE_KEY, tt.getShape());
417417
map.put(DTYPE_KEY, CommonUtils.getDataTypeFromRAI(tt.getData()));
418-
SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), false, true);
418+
SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(tt.getData(), true, true);
419419
shmaOutputList.add(shma);
420420
map.put(MEM_NAME_KEY, shma.getName());
421421
} else if (PlatformDetection.isWindows()){

src/main/java/io/bioimage/modelrunner/pytorch/shm/ShmBuilder.java

+5-5
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ private static void buildFromTensorUByte(NDArray tensor, String memoryName) thro
8888
if (CommonUtils.int32Overflows(arrayShape, 1))
8989
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
9090
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
91-
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
91+
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), true, true);
9292
shma.getDataBufferNoHeader().put(tensor.toByteArray());
9393
if (PlatformDetection.isWindows()) shma.close();
9494
}
@@ -100,7 +100,7 @@ private static void buildFromTensorInt(NDArray tensor, String memoryName) throws
100100
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
101101
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
102102

103-
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
103+
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), true, true);
104104
shma.getDataBufferNoHeader().put(tensor.toByteArray());
105105
if (PlatformDetection.isWindows()) shma.close();
106106
}
@@ -112,7 +112,7 @@ private static void buildFromTensorFloat(NDArray tensor, String memoryName) thro
112112
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
113113
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
114114

115-
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
115+
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), true, true);
116116
shma.getDataBufferNoHeader().put(tensor.toByteArray());
117117
if (PlatformDetection.isWindows()) shma.close();
118118
}
@@ -124,7 +124,7 @@ private static void buildFromTensorDouble(NDArray tensor, String memoryName) thr
124124
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
125125
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
126126

127-
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true);
127+
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), true, true);
128128
shma.getDataBufferNoHeader().put(tensor.toByteArray());
129129
if (PlatformDetection.isWindows()) shma.close();
130130
}
@@ -137,7 +137,7 @@ private static void buildFromTensorLong(NDArray tensor, String memoryName) throw
137137
+ " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8);
138138

139139

140-
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true);
140+
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), true, true);
141141
shma.getDataBufferNoHeader().put(tensor.toByteArray());
142142
if (PlatformDetection.isWindows()) shma.close();
143143
}

0 commit comments

Comments
 (0)