Skip to content

Commit d8c0725

Browse files
committed
improve small errors on shm communication
1 parent 6ad58b1 commit d8c0725

File tree

2 files changed

+5
-13
lines changed

2 files changed

+5
-13
lines changed

src/main/java/io/bioimage/modelrunner/tensorflow/v1/shm/ShmBuilder.java

+5-5
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ private static void buildFromTensorUByte(Tensor<UInt8> tensor, String memoryName
9393
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
9494
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
9595
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
96-
ByteBuffer buff = shma.getDataBuffer();
96+
ByteBuffer buff = shma.getDataBufferNoHeader();
9797
tensor.writeTo(buff);
9898
if (PlatformDetection.isWindows()) shma.close();
9999
}
@@ -106,7 +106,7 @@ private static void buildFromTensorInt(Tensor<Integer> tensor, String memoryName
106106
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
107107

108108
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
109-
ByteBuffer buff = shma.getDataBuffer();
109+
ByteBuffer buff = shma.getDataBufferNoHeader();
110110
tensor.writeTo(buff);
111111
if (PlatformDetection.isWindows()) shma.close();
112112
}
@@ -119,7 +119,7 @@ private static void buildFromTensorFloat(Tensor<Float> tensor, String memoryName
119119
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
120120

121121
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
122-
ByteBuffer buff = shma.getDataBuffer();
122+
ByteBuffer buff = shma.getDataBufferNoHeader();
123123
tensor.writeTo(buff);
124124
if (PlatformDetection.isWindows()) shma.close();
125125
}
@@ -132,7 +132,7 @@ private static void buildFromTensorDouble(Tensor<Double> tensor, String memoryNa
132132
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
133133

134134
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true);
135-
ByteBuffer buff = shma.getDataBuffer();
135+
ByteBuffer buff = shma.getDataBufferNoHeader();
136136
tensor.writeTo(buff);
137137
if (PlatformDetection.isWindows()) shma.close();
138138
}
@@ -146,7 +146,7 @@ private static void buildFromTensorLong(Tensor<Long> tensor, String memoryName)
146146

147147

148148
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true);
149-
ByteBuffer buff = shma.getDataBuffer();
149+
ByteBuffer buff = shma.getDataBufferNoHeader();
150150
tensor.writeTo(buff);
151151
if (PlatformDetection.isWindows()) shma.close();
152152
}

src/main/java/io/bioimage/modelrunner/tensorflow/v1/shm/TensorBuilder.java

-8
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,6 @@ private static Tensor<Integer> buildInt(SharedMemoryArray tensor)
105105
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
106106
ByteBuffer buff = tensor.getDataBufferNoHeader();
107107
IntBuffer intBuff = buff.asIntBuffer();
108-
int[] intArray = new int[intBuff.capacity()];
109-
intBuff.get(intArray);
110108
Tensor<Integer> ndarray = Tensor.create(ogShape, intBuff);
111109
return ndarray;
112110
}
@@ -122,8 +120,6 @@ private static Tensor<Long> buildLong(SharedMemoryArray tensor)
122120
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
123121
ByteBuffer buff = tensor.getDataBufferNoHeader();
124122
LongBuffer longBuff = buff.asLongBuffer();
125-
long[] longArray = new long[longBuff.capacity()];
126-
longBuff.get(longArray);
127123
Tensor<Long> ndarray = Tensor.create(ogShape, longBuff);
128124
return ndarray;
129125
}
@@ -139,8 +135,6 @@ private static Tensor<Float> buildFloat(SharedMemoryArray tensor)
139135
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
140136
ByteBuffer buff = tensor.getDataBufferNoHeader();
141137
FloatBuffer floatBuff = buff.asFloatBuffer();
142-
float[] floatArray = new float[floatBuff.capacity()];
143-
floatBuff.get(floatArray);
144138
Tensor<Float> ndarray = Tensor.create(ogShape, floatBuff);
145139
return ndarray;
146140
}
@@ -156,8 +150,6 @@ private static Tensor<Double> buildDouble(SharedMemoryArray tensor)
156150
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
157151
ByteBuffer buff = tensor.getDataBufferNoHeader();
158152
DoubleBuffer doubleBuff = buff.asDoubleBuffer();
159-
double[] doubleArray = new double[doubleBuff.capacity()];
160-
doubleBuff.get(doubleArray);
161153
Tensor<Double> ndarray = Tensor.create(ogShape, doubleBuff);
162154
return ndarray;
163155
}

0 commit comments

Comments
 (0)