Skip to content

Commit 42bc209

Browse files
committed
correct errors in interprocessing communication
1 parent 6375b7c commit 42bc209

File tree

2 files changed

+14
-54
lines changed

2 files changed

+14
-54
lines changed

Diff for: src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/shm/ShmBuilder.java

+10-35
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,8 @@ private static void buildFromTensorUByte(Tensor<TUint8> tensor, String memoryNam
9898
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
9999
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
100100
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
101-
ByteBuffer buff = shma.getDataBuffer();
102-
int totalSize = 1;
103-
for (long i : arrayShape) {totalSize *= i;}
104-
byte[] flatArr = new byte[buff.capacity()];
105-
buff.get(flatArr);
106-
tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize);
107-
shma.setBuffer(ByteBuffer.wrap(flatArr));
101+
ByteBuffer buff1 = shma.getDataBufferNoHeader();
102+
tensor.rawData().read(buff1.array(), 0, buff1.capacity());
108103
if (PlatformDetection.isWindows()) shma.close();
109104
}
110105

@@ -116,13 +111,8 @@ private static void buildFromTensorInt(Tensor<TInt32> tensor, String memoryName)
116111
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
117112

118113
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
119-
ByteBuffer buff = shma.getDataBuffer();
120-
int totalSize = 4;
121-
for (long i : arrayShape) {totalSize *= i;}
122-
byte[] flatArr = new byte[buff.capacity()];
123-
buff.get(flatArr);
124-
tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize);
125-
shma.setBuffer(ByteBuffer.wrap(flatArr));
114+
ByteBuffer buff1 = shma.getDataBufferNoHeader();
115+
tensor.rawData().read(buff1.array(), 0, buff1.capacity());
126116
if (PlatformDetection.isWindows()) shma.close();
127117
}
128118

@@ -134,13 +124,8 @@ private static void buildFromTensorFloat(Tensor<TFloat32> tensor, String memoryN
134124
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
135125

136126
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
137-
ByteBuffer buff = shma.getDataBuffer();
138-
int totalSize = 4;
139-
for (long i : arrayShape) {totalSize *= i;}
140-
byte[] flatArr = new byte[buff.capacity()];
141-
buff.get(flatArr);
142-
tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize);
143-
shma.setBuffer(ByteBuffer.wrap(flatArr));
127+
ByteBuffer buff1 = shma.getDataBufferNoHeader();
128+
tensor.rawData().read(buff1.array(), 0, buff1.capacity());
144129
if (PlatformDetection.isWindows()) shma.close();
145130
}
146131

@@ -152,13 +137,8 @@ private static void buildFromTensorDouble(Tensor<TFloat64> tensor, String memory
152137
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
153138

154139
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true);
155-
ByteBuffer buff = shma.getDataBuffer();
156-
int totalSize = 8;
157-
for (long i : arrayShape) {totalSize *= i;}
158-
byte[] flatArr = new byte[buff.capacity()];
159-
buff.get(flatArr);
160-
tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize);
161-
shma.setBuffer(ByteBuffer.wrap(flatArr));
140+
ByteBuffer buff1 = shma.getDataBufferNoHeader();
141+
tensor.rawData().read(buff1.array(), 0, buff1.capacity());
162142
if (PlatformDetection.isWindows()) shma.close();
163143
}
164144

@@ -171,13 +151,8 @@ private static void buildFromTensorLong(Tensor<TInt64> tensor, String memoryName
171151

172152

173153
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true);
174-
ByteBuffer buff = shma.getDataBuffer();
175-
int totalSize = 8;
176-
for (long i : arrayShape) {totalSize *= i;}
177-
byte[] flatArr = new byte[buff.capacity()];
178-
buff.get(flatArr);
179-
tensor.rawData().read(flatArr, flatArr.length - totalSize, totalSize);
180-
shma.setBuffer(ByteBuffer.wrap(flatArr));
154+
ByteBuffer buff1 = shma.getDataBufferNoHeader();
155+
tensor.rawData().read(buff1.array(), 0, buff1.capacity());
181156
if (PlatformDetection.isWindows()) shma.close();
182157
}
183158
}

Diff for: src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/shm/TensorBuilder.java

+4-19
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,6 @@
2323

2424
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
2525
import io.bioimage.modelrunner.utils.CommonUtils;
26-
import net.imglib2.RandomAccessibleInterval;
27-
import net.imglib2.img.Img;
28-
import net.imglib2.type.numeric.integer.IntType;
29-
import net.imglib2.type.numeric.integer.LongType;
30-
import net.imglib2.type.numeric.integer.UnsignedByteType;
31-
import net.imglib2.type.numeric.real.DoubleType;
32-
import net.imglib2.type.numeric.real.FloatType;
3326
import net.imglib2.util.Cast;
3427

3528
import java.nio.ByteBuffer;
@@ -125,9 +118,7 @@ private static Tensor<TInt32> buildInt(SharedMemoryArray tensor)
125118
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
126119
ByteBuffer buff = tensor.getDataBufferNoHeader();
127120
IntBuffer intBuff = buff.asIntBuffer();
128-
int[] intArray = new int[intBuff.capacity()];
129-
intBuff.get(intArray);
130-
IntDataBuffer dataBuffer = RawDataBufferFactory.create(intArray, false);
121+
IntDataBuffer dataBuffer = RawDataBufferFactory.create(intBuff.array(), false);
131122
Tensor<TInt32> ndarray = TInt32.tensorOf(Shape.of(ogShape), dataBuffer);
132123
return ndarray;
133124
}
@@ -143,9 +134,7 @@ private static Tensor<TInt64> buildLong(SharedMemoryArray tensor)
143134
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
144135
ByteBuffer buff = tensor.getDataBufferNoHeader();
145136
LongBuffer longBuff = buff.asLongBuffer();
146-
long[] longArray = new long[longBuff.capacity()];
147-
longBuff.get(longArray);
148-
LongDataBuffer dataBuffer = RawDataBufferFactory.create(longArray, false);
137+
LongDataBuffer dataBuffer = RawDataBufferFactory.create(longBuff.array(), false);
149138
Tensor<TInt64> ndarray = TInt64.tensorOf(Shape.of(ogShape), dataBuffer);
150139
return ndarray;
151140
}
@@ -161,9 +150,7 @@ private static Tensor<TFloat32> buildFloat(SharedMemoryArray tensor)
161150
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
162151
ByteBuffer buff = tensor.getDataBufferNoHeader();
163152
FloatBuffer floatBuff = buff.asFloatBuffer();
164-
float[] floatArray = new float[floatBuff.capacity()];
165-
floatBuff.get(floatArray);
166-
FloatDataBuffer dataBuffer = RawDataBufferFactory.create(floatArray, false);
153+
FloatDataBuffer dataBuffer = RawDataBufferFactory.create(floatBuff.array(), false);
167154
Tensor<TFloat32> ndarray = TFloat32.tensorOf(Shape.of(ogShape), dataBuffer);
168155
return ndarray;
169156
}
@@ -179,9 +166,7 @@ private static Tensor<TFloat64> buildDouble(SharedMemoryArray tensor)
179166
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
180167
ByteBuffer buff = tensor.getDataBufferNoHeader();
181168
DoubleBuffer doubleBuff = buff.asDoubleBuffer();
182-
double[] doubleArray = new double[doubleBuff.capacity()];
183-
doubleBuff.get(doubleArray);
184-
DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(doubleArray, false);
169+
DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(doubleBuff.array(), false);
185170
Tensor<TFloat64> ndarray = TFloat64.tensorOf(Shape.of(ogShape), dataBuffer);
186171
return ndarray;
187172
}

0 commit comments

Comments
 (0)