34
34
import net .imglib2 .type .numeric .real .DoubleType ;
35
35
import net .imglib2 .type .numeric .real .FloatType ;
36
36
import net .imglib2 .util .Cast ;
37
+
38
+ import java .util .Arrays ;
39
+
37
40
import ai .djl .ndarray .NDArray ;
38
41
import io .bioimage .modelrunner .tensor .Utils ;
42
+ import io .bioimage .modelrunner .utils .CommonUtils ;
39
43
40
44
/**
41
45
* A {@link RandomAccessibleInterval} builder for Pytorch {@link ai.djl.ndarray.NDArray} objects.
@@ -89,6 +93,9 @@ public static < T extends RealType< T > & NativeType< T > > RandomAccessibleInte
89
93
*/
90
94
private static RandomAccessibleInterval <UnsignedByteType > buildFromTensorUByte (NDArray tensor ) {
91
95
long [] arrayShape = tensor .getShape ().getShape ();
96
+ if (CommonUtils .int32Overflows (arrayShape , 1 ))
97
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
98
+ + " is too big. Max number of elements per output tensor supported: " + Integer .MAX_VALUE );
92
99
long [] tensorShape = new long [arrayShape .length ];
93
100
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
94
101
byte [] flatArr = tensor .toByteArray ();
@@ -105,6 +112,9 @@ private static RandomAccessibleInterval<UnsignedByteType> buildFromTensorUByte(N
105
112
*/
106
113
private static RandomAccessibleInterval <ByteType > buildFromTensorByte (NDArray tensor ) {
107
114
long [] arrayShape = tensor .getShape ().getShape ();
115
+ if (CommonUtils .int32Overflows (arrayShape , 1 ))
116
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
117
+ + " is too big. Max number of elements per output tensor supported: " + Integer .MAX_VALUE );
108
118
long [] tensorShape = new long [arrayShape .length ];
109
119
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
110
120
byte [] flatArr = tensor .toByteArray ();
@@ -121,6 +131,9 @@ private static RandomAccessibleInterval<ByteType> buildFromTensorByte(NDArray te
121
131
*/
122
132
private static RandomAccessibleInterval <IntType > buildFromTensorInt (NDArray tensor ) {
123
133
long [] arrayShape = tensor .getShape ().getShape ();
134
+ if (CommonUtils .int32Overflows (arrayShape , 1 ))
135
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
136
+ + " is too big. Max number of elements per output tensor supported: " + Integer .MAX_VALUE );
124
137
long [] tensorShape = new long [arrayShape .length ];
125
138
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
126
139
int [] flatArr = tensor .toIntArray ();
@@ -137,6 +150,9 @@ private static RandomAccessibleInterval<IntType> buildFromTensorInt(NDArray tens
137
150
*/
138
151
private static RandomAccessibleInterval <FloatType > buildFromTensorFloat (NDArray tensor ) {
139
152
long [] arrayShape = tensor .getShape ().getShape ();
153
+ if (CommonUtils .int32Overflows (arrayShape , 1 ))
154
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
155
+ + " is too big. Max number of elements per output tensor supported: " + Integer .MAX_VALUE );
140
156
long [] tensorShape = new long [arrayShape .length ];
141
157
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
142
158
float [] flatArr = tensor .toFloatArray ();
@@ -153,6 +169,9 @@ private static RandomAccessibleInterval<FloatType> buildFromTensorFloat(NDArray
153
169
*/
154
170
private static RandomAccessibleInterval <DoubleType > buildFromTensorDouble (NDArray tensor ) {
155
171
long [] arrayShape = tensor .getShape ().getShape ();
172
+ if (CommonUtils .int32Overflows (arrayShape , 1 ))
173
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
174
+ + " is too big. Max number of elements per output tensor supported: " + Integer .MAX_VALUE );
156
175
long [] tensorShape = new long [arrayShape .length ];
157
176
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
158
177
double [] flatArr = tensor .toDoubleArray ();
@@ -169,6 +188,9 @@ private static RandomAccessibleInterval<DoubleType> buildFromTensorDouble(NDArra
169
188
*/
170
189
private static RandomAccessibleInterval <LongType > buildFromTensorLong (NDArray tensor ) {
171
190
long [] arrayShape = tensor .getShape ().getShape ();
191
+ if (CommonUtils .int32Overflows (arrayShape , 1 ))
192
+ throw new IllegalArgumentException ("Model output tensor with shape " + Arrays .toString (arrayShape )
193
+ + " is too big. Max number of elements per output tensor supported: " + Integer .MAX_VALUE );
172
194
long [] tensorShape = new long [arrayShape .length ];
173
195
for (int i = 0 ; i < arrayShape .length ; i ++) tensorShape [i ] = arrayShape [arrayShape .length - 1 - i ];
174
196
long [] flatArr = tensor .toLongArray ();
0 commit comments