2121package io .bioimage .modelrunner .tensorflow .v2 .api050 .tensor ;
2222
2323import io .bioimage .modelrunner .tensor .Utils ;
24-
24+ import io . bioimage . modelrunner . utils . CommonUtils ;
2525import net .imglib2 .RandomAccessibleInterval ;
2626import net .imglib2 .img .array .ArrayImgs ;
2727import net .imglib2 .type .Type ;
3131import net .imglib2 .type .numeric .real .DoubleType ;
3232import net .imglib2 .type .numeric .real .FloatType ;
3333
34+ import java .util .Arrays ;
35+
3436import org .tensorflow .Tensor ;
3537import org .tensorflow .types .TFloat32 ;
3638import org .tensorflow .types .TFloat64 ;
@@ -103,6 +105,9 @@ else if (tensor instanceof TInt64)
103105 private static RandomAccessibleInterval <UnsignedByteType > buildFromTensorUByte (TUint8 tensor )
104106 {
105107 long [] arrayShape = tensor .shape ().asArray ();
108+ if (CommonUtils .int32Overflows (arrayShape , 1 ))
109+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
110+ + " is too big. Max number of elements per ubyte output tensor supported: " + Integer .MAX_VALUE / 1 );
106111 long [] tensorShape = new long [arrayShape .length ];
107112 for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
108113 int totalSize = 1 ;
@@ -123,6 +128,9 @@ private static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(T
123128 private static RandomAccessibleInterval <IntType > buildFromTensorInt (TInt32 tensor )
124129 {
125130 long [] arrayShape = tensor .shape ().asArray ();
131+ if (CommonUtils .int32Overflows (arrayShape , 4 ))
132+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
133+ + " is too big. Max number of elements per int output tensor supported: " + Integer .MAX_VALUE / 4 );
126134 long [] tensorShape = new long [arrayShape .length ];
127135 for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
128136 int totalSize = 1 ;
@@ -143,6 +151,9 @@ private static RandomAccessibleInterval<IntType> buildFromTensorInt(TInt32 tenso
143151 private static RandomAccessibleInterval <FloatType > buildFromTensorFloat (TFloat32 tensor )
144152 {
145153 long [] arrayShape = tensor .shape ().asArray ();
154+ if (CommonUtils .int32Overflows (arrayShape , 4 ))
155+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
156+ + " is too big. Max number of elements per float output tensor supported: " + Integer .MAX_VALUE / 4 );
146157 long [] tensorShape = new long [arrayShape .length ];
147158 for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
148159 int totalSize = 1 ;
@@ -163,6 +174,9 @@ private static RandomAccessibleInterval<FloatType> buildFromTensorFloat(TFloat32
163174 private static RandomAccessibleInterval <DoubleType > buildFromTensorDouble (TFloat64 tensor )
164175 {
165176 long [] arrayShape = tensor .shape ().asArray ();
177+ if (CommonUtils .int32Overflows (arrayShape , 8 ))
178+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
179+ + " is too big. Max number of elements per double output tensor supported: " + Integer .MAX_VALUE / 8 );
166180 long [] tensorShape = new long [arrayShape .length ];
167181 for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
168182 int totalSize = 1 ;
@@ -183,6 +197,9 @@ private static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(TFloat
183197 private static RandomAccessibleInterval <LongType > buildFromTensorLong (TInt64 tensor )
184198 {
185199 long [] arrayShape = tensor .shape ().asArray ();
200+ if (CommonUtils .int32Overflows (arrayShape , 8 ))
201+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
202+ + " is too big. Max number of elements per long output tensor supported: " + Integer .MAX_VALUE / 8 );
186203 long [] tensorShape = new long [arrayShape .length ];
187204 for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
188205 int totalSize = 1 ;
0 commit comments