Skip to content

Commit 4c41326

Browse files
committed
correct erroneus creation of onnx tensors
1 parent e86afab commit 4c41326

File tree

1 file changed

+48
-18
lines changed

1 file changed

+48
-18
lines changed

src/main/java/io/bioimage/modelrunner/onnx/tensor/TensorBuilder.java

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
import java.nio.DoubleBuffer;
2929
import java.nio.FloatBuffer;
3030
import java.nio.IntBuffer;
31+
import java.util.Arrays;
3132

33+
import net.imglib2.Cursor;
3234
import net.imglib2.RandomAccessibleInterval;
3335
import net.imglib2.blocks.PrimitiveBlocks;
3436
import net.imglib2.img.Img;
@@ -38,7 +40,7 @@
3840
import net.imglib2.type.numeric.real.DoubleType;
3941
import net.imglib2.type.numeric.real.FloatType;
4042
import net.imglib2.util.Util;
41-
43+
import net.imglib2.view.Views;
4244
import ai.onnxruntime.OnnxTensor;
4345
import ai.onnxruntime.OrtEnvironment;
4446
import ai.onnxruntime.OrtException;
@@ -122,18 +124,25 @@ public static <T extends Type<T>> OnnxTensor build(RandomAccessibleInterval<T> r
122124
*/
123125
private static OnnxTensor buildByte(RandomAccessibleInterval<ByteType> tensor, OrtEnvironment env) throws OrtException
124126
{
127+
long[] ogShape = tensor.dimensionsAsLongArray();
128+
if (CommonUtils.int32Overflows(ogShape))
129+
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
130+
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
125131
tensor = Utils.transpose(tensor);
126-
PrimitiveBlocks< ByteType > blocks = PrimitiveBlocks.of( tensor );
127132
long[] tensorShape = tensor.dimensionsAsLongArray();
128-
if (CommonUtils.int32Overflows(tensorShape))
129-
throw new IllegalArgumentException("Tensor is too big to handle. Max number of elements allowed in a tensor: " + Integer.MAX_VALUE);
130133
int size = 1;
131134
for (long ll : tensorShape) size *= ll;
132135
final byte[] flatArr = new byte[size];
133136
int[] sArr = new int[tensorShape.length];
134137
for (int i = 0; i < sArr.length; i ++)
135138
sArr[i] = (int) tensorShape[i];
136-
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );
139+
140+
Cursor<ByteType> cursor = Views.flatIterable(tensor).cursor();
141+
int i = 0;
142+
while (cursor.hasNext()) {
143+
cursor.fwd();
144+
flatArr[i ++] = cursor.get().getByte();
145+
}
137146
ByteBuffer buff = ByteBuffer.wrap(flatArr);
138147
OnnxTensor ndarray = OnnxTensor.createTensor(env, buff, tensorShape);
139148
return ndarray;
@@ -153,18 +162,25 @@ private static OnnxTensor buildByte(RandomAccessibleInterval<ByteType> tensor, O
153162
*/
154163
private static OnnxTensor buildInt(RandomAccessibleInterval<IntType> tensor, OrtEnvironment env) throws OrtException
155164
{
165+
long[] ogShape = tensor.dimensionsAsLongArray();
166+
if (CommonUtils.int32Overflows(ogShape))
167+
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
168+
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
156169
tensor = Utils.transpose(tensor);
157-
PrimitiveBlocks< IntType > blocks = PrimitiveBlocks.of( tensor );
158170
long[] tensorShape = tensor.dimensionsAsLongArray();
159-
if (CommonUtils.int32Overflows(tensorShape))
160-
throw new IllegalArgumentException("Tensor is too big to handle. Max number of elements allowed in a tensor: " + Integer.MAX_VALUE);
161171
int size = 1;
162172
for (long ll : tensorShape) size *= ll;
163173
final int[] flatArr = new int[size];
164174
int[] sArr = new int[tensorShape.length];
165175
for (int i = 0; i < sArr.length; i ++)
166176
sArr[i] = (int) tensorShape[i];
167-
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );
177+
178+
Cursor<IntType> cursor = Views.flatIterable(tensor).cursor();
179+
int i = 0;
180+
while (cursor.hasNext()) {
181+
cursor.fwd();
182+
flatArr[i ++] = cursor.get().get();
183+
}
168184
IntBuffer buff = IntBuffer.wrap(flatArr);
169185
OnnxTensor ndarray = OnnxTensor.createTensor(env, buff, tensorShape);
170186
return ndarray;
@@ -184,20 +200,27 @@ private static OnnxTensor buildInt(RandomAccessibleInterval<IntType> tensor, Ort
184200
*/
185201
private static OnnxTensor buildFloat(RandomAccessibleInterval<FloatType> tensor, OrtEnvironment env) throws OrtException
186202
{
203+
long[] ogShape = tensor.dimensionsAsLongArray();
204+
if (CommonUtils.int32Overflows(ogShape))
205+
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
206+
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
187207
tensor = Utils.transpose(tensor);
188-
PrimitiveBlocks< FloatType > blocks = PrimitiveBlocks.of( tensor );
189208
long[] tensorShape = tensor.dimensionsAsLongArray();
190-
if (CommonUtils.int32Overflows(tensorShape))
191-
throw new IllegalArgumentException("Tensor is too big to handle. Max number of elements allowed in a tensor: " + Integer.MAX_VALUE);
192209
int size = 1;
193210
for (long ll : tensorShape) size *= ll;
194211
final float[] flatArr = new float[size];
195212
int[] sArr = new int[tensorShape.length];
196213
for (int i = 0; i < sArr.length; i ++)
197214
sArr[i] = (int) tensorShape[i];
198-
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );
215+
216+
Cursor<FloatType> cursor = Views.flatIterable(tensor).cursor();
217+
int i = 0;
218+
while (cursor.hasNext()) {
219+
cursor.fwd();
220+
flatArr[i ++] = cursor.get().get();
221+
}
199222
FloatBuffer buff = FloatBuffer.wrap(flatArr);
200-
OnnxTensor ndarray = OnnxTensor.createTensor(env, buff, tensorShape);
223+
OnnxTensor ndarray = OnnxTensor.createTensor(env, buff, ogShape);
201224
return ndarray;
202225
}
203226

@@ -215,18 +238,25 @@ private static OnnxTensor buildFloat(RandomAccessibleInterval<FloatType> tensor,
215238
*/
216239
private static OnnxTensor buildDouble(RandomAccessibleInterval<DoubleType> tensor, OrtEnvironment env) throws OrtException
217240
{
241+
long[] ogShape = tensor.dimensionsAsLongArray();
242+
if (CommonUtils.int32Overflows(ogShape))
243+
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
244+
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
218245
tensor = Utils.transpose(tensor);
219-
PrimitiveBlocks< DoubleType > blocks = PrimitiveBlocks.of( tensor );
220246
long[] tensorShape = tensor.dimensionsAsLongArray();
221-
if (CommonUtils.int32Overflows(tensorShape))
222-
throw new IllegalArgumentException("Tensor is too big to handle. Max number of elements allowed in a tensor: " + Integer.MAX_VALUE);
223247
int size = 1;
224248
for (long ll : tensorShape) size *= ll;
225249
final double[] flatArr = new double[size];
226250
int[] sArr = new int[tensorShape.length];
227251
for (int i = 0; i < sArr.length; i ++)
228252
sArr[i] = (int) tensorShape[i];
229-
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );
253+
254+
Cursor<DoubleType> cursor = Views.flatIterable(tensor).cursor();
255+
int i = 0;
256+
while (cursor.hasNext()) {
257+
cursor.fwd();
258+
flatArr[i ++] = cursor.get().get();
259+
}
230260
DoubleBuffer buff = DoubleBuffer.wrap(flatArr);
231261
OnnxTensor ndarray = OnnxTensor.createTensor(env, buff, tensorShape);
232262
return ndarray;

0 commit comments

Comments
 (0)