28
28
import java .nio .DoubleBuffer ;
29
29
import java .nio .FloatBuffer ;
30
30
import java .nio .IntBuffer ;
31
+ import java .util .Arrays ;
31
32
33
+ import net .imglib2 .Cursor ;
32
34
import net .imglib2 .RandomAccessibleInterval ;
33
35
import net .imglib2 .blocks .PrimitiveBlocks ;
34
36
import net .imglib2 .img .Img ;
38
40
import net .imglib2 .type .numeric .real .DoubleType ;
39
41
import net .imglib2 .type .numeric .real .FloatType ;
40
42
import net .imglib2 .util .Util ;
41
-
43
+ import net . imglib2 . view . Views ;
42
44
import ai .onnxruntime .OnnxTensor ;
43
45
import ai .onnxruntime .OrtEnvironment ;
44
46
import ai .onnxruntime .OrtException ;
@@ -122,18 +124,25 @@ public static <T extends Type<T>> OnnxTensor build(RandomAccessibleInterval<T> r
122
124
*/
123
125
private static OnnxTensor buildByte (RandomAccessibleInterval <ByteType > tensor , OrtEnvironment env ) throws OrtException
124
126
{
127
+ long [] ogShape = tensor .dimensionsAsLongArray ();
128
+ if (CommonUtils .int32Overflows (ogShape ))
129
+ throw new IllegalArgumentException ("Provided tensor with shape " + Arrays .toString (ogShape )
130
+ + " is too big. Max number of elements per tensor supported: " + Integer .MAX_VALUE );
125
131
tensor = Utils .transpose (tensor );
126
- PrimitiveBlocks < ByteType > blocks = PrimitiveBlocks .of ( tensor );
127
132
long [] tensorShape = tensor .dimensionsAsLongArray ();
128
- if (CommonUtils .int32Overflows (tensorShape ))
129
- throw new IllegalArgumentException ("Tensor is too big to handle. Max number of elements allowed in a tensor: " + Integer .MAX_VALUE );
130
133
int size = 1 ;
131
134
for (long ll : tensorShape ) size *= ll ;
132
135
final byte [] flatArr = new byte [size ];
133
136
int [] sArr = new int [tensorShape .length ];
134
137
for (int i = 0 ; i < sArr .length ; i ++)
135
138
sArr [i ] = (int ) tensorShape [i ];
136
- blocks .copy ( tensor .minAsLongArray (), flatArr , sArr );
139
+
140
+ Cursor <ByteType > cursor = Views .flatIterable (tensor ).cursor ();
141
+ int i = 0 ;
142
+ while (cursor .hasNext ()) {
143
+ cursor .fwd ();
144
+ flatArr [i ++] = cursor .get ().getByte ();
145
+ }
137
146
ByteBuffer buff = ByteBuffer .wrap (flatArr );
138
147
OnnxTensor ndarray = OnnxTensor .createTensor (env , buff , tensorShape );
139
148
return ndarray ;
@@ -153,18 +162,25 @@ private static OnnxTensor buildByte(RandomAccessibleInterval<ByteType> tensor, O
153
162
*/
154
163
private static OnnxTensor buildInt (RandomAccessibleInterval <IntType > tensor , OrtEnvironment env ) throws OrtException
155
164
{
165
+ long [] ogShape = tensor .dimensionsAsLongArray ();
166
+ if (CommonUtils .int32Overflows (ogShape ))
167
+ throw new IllegalArgumentException ("Provided tensor with shape " + Arrays .toString (ogShape )
168
+ + " is too big. Max number of elements per tensor supported: " + Integer .MAX_VALUE );
156
169
tensor = Utils .transpose (tensor );
157
- PrimitiveBlocks < IntType > blocks = PrimitiveBlocks .of ( tensor );
158
170
long [] tensorShape = tensor .dimensionsAsLongArray ();
159
- if (CommonUtils .int32Overflows (tensorShape ))
160
- throw new IllegalArgumentException ("Tensor is too big to handle. Max number of elements allowed in a tensor: " + Integer .MAX_VALUE );
161
171
int size = 1 ;
162
172
for (long ll : tensorShape ) size *= ll ;
163
173
final int [] flatArr = new int [size ];
164
174
int [] sArr = new int [tensorShape .length ];
165
175
for (int i = 0 ; i < sArr .length ; i ++)
166
176
sArr [i ] = (int ) tensorShape [i ];
167
- blocks .copy ( tensor .minAsLongArray (), flatArr , sArr );
177
+
178
+ Cursor <IntType > cursor = Views .flatIterable (tensor ).cursor ();
179
+ int i = 0 ;
180
+ while (cursor .hasNext ()) {
181
+ cursor .fwd ();
182
+ flatArr [i ++] = cursor .get ().get ();
183
+ }
168
184
IntBuffer buff = IntBuffer .wrap (flatArr );
169
185
OnnxTensor ndarray = OnnxTensor .createTensor (env , buff , tensorShape );
170
186
return ndarray ;
@@ -184,20 +200,27 @@ private static OnnxTensor buildInt(RandomAccessibleInterval<IntType> tensor, Ort
184
200
*/
185
201
private static OnnxTensor buildFloat (RandomAccessibleInterval <FloatType > tensor , OrtEnvironment env ) throws OrtException
186
202
{
203
+ long [] ogShape = tensor .dimensionsAsLongArray ();
204
+ if (CommonUtils .int32Overflows (ogShape ))
205
+ throw new IllegalArgumentException ("Provided tensor with shape " + Arrays .toString (ogShape )
206
+ + " is too big. Max number of elements per tensor supported: " + Integer .MAX_VALUE );
187
207
tensor = Utils .transpose (tensor );
188
- PrimitiveBlocks < FloatType > blocks = PrimitiveBlocks .of ( tensor );
189
208
long [] tensorShape = tensor .dimensionsAsLongArray ();
190
- if (CommonUtils .int32Overflows (tensorShape ))
191
- throw new IllegalArgumentException ("Tensor is too big to handle. Max number of elements allowed in a tensor: " + Integer .MAX_VALUE );
192
209
int size = 1 ;
193
210
for (long ll : tensorShape ) size *= ll ;
194
211
final float [] flatArr = new float [size ];
195
212
int [] sArr = new int [tensorShape .length ];
196
213
for (int i = 0 ; i < sArr .length ; i ++)
197
214
sArr [i ] = (int ) tensorShape [i ];
198
- blocks .copy ( tensor .minAsLongArray (), flatArr , sArr );
215
+
216
+ Cursor <FloatType > cursor = Views .flatIterable (tensor ).cursor ();
217
+ int i = 0 ;
218
+ while (cursor .hasNext ()) {
219
+ cursor .fwd ();
220
+ flatArr [i ++] = cursor .get ().get ();
221
+ }
199
222
FloatBuffer buff = FloatBuffer .wrap (flatArr );
200
- OnnxTensor ndarray = OnnxTensor .createTensor (env , buff , tensorShape );
223
+ OnnxTensor ndarray = OnnxTensor .createTensor (env , buff , ogShape );
201
224
return ndarray ;
202
225
}
203
226
@@ -215,18 +238,25 @@ private static OnnxTensor buildFloat(RandomAccessibleInterval<FloatType> tensor,
215
238
*/
216
239
private static OnnxTensor buildDouble (RandomAccessibleInterval <DoubleType > tensor , OrtEnvironment env ) throws OrtException
217
240
{
241
+ long [] ogShape = tensor .dimensionsAsLongArray ();
242
+ if (CommonUtils .int32Overflows (ogShape ))
243
+ throw new IllegalArgumentException ("Provided tensor with shape " + Arrays .toString (ogShape )
244
+ + " is too big. Max number of elements per tensor supported: " + Integer .MAX_VALUE );
218
245
tensor = Utils .transpose (tensor );
219
- PrimitiveBlocks < DoubleType > blocks = PrimitiveBlocks .of ( tensor );
220
246
long [] tensorShape = tensor .dimensionsAsLongArray ();
221
- if (CommonUtils .int32Overflows (tensorShape ))
222
- throw new IllegalArgumentException ("Tensor is too big to handle. Max number of elements allowed in a tensor: " + Integer .MAX_VALUE );
223
247
int size = 1 ;
224
248
for (long ll : tensorShape ) size *= ll ;
225
249
final double [] flatArr = new double [size ];
226
250
int [] sArr = new int [tensorShape .length ];
227
251
for (int i = 0 ; i < sArr .length ; i ++)
228
252
sArr [i ] = (int ) tensorShape [i ];
229
- blocks .copy ( tensor .minAsLongArray (), flatArr , sArr );
253
+
254
+ Cursor <DoubleType > cursor = Views .flatIterable (tensor ).cursor ();
255
+ int i = 0 ;
256
+ while (cursor .hasNext ()) {
257
+ cursor .fwd ();
258
+ flatArr [i ++] = cursor .get ().get ();
259
+ }
230
260
DoubleBuffer buff = DoubleBuffer .wrap (flatArr );
231
261
OnnxTensor ndarray = OnnxTensor .createTensor (env , buff , tensorShape );
232
262
return ndarray ;
0 commit comments