Skip to content

Commit 5893e56

Browse files
committed
improve robustness
1 parent 8abcf8b commit 5893e56

File tree

2 files changed

+28
-11
lines changed

2 files changed

+28
-11
lines changed

src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/tensor/ImgLib2Builder.java

+18-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
package io.bioimage.modelrunner.tensorflow.v2.api050.tensor;
2222

2323
import io.bioimage.modelrunner.tensor.Utils;
24-
24+
import io.bioimage.modelrunner.utils.CommonUtils;
2525
import net.imglib2.RandomAccessibleInterval;
2626
import net.imglib2.img.array.ArrayImgs;
2727
import net.imglib2.type.Type;
@@ -31,6 +31,8 @@
3131
import net.imglib2.type.numeric.real.DoubleType;
3232
import net.imglib2.type.numeric.real.FloatType;
3333

34+
import java.util.Arrays;
35+
3436
import org.tensorflow.Tensor;
3537
import org.tensorflow.types.TFloat32;
3638
import org.tensorflow.types.TFloat64;
@@ -103,6 +105,9 @@ else if (tensor instanceof TInt64)
103105
private static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(TUint8 tensor)
104106
{
105107
long[] arrayShape = tensor.shape().asArray();
108+
if (CommonUtils.int32Overflows(arrayShape, 1))
109+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
110+
+ " is too big. Max number of elements per ubyte output tensor supported: " + Integer.MAX_VALUE / 1);
106111
long[] tensorShape = new long[arrayShape.length];
107112
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
108113
int totalSize = 1;
@@ -123,6 +128,9 @@ private static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(T
123128
private static RandomAccessibleInterval<IntType> buildFromTensorInt(TInt32 tensor)
124129
{
125130
long[] arrayShape = tensor.shape().asArray();
131+
if (CommonUtils.int32Overflows(arrayShape, 4))
132+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
133+
+ " is too big. Max number of elements per int output tensor supported: " + Integer.MAX_VALUE / 4);
126134
long[] tensorShape = new long[arrayShape.length];
127135
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
128136
int totalSize = 1;
@@ -143,6 +151,9 @@ private static RandomAccessibleInterval<IntType> buildFromTensorInt(TInt32 tenso
143151
private static RandomAccessibleInterval<FloatType> buildFromTensorFloat(TFloat32 tensor)
144152
{
145153
long[] arrayShape = tensor.shape().asArray();
154+
if (CommonUtils.int32Overflows(arrayShape, 4))
155+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
156+
+ " is too big. Max number of elements per float output tensor supported: " + Integer.MAX_VALUE / 4);
146157
long[] tensorShape = new long[arrayShape.length];
147158
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
148159
int totalSize = 1;
@@ -163,6 +174,9 @@ private static RandomAccessibleInterval<FloatType> buildFromTensorFloat(TFloat32
163174
private static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(TFloat64 tensor)
164175
{
165176
long[] arrayShape = tensor.shape().asArray();
177+
if (CommonUtils.int32Overflows(arrayShape, 8))
178+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
179+
+ " is too big. Max number of elements per double output tensor supported: " + Integer.MAX_VALUE / 8);
166180
long[] tensorShape = new long[arrayShape.length];
167181
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
168182
int totalSize = 1;
@@ -183,6 +197,9 @@ private static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(TFloat
183197
private static RandomAccessibleInterval<LongType> buildFromTensorLong(TInt64 tensor)
184198
{
185199
long[] arrayShape = tensor.shape().asArray();
200+
if (CommonUtils.int32Overflows(arrayShape, 8))
201+
throw new IllegalArgumentException("Model output tensor with shape " + Arrays.toString(arrayShape)
202+
+ " is too big. Max number of elements per long output tensor supported: " + Integer.MAX_VALUE / 8);
186203
long[] tensorShape = new long[arrayShape.length];
187204
for (int i = 0; i < arrayShape.length; i ++) tensorShape[i] = arrayShape[arrayShape.length - 1 - i];
188205
int totalSize = 1;

src/main/java/io/bioimage/modelrunner/tensorflow/v2/api050/tensor/TensorBuilder.java

+10-10
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,9 @@ public static TUint8 buildUByte(RandomAccessibleInterval<UnsignedByteType> tenso
134134
throws IllegalArgumentException
135135
{
136136
long[] ogShape = tensor.dimensionsAsLongArray();
137-
if (CommonUtils.int32Overflows(ogShape))
137+
if (CommonUtils.int32Overflows(ogShape, 1))
138138
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
139-
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
139+
+ " is too big. Max number of elements per ubyte tensor supported: " + Integer.MAX_VALUE);
140140
tensor = Utils.transpose(tensor);
141141
long[] tensorShape = tensor.dimensionsAsLongArray();
142142
int size = 1;
@@ -171,9 +171,9 @@ public static TInt32 buildInt(RandomAccessibleInterval<IntType> tensor)
171171
throws IllegalArgumentException
172172
{
173173
long[] ogShape = tensor.dimensionsAsLongArray();
174-
if (CommonUtils.int32Overflows(ogShape))
174+
if (CommonUtils.int32Overflows(ogShape, 4))
175175
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
176-
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
176+
+ " is too big. Max number of elements per int tensor supported: " + Integer.MAX_VALUE / 4);
177177
tensor = Utils.transpose(tensor);
178178
long[] tensorShape = tensor.dimensionsAsLongArray();
179179
int size = 1;
@@ -209,9 +209,9 @@ private static TInt64 buildLong(RandomAccessibleInterval<LongType> tensor)
209209
throws IllegalArgumentException
210210
{
211211
long[] ogShape = tensor.dimensionsAsLongArray();
212-
if (CommonUtils.int32Overflows(ogShape))
212+
if (CommonUtils.int32Overflows(ogShape, 8))
213213
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
214-
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
214+
+ " is too big. Max number of elements per long tensor supported: " + Integer.MAX_VALUE / 8);
215215
tensor = Utils.transpose(tensor);
216216
long[] tensorShape = tensor.dimensionsAsLongArray();
217217
int size = 1;
@@ -248,9 +248,9 @@ public static TFloat32 buildFloat(
248248
throws IllegalArgumentException
249249
{
250250
long[] ogShape = tensor.dimensionsAsLongArray();
251-
if (CommonUtils.int32Overflows(ogShape))
251+
if (CommonUtils.int32Overflows(ogShape, 4))
252252
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
253-
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
253+
+ " is too big. Max number of elements per float tensor supported: " + Integer.MAX_VALUE / 4);
254254
tensor = Utils.transpose(tensor);
255255
long[] tensorShape = tensor.dimensionsAsLongArray();
256256
int size = 1;
@@ -286,9 +286,9 @@ private static TFloat64 buildDouble(
286286
throws IllegalArgumentException
287287
{
288288
long[] ogShape = tensor.dimensionsAsLongArray();
289-
if (CommonUtils.int32Overflows(ogShape))
289+
if (CommonUtils.int32Overflows(ogShape, 8))
290290
throw new IllegalArgumentException("Provided tensor with shape " + Arrays.toString(ogShape)
291-
+ " is too big. Max number of elements per tensor supported: " + Integer.MAX_VALUE);
291+
+ " is too big. Max number of elements per double tensor supported: " + Integer.MAX_VALUE / 8);
292292
tensor = Utils.transpose(tensor);
293293
long[] tensorShape = tensor.dimensionsAsLongArray();
294294
int size = 1;

0 commit comments

Comments
 (0)