21
21
package io .bioimage .modelrunner .tensorflow .v2 .api050 .tensor ;
22
22
23
23
import io .bioimage .modelrunner .tensor .Utils ;
24
-
24
+ import io . bioimage . modelrunner . utils . CommonUtils ;
25
25
import net .imglib2 .RandomAccessibleInterval ;
26
26
import net .imglib2 .img .array .ArrayImgs ;
27
27
import net .imglib2 .type .Type ;
31
31
import net .imglib2 .type .numeric .real .DoubleType ;
32
32
import net .imglib2 .type .numeric .real .FloatType ;
33
33
34
+ import java .util .Arrays ;
35
+
34
36
import org .tensorflow .Tensor ;
35
37
import org .tensorflow .types .TFloat32 ;
36
38
import org .tensorflow .types .TFloat64 ;
@@ -103,6 +105,9 @@ else if (tensor instanceof TInt64)
103
105
private static RandomAccessibleInterval <UnsignedByteType > buildFromTensorUByte (TUint8 tensor )
104
106
{
105
107
long [] arrayShape = tensor .shape ().asArray ();
108
+ if (CommonUtils .int32Overflows (arrayShape , 1 ))
109
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
110
+ + " is too big. Max number of elements per ubyte output tensor supported: " + Integer .MAX_VALUE / 1 );
106
111
long [] tensorShape = new long [arrayShape .length ];
107
112
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
108
113
int totalSize = 1 ;
@@ -123,6 +128,9 @@ private static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(T
123
128
private static RandomAccessibleInterval <IntType > buildFromTensorInt (TInt32 tensor )
124
129
{
125
130
long [] arrayShape = tensor .shape ().asArray ();
131
+ if (CommonUtils .int32Overflows (arrayShape , 4 ))
132
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
133
+ + " is too big. Max number of elements per int output tensor supported: " + Integer .MAX_VALUE / 4 );
126
134
long [] tensorShape = new long [arrayShape .length ];
127
135
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
128
136
int totalSize = 1 ;
@@ -143,6 +151,9 @@ private static RandomAccessibleInterval<IntType> buildFromTensorInt(TInt32 tenso
143
151
private static RandomAccessibleInterval <FloatType > buildFromTensorFloat (TFloat32 tensor )
144
152
{
145
153
long [] arrayShape = tensor .shape ().asArray ();
154
+ if (CommonUtils .int32Overflows (arrayShape , 4 ))
155
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
156
+ + " is too big. Max number of elements per float output tensor supported: " + Integer .MAX_VALUE / 4 );
146
157
long [] tensorShape = new long [arrayShape .length ];
147
158
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
148
159
int totalSize = 1 ;
@@ -163,6 +174,9 @@ private static RandomAccessibleInterval<FloatType> buildFromTensorFloat(TFloat32
163
174
private static RandomAccessibleInterval <DoubleType > buildFromTensorDouble (TFloat64 tensor )
164
175
{
165
176
long [] arrayShape = tensor .shape ().asArray ();
177
+ if (CommonUtils .int32Overflows (arrayShape , 8 ))
178
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
179
+ + " is too big. Max number of elements per double output tensor supported: " + Integer .MAX_VALUE / 8 );
166
180
long [] tensorShape = new long [arrayShape .length ];
167
181
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
168
182
int totalSize = 1 ;
@@ -183,6 +197,9 @@ private static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(TFloat
183
197
private static RandomAccessibleInterval <LongType > buildFromTensorLong (TInt64 tensor )
184
198
{
185
199
long [] arrayShape = tensor .shape ().asArray ();
200
+ if (CommonUtils .int32Overflows (arrayShape , 8 ))
201
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
202
+ + " is too big. Max number of elements per long output tensor supported: " + Integer .MAX_VALUE / 8 );
186
203
long [] tensorShape = new long [arrayShape .length ];
187
204
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
188
205
int totalSize = 1 ;
0 commit comments