Skip to content

Commit 90e4910

Browse files
committed
improve interprocessing communication
1 parent a4eeaf9 commit 90e4910

File tree

2 files changed

+15
-49
lines changed

2 files changed

+15
-49
lines changed

src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/shm/ShmBuilder.java

+10-36
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import java.nio.ByteBuffer;
2929
import java.util.Arrays;
3030

31-
import org.tensorflow.Tensor;
3231
import org.tensorflow.types.TFloat32;
3332
import org.tensorflow.types.TFloat64;
3433
import org.tensorflow.types.TInt32;
@@ -101,13 +100,8 @@ private static void buildFromTensorUByte(TUint8 tensor, String memoryName) throw
101100
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
102101
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
103102
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
104-
ByteBuffer buff = shma.getDataBuffer();
105-
int totalSize = 1;
106-
for (long i : arrayShape) {totalSize *= i;}
107-
byte[] flatArr = new byte[buff.capacity()];
108-
buff.get(flatArr);
109-
tensor.asRawTensor().data().read(flatArr, flatArr.length - totalSize, totalSize);
110-
shma.setBuffer(ByteBuffer.wrap(flatArr));
103+
ByteBuffer buff = shma.getDataBufferNoHeader();
104+
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
111105
if (PlatformDetection.isWindows()) shma.close();
112106
}
113107

@@ -119,13 +113,8 @@ private static void buildFromTensorInt(TInt32 tensor, String memoryName) throws
119113
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
120114

121115
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
122-
ByteBuffer buff = shma.getDataBuffer();
123-
int totalSize = 4;
124-
for (long i : arrayShape) {totalSize *= i;}
125-
byte[] flatArr = new byte[buff.capacity()];
126-
buff.get(flatArr);
127-
tensor.asRawTensor().data().read(flatArr, flatArr.length - totalSize, totalSize);
128-
shma.setBuffer(ByteBuffer.wrap(flatArr));
116+
ByteBuffer buff = shma.getDataBufferNoHeader();
117+
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
129118
if (PlatformDetection.isWindows()) shma.close();
130119
}
131120

@@ -137,13 +126,8 @@ private static void buildFromTensorFloat(TFloat32 tensor, String memoryName) thr
137126
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
138127

139128
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
140-
ByteBuffer buff = shma.getDataBuffer();
141-
int totalSize = 4;
142-
for (long i : arrayShape) {totalSize *= i;}
143-
byte[] flatArr = new byte[buff.capacity()];
144-
buff.get(flatArr);
145-
tensor.asRawTensor().data().read(flatArr, flatArr.length - totalSize, totalSize);
146-
shma.setBuffer(ByteBuffer.wrap(flatArr));
129+
ByteBuffer buff = shma.getDataBufferNoHeader();
130+
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
147131
if (PlatformDetection.isWindows()) shma.close();
148132
}
149133

@@ -155,13 +139,8 @@ private static void buildFromTensorDouble(TFloat64 tensor, String memoryName) th
155139
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
156140

157141
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true);
158-
ByteBuffer buff = shma.getDataBuffer();
159-
int totalSize = 8;
160-
for (long i : arrayShape) {totalSize *= i;}
161-
byte[] flatArr = new byte[buff.capacity()];
162-
buff.get(flatArr);
163-
tensor.asRawTensor().data().read(flatArr, flatArr.length - totalSize, totalSize);
164-
shma.setBuffer(ByteBuffer.wrap(flatArr));
142+
ByteBuffer buff = shma.getDataBufferNoHeader();
143+
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
165144
if (PlatformDetection.isWindows()) shma.close();
166145
}
167146

@@ -174,13 +153,8 @@ private static void buildFromTensorLong(TInt64 tensor, String memoryName) throws
174153

175154

176155
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true);
177-
ByteBuffer buff = shma.getDataBuffer();
178-
int totalSize = 8;
179-
for (long i : arrayShape) {totalSize *= i;}
180-
byte[] flatArr = new byte[buff.capacity()];
181-
buff.get(flatArr);
182-
tensor.asRawTensor().data().read(flatArr, flatArr.length - totalSize, totalSize);
183-
shma.setBuffer(ByteBuffer.wrap(flatArr));
156+
ByteBuffer buff = shma.getDataBufferNoHeader();
157+
tensor.asRawTensor().data().read(buff.array(), 0, buff.capacity());
184158
if (PlatformDetection.isWindows()) shma.close();
185159
}
186160
}

src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/shm/TensorBuilder.java

+5-13
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,7 @@ private static TInt32 buildInt(SharedMemoryArray tensor)
118118
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
119119
ByteBuffer buff = tensor.getDataBufferNoHeader();
120120
IntBuffer intBuff = buff.asIntBuffer();
121-
int[] intArray = new int[intBuff.capacity()];
122-
intBuff.get(intArray);
123-
IntDataBuffer dataBuffer = RawDataBufferFactory.create(intArray, false);
121+
IntDataBuffer dataBuffer = RawDataBufferFactory.create(intBuff.array(), false);
124122
TInt32 ndarray = TInt32.tensorOf(Shape.of(ogShape),
125123
dataBuffer);
126124
return ndarray;
@@ -137,9 +135,7 @@ private static TInt64 buildLong(SharedMemoryArray tensor)
137135
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
138136
ByteBuffer buff = tensor.getDataBufferNoHeader();
139137
LongBuffer longBuff = buff.asLongBuffer();
140-
long[] longArray = new long[longBuff.capacity()];
141-
longBuff.get(longArray);
142-
LongDataBuffer dataBuffer = RawDataBufferFactory.create(longArray, false);
138+
LongDataBuffer dataBuffer = RawDataBufferFactory.create(longBuff.array(), false);
143139
TInt64 ndarray = TInt64.tensorOf(Shape.of(ogShape),
144140
dataBuffer);
145141
return ndarray;
@@ -156,9 +152,7 @@ private static TFloat32 buildFloat(SharedMemoryArray tensor)
156152
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
157153
ByteBuffer buff = tensor.getDataBufferNoHeader();
158154
FloatBuffer floatBuff = buff.asFloatBuffer();
159-
float[] floatArray = new float[floatBuff.capacity()];
160-
floatBuff.get(floatArray);
161-
FloatDataBuffer dataBuffer = RawDataBufferFactory.create(floatArray, false);
155+
FloatDataBuffer dataBuffer = RawDataBufferFactory.create(floatBuff.array(), false);
162156
TFloat32 ndarray = TFloat32.tensorOf(Shape.of(ogShape), dataBuffer);
163157
return ndarray;
164158
}
@@ -173,10 +167,8 @@ private static TFloat64 buildDouble(SharedMemoryArray tensor)
173167
if (!tensor.isNumpyFormat())
174168
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
175169
ByteBuffer buff = tensor.getDataBufferNoHeader();
176-
DoubleBuffer doubleBuff = buff.asDoubleBuffer();
177-
double[] doubleArray = new double[doubleBuff.capacity()];
178-
doubleBuff.get(doubleArray);
179-
DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(doubleArray, false);
170+
DoubleBuffer floatBuff = buff.asDoubleBuffer();
171+
DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(floatBuff.array(), false);
180172
TFloat64 ndarray = TFloat64.tensorOf(Shape.of(ogShape), dataBuffer);
181173
return ndarray;
182174
}

0 commit comments

Comments
 (0)