Skip to content

Commit 3251433

Browse files
committed
increase robsutness
1 parent 193d3a1 commit 3251433

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

src/main/java/io/bioimage/modelrunner/pytorch/tensor/ImgLib2Builder.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,12 @@
3434
import net.imglib2.type.numeric.real.DoubleType;
3535
import net.imglib2.type.numeric.real.FloatType;
3636
import net.imglib2.util.Cast;
37+
38+
import java.util.Arrays;
39+
3740
import ai.djl.ndarray.NDArray;
3841
import io.bioimage.modelrunner.tensor.Utils;
42+
import io.bioimage.modelrunner.utils.CommonUtils;
3943

4044
/**
4145
* A {@link RandomAccessibleInterval} builder for Pytorch {@link ai.djl.ndarray.NDArray} objects.
@@ -89,6 +93,9 @@ public static < T extends RealType< T > & NativeType< T > > RandomAccessibleInte
8993
*/
9094
private static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(NDArray tensor) {
9195
long[] arrayShape = tensor.getShape().getShape();
96+
if (CommonUtils.int32Overflows(arrayShape, 1))
97+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
98+
+ " is too big. Max number of elements per output tensor supported: " + Integer.MAX_VALUE);
9299
long[] tensorShape = new long[arrayShape.length];
93100
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
94101
byte[] flatArr = tensor.toByteArray();
@@ -105,6 +112,9 @@ private static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(N
105112
*/
106113
private static RandomAccessibleInterval<ByteType> buildFromTensorByte(NDArray tensor) {
107114
long[] arrayShape = tensor.getShape().getShape();
115+
if (CommonUtils.int32Overflows(arrayShape, 1))
116+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
117+
+ " is too big. Max number of elements per output tensor supported: " + Integer.MAX_VALUE);
108118
long[] tensorShape = new long[arrayShape.length];
109119
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
110120
byte[] flatArr = tensor.toByteArray();
@@ -121,6 +131,9 @@ private static RandomAccessibleInterval<ByteType> buildFromTensorByte(NDArray te
121131
*/
122132
private static RandomAccessibleInterval<IntType> buildFromTensorInt(NDArray tensor) {
123133
long[] arrayShape = tensor.getShape().getShape();
134+
if (CommonUtils.int32Overflows(arrayShape, 1))
135+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
136+
+ " is too big. Max number of elements per output tensor supported: " + Integer.MAX_VALUE);
124137
long[] tensorShape = new long[arrayShape.length];
125138
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
126139
int[] flatArr = tensor.toIntArray();
@@ -137,6 +150,9 @@ private static RandomAccessibleInterval<IntType> buildFromTensorInt(NDArray tens
137150
*/
138151
private static RandomAccessibleInterval<FloatType> buildFromTensorFloat(NDArray tensor) {
139152
long[] arrayShape = tensor.getShape().getShape();
153+
if (CommonUtils.int32Overflows(arrayShape, 1))
154+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
155+
+ " is too big. Max number of elements per output tensor supported: " + Integer.MAX_VALUE);
140156
long[] tensorShape = new long[arrayShape.length];
141157
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
142158
float[] flatArr = tensor.toFloatArray();
@@ -153,6 +169,9 @@ private static RandomAccessibleInterval<FloatType> buildFromTensorFloat(NDArray
153169
*/
154170
private static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(NDArray tensor) {
155171
long[] arrayShape = tensor.getShape().getShape();
172+
if (CommonUtils.int32Overflows(arrayShape, 1))
173+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
174+
+ " is too big. Max number of elements per output tensor supported: " + Integer.MAX_VALUE);
156175
long[] tensorShape = new long[arrayShape.length];
157176
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
158177
double[] flatArr = tensor.toDoubleArray();
@@ -169,6 +188,9 @@ private static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(NDArra
169188
*/
170189
private static RandomAccessibleInterval<LongType> buildFromTensorLong(NDArray tensor) {
171190
long[] arrayShape = tensor.getShape().getShape();
191+
if (CommonUtils.int32Overflows(arrayShape, 1))
192+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
193+
+ " is too big. Max number of elements per output tensor supported: " + Integer.MAX_VALUE);
172194
long[] tensorShape = new long[arrayShape.length];
173195
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
174196
long[] flatArr = tensor.toLongArray();

0 commit comments

Comments
 (0)