Skip to content

Commit f0c2e91

Browse files
committed
improve robustness
1 parent 57ec6bf commit f0c2e91

File tree

2 files changed

+28
-11
lines changed

2 files changed

+28
-11
lines changed

src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/ImgLib2Builder.java

+18-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
import io.bioimage.modelrunner.tensor.Utils;
25-
25+
import io.bioimage.modelrunner.utils.CommonUtils;
2626
import net.imglib2.RandomAccessibleInterval;
2727
import net.imglib2.img.array.ArrayImgs;
2828
import net.imglib2.type.Type;
@@ -32,6 +32,8 @@
3232
import net.imglib2.type.numeric.real.DoubleType;
3333
import net.imglib2.type.numeric.real.FloatType;
3434

35+
import java.util.Arrays;
36+
3537
import org.tensorflow.Tensor;
3638
import org.tensorflow.types.TFloat32;
3739
import org.tensorflow.types.TFloat64;
@@ -95,6 +97,9 @@ public static <T extends Type<T>> RandomAccessibleInterval<T> build(Tensor<? ext
9597
private static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(Tensor<TUint8> tensor)
9698
{
9799
long[] arrayShape = tensor.shape().asArray();
100+
if (CommonUtils.int32Overflows(arrayShape, 1))
101+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
102+
+ " is too big. Max number of elements per double ubyte tensor supported: " + Integer.MAX_VALUE / 1);
98103
long[] tensorShape = new long[arrayShape.length];
99104
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
100105
int totalSize = 1;
@@ -115,6 +120,9 @@ private static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(T
115120
private static RandomAccessibleInterval<IntType> buildFromTensorInt(Tensor<TInt32> tensor)
116121
{
117122
long[] arrayShape = tensor.shape().asArray();
123+
if (CommonUtils.int32Overflows(arrayShape, 4))
124+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
125+
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
118126
long[] tensorShape = new long[arrayShape.length];
119127
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
120128
int totalSize = 1;
@@ -135,6 +143,9 @@ private static RandomAccessibleInterval<IntType> buildFromTensorInt(Tensor<TInt3
135143
private static RandomAccessibleInterval<FloatType> buildFromTensorFloat(Tensor<TFloat32> tensor)
136144
{
137145
long[] arrayShape = tensor.shape().asArray();
146+
if (CommonUtils.int32Overflows(arrayShape, 4))
147+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
148+
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
138149
long[] tensorShape = new long[arrayShape.length];
139150
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
140151
int totalSize = 1;
@@ -155,6 +166,9 @@ private static RandomAccessibleInterval<FloatType> buildFromTensorFloat(Tensor<T
155166
private static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(Tensor<TFloat64> tensor)
156167
{
157168
long[] arrayShape = tensor.shape().asArray();
169+
if (CommonUtils.int32Overflows(arrayShape, 8))
170+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
171+
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
158172
long[] tensorShape = new long[arrayShape.length];
159173
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
160174
int totalSize = 1;
@@ -175,6 +189,9 @@ private static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(Tensor
175189
private static RandomAccessibleInterval<LongType> buildFromTensorLong(Tensor<TInt64> tensor)
176190
{
177191
long[] arrayShape = tensor.shape().asArray();
192+
if (CommonUtils.int32Overflows(arrayShape, 8))
193+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
194+
+ " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8);
178195
long[] tensorShape = new long[arrayShape.length];
179196
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
180197
int totalSize = 1;

src/main/java/io/bioimage/modelrunner/tensorflow/v2/api020/tensor/TensorBuilder.java

+10-10
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,9 @@ private static Tensor<TUint8> buildUByte(
135135
throws IllegalArgumentException
136136
{
137137
long[] ogShape = tensor.dimensionsAsLongArray();
138-
if (CommonUtils.int32Overflows(ogShape))
138+
if (CommonUtils.int32Overflows(ogShape, 1))
139139
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
140-
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
140+
+ " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE);
141141
tensor = Utils.transpose(tensor);
142142
long[] tensorShape = tensor.dimensionsAsLongArray();
143143
int size = 1;
@@ -172,9 +172,9 @@ private static Tensor<TInt32> buildInt(
172172
RandomAccessibleInterval<IntType> tensor) throws IllegalArgumentException
173173
{
174174
long[] ogShape = tensor.dimensionsAsLongArray();
175-
if (CommonUtils.int32Overflows(ogShape))
175+
if (CommonUtils.int32Overflows(ogShape, 4))
176176
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
177-
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
177+
+ " is too big. Max number of elements per int tensor supported: " + Integer.MAX_VALUE / 4);
178178
tensor = Utils.transpose(tensor);
179179
long[] tensorShape = tensor.dimensionsAsLongArray();
180180
int size = 1;
@@ -210,9 +210,9 @@ private static Tensor<TInt64> buildLong(
210210
throws IllegalArgumentException
211211
{
212212
long[] ogShape = tensor.dimensionsAsLongArray();
213-
if (CommonUtils.int32Overflows(ogShape))
213+
if (CommonUtils.int32Overflows(ogShape, 8))
214214
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
215-
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
215+
+ " is too big. Max number of elements per long tensor supported: " + Integer.MAX_VALUE / 8);
216216
tensor = Utils.transpose(tensor);
217217
long[] tensorShape = tensor.dimensionsAsLongArray();
218218
int size = 1;
@@ -248,9 +248,9 @@ private static Tensor<TFloat32> buildFloat(
248248
throws IllegalArgumentException
249249
{
250250
long[] ogShape = tensor.dimensionsAsLongArray();
251-
if (CommonUtils.int32Overflows(ogShape))
251+
if (CommonUtils.int32Overflows(ogShape, 4))
252252
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
253-
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
253+
+ " is too big. Max number of elements per float tensor supported: " + Integer.MAX_VALUE / 4);
254254
tensor = Utils.transpose(tensor);
255255
long[] tensorShape = tensor.dimensionsAsLongArray();
256256
int size = 1;
@@ -286,9 +286,9 @@ private static Tensor<TFloat64> buildDouble(
286286
throws IllegalArgumentException
287287
{
288288
long[] ogShape = tensor.dimensionsAsLongArray();
289-
if (CommonUtils.int32Overflows(ogShape))
289+
if (CommonUtils.int32Overflows(ogShape, 8))
290290
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
291-
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
291+
+ " is too big. Max number of elements per double tensor supported: " + Integer.MAX_VALUE / 8);
292292
tensor = Utils.transpose(tensor);
293293
long[] tensorShape = tensor.dimensionsAsLongArray();
294294
int size = 1;

0 commit comments

Comments
 (0)