Skip to content

Commit d7febcb

Browse files
committed
increase robustness creating pytorch tensors
1 parent 0b09e00 commit d7febcb

File tree

1 file changed

+49
-11
lines changed

1 file changed

+49
-11
lines changed

Diff for: src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/JavaCPPTensorBuilder.java

+49-11
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,22 @@
2020
*/
2121
package io.bioimage.modelrunner.pytorch.javacpp.tensor;
2222

23+
import java.nio.ByteBuffer;
24+
import java.util.Arrays;
25+
2326
import io.bioimage.modelrunner.tensor.Tensor;
2427
import io.bioimage.modelrunner.tensor.Utils;
28+
import net.imglib2.Cursor;
2529
import net.imglib2.RandomAccessibleInterval;
2630
import net.imglib2.blocks.PrimitiveBlocks;
2731
import net.imglib2.type.Type;
2832
import net.imglib2.type.numeric.integer.ByteType;
2933
import net.imglib2.type.numeric.integer.IntType;
34+
import net.imglib2.type.numeric.integer.UnsignedByteType;
3035
import net.imglib2.type.numeric.real.DoubleType;
3136
import net.imglib2.type.numeric.real.FloatType;
3237
import net.imglib2.util.Util;
38+
import net.imglib2.view.Views;
3339

3440
/**
3541
* Class that manages the creation of JAvaCPP Pytorch tensors
@@ -101,16 +107,24 @@ public static <T extends Type<T>> org.bytedeco.pytorch.Tensor build(RandomAccess
101107
private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleInterval<ByteType> tensor)
102108
{
103109
long[] ogShape = tensor.dimensionsAsLongArray();
110+
if (CommonUtils.int32Overflows(ogShape))
111+
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
112+
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
104113
tensor = Utils.transpose(tensor);
105-
PrimitiveBlocks< ByteType > blocks = PrimitiveBlocks.of( tensor );
106114
long[] tensorShape = tensor.dimensionsAsLongArray();
107115
int size = 1;
108116
for (long ll : tensorShape) size *= ll;
109117
final byte[] flatArr = new byte[size];
110118
int[] sArr = new int[tensorShape.length];
111119
for (int i = 0; i < sArr.length; i ++)
112120
sArr[i] = (int) tensorShape[i];
113-
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );
121+
122+
Cursor<ByteType> cursor = Views.flatIterable(tensor).cursor();
123+
int i = 0;
124+
while (cursor.hasNext()) {
125+
cursor.fwd();
126+
flatArr[i ++] = cursor.get().get();
127+
}
114128
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
115129
return ndarray;
116130
}
@@ -126,17 +140,25 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleI
126140
private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleInterval<IntType> tensor)
127141
{
128142
long[] ogShape = tensor.dimensionsAsLongArray();
143+
if (CommonUtils.int32Overflows(ogShape))
144+
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
145+
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
129146
tensor = Utils.transpose(tensor);
130-
PrimitiveBlocks< IntType > blocks = PrimitiveBlocks.of( tensor );
131147
long[] tensorShape = tensor.dimensionsAsLongArray();
132148
int size = 1;
133149
for (long ll : tensorShape) size *= ll;
134150
final int[] flatArr = new int[size];
135151
int[] sArr = new int[tensorShape.length];
136152
for (int i = 0; i < sArr.length; i ++)
137153
sArr[i] = (int) tensorShape[i];
138-
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );
139-
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
154+
155+
Cursor<IntType> cursor = Views.flatIterable(tensor).cursor();
156+
int i = 0;
157+
while (cursor.hasNext()) {
158+
cursor.fwd();
159+
flatArr[i ++] = cursor.get().get();
160+
}
161+
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
140162
return ndarray;
141163
}
142164

@@ -151,17 +173,25 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleIn
151173
private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessibleInterval<FloatType> tensor)
152174
{
153175
long[] ogShape = tensor.dimensionsAsLongArray();
176+
if (CommonUtils.int32Overflows(ogShape))
177+
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
178+
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
154179
tensor = Utils.transpose(tensor);
155-
PrimitiveBlocks< FloatType > blocks = PrimitiveBlocks.of( tensor );
156180
long[] tensorShape = tensor.dimensionsAsLongArray();
157181
int size = 1;
158182
for (long ll : tensorShape) size *= ll;
159183
final float[] flatArr = new float[size];
160184
int[] sArr = new int[tensorShape.length];
161185
for (int i = 0; i < sArr.length; i ++)
162186
sArr[i] = (int) tensorShape[i];
163-
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );
164-
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
187+
188+
Cursor<FloatType> cursor = Views.flatIterable(tensor).cursor();
189+
int i = 0;
190+
while (cursor.hasNext()) {
191+
cursor.fwd();
192+
flatArr[i ++] = cursor.get().get();
193+
}
194+
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
165195
return ndarray;
166196
}
167197

@@ -176,17 +206,25 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessible
176206
private static org.bytedeco.pytorch.Tensor buildFromTensorDouble(RandomAccessibleInterval<DoubleType> tensor)
177207
{
178208
long[] ogShape = tensor.dimensionsAsLongArray();
209+
if (CommonUtils.int32Overflows(ogShape))
210+
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
211+
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
179212
tensor = Utils.transpose(tensor);
180-
PrimitiveBlocks< DoubleType > blocks = PrimitiveBlocks.of( tensor );
181213
long[] tensorShape = tensor.dimensionsAsLongArray();
182214
int size = 1;
183215
for (long ll : tensorShape) size *= ll;
184216
final double[] flatArr = new double[size];
185217
int[] sArr = new int[tensorShape.length];
186218
for (int i = 0; i < sArr.length; i ++)
187219
sArr[i] = (int) tensorShape[i];
188-
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );
189-
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
220+
221+
Cursor<DoubleType> cursor = Views.flatIterable(tensor).cursor();
222+
int i = 0;
223+
while (cursor.hasNext()) {
224+
cursor.fwd();
225+
flatArr[i ++] = cursor.get().get();
226+
}
227+
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
190228
return ndarray;
191229
}
192230
}

0 commit comments

Comments
 (0)