22
22
23
23
24
24
import io .bioimage .modelrunner .tensor .Utils ;
25
-
25
+ import io . bioimage . modelrunner . utils . CommonUtils ;
26
26
import net .imglib2 .RandomAccessibleInterval ;
27
27
import net .imglib2 .img .array .ArrayImgs ;
28
28
import net .imglib2 .type .Type ;
32
32
import net .imglib2 .type .numeric .real .DoubleType ;
33
33
import net .imglib2 .type .numeric .real .FloatType ;
34
34
35
+ import java .util .Arrays ;
36
+
35
37
import org .tensorflow .Tensor ;
36
38
import org .tensorflow .types .TFloat32 ;
37
39
import org .tensorflow .types .TFloat64 ;
@@ -95,6 +97,9 @@ public static <T extends Type<T>> RandomAccessibleInterval<T> build(Tensor<? ext
95
97
private static RandomAccessibleInterval <UnsignedByteType > buildFromTensorUByte (Tensor <TUint8 > tensor )
96
98
{
97
99
long [] arrayShape = tensor .shape ().asArray ();
100
+ if (CommonUtils .int32Overflows (arrayShape , 1 ))
101
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
102
+ + " is too big. Max number of elements per double ubyte tensor supported: " + Integer .MAX_VALUE / 1 );
98
103
long [] tensorShape = new long [arrayShape .length ];
99
104
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
100
105
int totalSize = 1 ;
@@ -115,6 +120,9 @@ private static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(T
115
120
private static RandomAccessibleInterval <IntType > buildFromTensorInt (Tensor <TInt32 > tensor )
116
121
{
117
122
long [] arrayShape = tensor .shape ().asArray ();
123
+ if (CommonUtils .int32Overflows (arrayShape , 4 ))
124
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
125
+ + " is too big. Max number of elements per int output tensor supported: " + Integer .MAX_VALUE / 4 );
118
126
long [] tensorShape = new long [arrayShape .length ];
119
127
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
120
128
int totalSize = 1 ;
@@ -135,6 +143,9 @@ private static RandomAccessibleInterval<IntType> buildFromTensorInt(Tensor<TInt3
135
143
private static RandomAccessibleInterval <FloatType > buildFromTensorFloat (Tensor <TFloat32 > tensor )
136
144
{
137
145
long [] arrayShape = tensor .shape ().asArray ();
146
+ if (CommonUtils .int32Overflows (arrayShape , 4 ))
147
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
148
+ + " is too big. Max number of elements per float output tensor supported: " + Integer .MAX_VALUE / 4 );
138
149
long [] tensorShape = new long [arrayShape .length ];
139
150
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
140
151
int totalSize = 1 ;
@@ -155,6 +166,9 @@ private static RandomAccessibleInterval<FloatType> buildFromTensorFloat(Tensor<T
155
166
private static RandomAccessibleInterval <DoubleType > buildFromTensorDouble (Tensor <TFloat64 > tensor )
156
167
{
157
168
long [] arrayShape = tensor .shape ().asArray ();
169
+ if (CommonUtils .int32Overflows (arrayShape , 8 ))
170
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
171
+ + " is too big. Max number of elements per double output tensor supported: " + Integer .MAX_VALUE / 8 );
158
172
long [] tensorShape = new long [arrayShape .length ];
159
173
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
160
174
int totalSize = 1 ;
@@ -175,6 +189,9 @@ private static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(Tensor
175
189
private static RandomAccessibleInterval <LongType > buildFromTensorLong (Tensor <TInt64 > tensor )
176
190
{
177
191
long [] arrayShape = tensor .shape ().asArray ();
192
+ if (CommonUtils .int32Overflows (arrayShape , 8 ))
193
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
194
+ + " is too big. Max number of elements per long output tensor supported: " + Integer .MAX_VALUE / 8 );
178
195
long [] tensorShape = new long [arrayShape .length ];
179
196
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
180
197
int totalSize = 1 ;
0 commit comments