Skip to content

Commit 24c5a81

Browse files
committed
fix issue referencing and copying data
1 parent c643162 commit 24c5a81

File tree

2 files changed

+45
-23
lines changed

2 files changed

+45
-23
lines changed

Diff for: src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/shm/ShmBuilder.java

+25-10
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,11 @@ private static void buildFromTensorUByte(Tensor<TUint8> tensor, String memoryNam
9898
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
9999
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
100100
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new UnsignedByteType(), false, true);
101-
ByteBuffer buff1 = shma.getDataBufferNoHeader();
102-
tensor.rawData().read(buff1.array(), 0, buff1.capacity());
101+
ByteBuffer buff = shma.getDataBufferNoHeader();
102+
byte[] flat = new byte[buff.capacity()];
103+
ByteBuffer buff2 = ByteBuffer.wrap(flat);
104+
tensor.rawData().read(flat, 0, buff.capacity());
105+
buff = buff2;
103106
if (PlatformDetection.isWindows()) shma.close();
104107
}
105108

@@ -111,8 +114,11 @@ private static void buildFromTensorInt(Tensor<TInt32> tensor, String memoryName)
111114
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
112115

113116
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new IntType(), false, true);
114-
ByteBuffer buff1 = shma.getDataBufferNoHeader();
115-
tensor.rawData().read(buff1.array(), 0, buff1.capacity());
117+
ByteBuffer buff = shma.getDataBufferNoHeader();
118+
byte[] flat = new byte[buff.capacity()];
119+
ByteBuffer buff2 = ByteBuffer.wrap(flat);
120+
tensor.rawData().read(flat, 0, buff.capacity());
121+
buff = buff2;
116122
if (PlatformDetection.isWindows()) shma.close();
117123
}
118124

@@ -124,8 +130,11 @@ private static void buildFromTensorFloat(Tensor<TFloat32> tensor, String memoryN
124130
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
125131

126132
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new FloatType(), false, true);
127-
ByteBuffer buff1 = shma.getDataBufferNoHeader();
128-
tensor.rawData().read(buff1.array(), 0, buff1.capacity());
133+
ByteBuffer buff = shma.getDataBufferNoHeader();
134+
byte[] flat = new byte[buff.capacity()];
135+
ByteBuffer buff2 = ByteBuffer.wrap(flat);
136+
tensor.rawData().read(flat, 0, buff.capacity());
137+
buff = buff2;
129138
if (PlatformDetection.isWindows()) shma.close();
130139
}
131140

@@ -137,8 +146,11 @@ private static void buildFromTensorDouble(Tensor<TFloat64> tensor, String memory
137146
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
138147

139148
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new DoubleType(), false, true);
140-
ByteBuffer buff1 = shma.getDataBufferNoHeader();
141-
tensor.rawData().read(buff1.array(), 0, buff1.capacity());
149+
ByteBuffer buff = shma.getDataBufferNoHeader();
150+
byte[] flat = new byte[buff.capacity()];
151+
ByteBuffer buff2 = ByteBuffer.wrap(flat);
152+
tensor.rawData().read(flat, 0, buff.capacity());
153+
buff = buff2;
142154
if (PlatformDetection.isWindows()) shma.close();
143155
}
144156

@@ -151,8 +163,11 @@ private static void buildFromTensorLong(Tensor<TInt64> tensor, String memoryName
151163

152164

153165
SharedMemoryArray shma = SharedMemoryArray.readOrCreate(memoryName, arrayShape, new LongType(), false, true);
154-
ByteBuffer buff1 = shma.getDataBufferNoHeader();
155-
tensor.rawData().read(buff1.array(), 0, buff1.capacity());
166+
ByteBuffer buff = shma.getDataBufferNoHeader();
167+
byte[] flat = new byte[buff.capacity()];
168+
ByteBuffer buff2 = ByteBuffer.wrap(flat);
169+
tensor.rawData().read(flat, 0, buff.capacity());
170+
buff = buff2;
156171
if (PlatformDetection.isWindows()) shma.close();
157172
}
158173
}

Diff for: src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/shm/TensorBuilder.java

+20-13
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,6 @@
2626
import net.imglib2.util.Cast;
2727

2828
import java.nio.ByteBuffer;
29-
import java.nio.DoubleBuffer;
30-
import java.nio.FloatBuffer;
31-
import java.nio.IntBuffer;
32-
import java.nio.LongBuffer;
3329
import java.util.Arrays;
3430

3531
import org.tensorflow.Tensor;
@@ -102,7 +98,10 @@ private static Tensor<TUint8> buildUByte(SharedMemoryArray tensor)
10298
if (!tensor.isNumpyFormat())
10399
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
104100
ByteBuffer buff = tensor.getDataBufferNoHeader();
105-
ByteDataBuffer dataBuffer = RawDataBufferFactory.create(buff.array(), false);
101+
byte[] flat = new byte[buff.capacity()];
102+
buff.get(flat);
103+
buff.rewind();
104+
ByteDataBuffer dataBuffer = RawDataBufferFactory.create(flat, false);
106105
Tensor<TUint8> ndarray = Tensor.of(TUint8.DTYPE, Shape.of(ogShape), dataBuffer);
107106
return ndarray;
108107
}
@@ -117,8 +116,10 @@ private static Tensor<TInt32> buildInt(SharedMemoryArray tensor)
117116
if (!tensor.isNumpyFormat())
118117
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
119118
ByteBuffer buff = tensor.getDataBufferNoHeader();
120-
IntBuffer intBuff = buff.asIntBuffer();
121-
IntDataBuffer dataBuffer = RawDataBufferFactory.create(intBuff.array(), false);
119+
int[] flat = new int[buff.capacity() / 4];
120+
buff.asIntBuffer().get(flat);
121+
buff.rewind();
122+
IntDataBuffer dataBuffer = RawDataBufferFactory.create(flat, false);
122123
Tensor<TInt32> ndarray = TInt32.tensorOf(Shape.of(ogShape), dataBuffer);
123124
return ndarray;
124125
}
@@ -133,8 +134,10 @@ private static Tensor<TInt64> buildLong(SharedMemoryArray tensor)
133134
if (!tensor.isNumpyFormat())
134135
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
135136
ByteBuffer buff = tensor.getDataBufferNoHeader();
136-
LongBuffer longBuff = buff.asLongBuffer();
137-
LongDataBuffer dataBuffer = RawDataBufferFactory.create(longBuff.array(), false);
137+
long[] flat = new long[buff.capacity() / 8];
138+
buff.asLongBuffer().get(flat);
139+
buff.rewind();
140+
LongDataBuffer dataBuffer = RawDataBufferFactory.create(flat, false);
138141
Tensor<TInt64> ndarray = TInt64.tensorOf(Shape.of(ogShape), dataBuffer);
139142
return ndarray;
140143
}
@@ -149,8 +152,10 @@ private static Tensor<TFloat32> buildFloat(SharedMemoryArray tensor)
149152
if (!tensor.isNumpyFormat())
150153
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
151154
ByteBuffer buff = tensor.getDataBufferNoHeader();
152-
FloatBuffer floatBuff = buff.asFloatBuffer();
153-
FloatDataBuffer dataBuffer = RawDataBufferFactory.create(floatBuff.array(), false);
155+
float[] flat = new float[buff.capacity() / 4];
156+
buff.asFloatBuffer().get(flat);
157+
buff.rewind();
158+
FloatDataBuffer dataBuffer = RawDataBufferFactory.create(flat, false);
154159
Tensor<TFloat32> ndarray = TFloat32.tensorOf(Shape.of(ogShape), dataBuffer);
155160
return ndarray;
156161
}
@@ -165,8 +170,10 @@ private static Tensor<TFloat64> buildDouble(SharedMemoryArray tensor)
165170
if (!tensor.isNumpyFormat())
166171
throw new IllegalArgumentException("Shared memory arrays must be saved in numpy format.");
167172
ByteBuffer buff = tensor.getDataBufferNoHeader();
168-
DoubleBuffer doubleBuff = buff.asDoubleBuffer();
169-
DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(doubleBuff.array(), false);
173+
double[] flat = new double[buff.capacity() / 8];
174+
buff.asDoubleBuffer().get(flat);
175+
buff.rewind();
176+
DoubleDataBuffer dataBuffer = RawDataBufferFactory.create(flat, false);
170177
Tensor<TFloat64> ndarray = TFloat64.tensorOf(Shape.of(ogShape), dataBuffer);
171178
return ndarray;
172179
}

0 commit comments

Comments
 (0)