20
20
*/
21
21
package io .bioimage .modelrunner .pytorch .javacpp .tensor ;
22
22
23
+ import java .nio .ByteBuffer ;
24
+ import java .util .Arrays ;
25
+
23
26
import io .bioimage .modelrunner .tensor .Tensor ;
24
27
import io .bioimage .modelrunner .tensor .Utils ;
28
+ import net .imglib2 .Cursor ;
25
29
import net .imglib2 .RandomAccessibleInterval ;
26
30
import net .imglib2 .blocks .PrimitiveBlocks ;
27
31
import net .imglib2 .type .Type ;
28
32
import net .imglib2 .type .numeric .integer .ByteType ;
29
33
import net .imglib2 .type .numeric .integer .IntType ;
34
+ import net .imglib2 .type .numeric .integer .UnsignedByteType ;
30
35
import net .imglib2 .type .numeric .real .DoubleType ;
31
36
import net .imglib2 .type .numeric .real .FloatType ;
32
37
import net .imglib2 .util .Util ;
38
+ import net .imglib2 .view .Views ;
33
39
34
40
/**
35
41
* Class that manages the creation of JAvaCPP Pytorch tensors
@@ -101,16 +107,24 @@ public static <T extends Type<T>> org.bytedeco.pytorch.Tensor build(RandomAccess
101
107
private static org .bytedeco .pytorch .Tensor buildFromTensorByte (RandomAccessibleInterval <ByteType > tensor )
102
108
{
103
109
long [] ogShape = tensor .dimensionsAsLongArray ();
110
+ if (CommonUtils .int32Overflows (ogShape ))
111
+ throw new IllegalArgumentException ("Provided tensor with shape " + Arrays .toString (ogShape )
112
+ + " is too big. Max number of elements per tensor supported: " + Integer .MAX_VALUE );
104
113
tensor = Utils .transpose (tensor );
105
- PrimitiveBlocks < ByteType > blocks = PrimitiveBlocks .of ( tensor );
106
114
long [] tensorShape = tensor .dimensionsAsLongArray ();
107
115
int size = 1 ;
108
116
for (long ll : tensorShape ) size *= ll ;
109
117
final byte [] flatArr = new byte [size ];
110
118
int [] sArr = new int [tensorShape .length ];
111
119
for (int i = 0 ; i < sArr .length ; i ++)
112
120
sArr [i ] = (int ) tensorShape [i ];
113
- blocks .copy ( tensor .minAsLongArray (), flatArr , sArr );
121
+
122
+ Cursor <ByteType > cursor = Views .flatIterable (tensor ).cursor ();
123
+ int i = 0 ;
124
+ while (cursor .hasNext ()) {
125
+ cursor .fwd ();
126
+ flatArr [i ++] = cursor .get ().get ();
127
+ }
114
128
org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , ogShape );
115
129
return ndarray ;
116
130
}
@@ -126,17 +140,25 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleI
126
140
private static org .bytedeco .pytorch .Tensor buildFromTensorInt (RandomAccessibleInterval <IntType > tensor )
127
141
{
128
142
long [] ogShape = tensor .dimensionsAsLongArray ();
143
+ if (CommonUtils .int32Overflows (ogShape ))
144
+ throw new IllegalArgumentException ("Provided tensor with shape " + Arrays .toString (ogShape )
145
+ + " is too big. Max number of elements per tensor supported: " + Integer .MAX_VALUE );
129
146
tensor = Utils .transpose (tensor );
130
- PrimitiveBlocks < IntType > blocks = PrimitiveBlocks .of ( tensor );
131
147
long [] tensorShape = tensor .dimensionsAsLongArray ();
132
148
int size = 1 ;
133
149
for (long ll : tensorShape ) size *= ll ;
134
150
final int [] flatArr = new int [size ];
135
151
int [] sArr = new int [tensorShape .length ];
136
152
for (int i = 0 ; i < sArr .length ; i ++)
137
153
sArr [i ] = (int ) tensorShape [i ];
138
- blocks .copy ( tensor .minAsLongArray (), flatArr , sArr );
139
- org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , ogShape );
154
+
155
+ Cursor <IntType > cursor = Views .flatIterable (tensor ).cursor ();
156
+ int i = 0 ;
157
+ while (cursor .hasNext ()) {
158
+ cursor .fwd ();
159
+ flatArr [i ++] = cursor .get ().get ();
160
+ }
161
+ org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , ogShape );
140
162
return ndarray ;
141
163
}
142
164
@@ -151,17 +173,25 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleIn
151
173
private static org .bytedeco .pytorch .Tensor buildFromTensorFloat (RandomAccessibleInterval <FloatType > tensor )
152
174
{
153
175
long [] ogShape = tensor .dimensionsAsLongArray ();
176
+ if (CommonUtils .int32Overflows (ogShape ))
177
+ throw new IllegalArgumentException ("Provided tensor with shape " + Arrays .toString (ogShape )
178
+ + " is too big. Max number of elements per tensor supported: " + Integer .MAX_VALUE );
154
179
tensor = Utils .transpose (tensor );
155
- PrimitiveBlocks < FloatType > blocks = PrimitiveBlocks .of ( tensor );
156
180
long [] tensorShape = tensor .dimensionsAsLongArray ();
157
181
int size = 1 ;
158
182
for (long ll : tensorShape ) size *= ll ;
159
183
final float [] flatArr = new float [size ];
160
184
int [] sArr = new int [tensorShape .length ];
161
185
for (int i = 0 ; i < sArr .length ; i ++)
162
186
sArr [i ] = (int ) tensorShape [i ];
163
- blocks .copy ( tensor .minAsLongArray (), flatArr , sArr );
164
- org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , ogShape );
187
+
188
+ Cursor <FloatType > cursor = Views .flatIterable (tensor ).cursor ();
189
+ int i = 0 ;
190
+ while (cursor .hasNext ()) {
191
+ cursor .fwd ();
192
+ flatArr [i ++] = cursor .get ().get ();
193
+ }
194
+ org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , ogShape );
165
195
return ndarray ;
166
196
}
167
197
@@ -176,17 +206,25 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessible
176
206
private static org .bytedeco .pytorch .Tensor buildFromTensorDouble (RandomAccessibleInterval <DoubleType > tensor )
177
207
{
178
208
long [] ogShape = tensor .dimensionsAsLongArray ();
209
+ if (CommonUtils .int32Overflows (ogShape ))
210
+ throw new IllegalArgumentException ("Provided tensor with shape " + Arrays .toString (ogShape )
211
+ + " is too big. Max number of elements per tensor supported: " + Integer .MAX_VALUE );
179
212
tensor = Utils .transpose (tensor );
180
- PrimitiveBlocks < DoubleType > blocks = PrimitiveBlocks .of ( tensor );
181
213
long [] tensorShape = tensor .dimensionsAsLongArray ();
182
214
int size = 1 ;
183
215
for (long ll : tensorShape ) size *= ll ;
184
216
final double [] flatArr = new double [size ];
185
217
int [] sArr = new int [tensorShape .length ];
186
218
for (int i = 0 ; i < sArr .length ; i ++)
187
219
sArr [i ] = (int ) tensorShape [i ];
188
- blocks .copy ( tensor .minAsLongArray (), flatArr , sArr );
189
- org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , ogShape );
220
+
221
+ Cursor <DoubleType > cursor = Views .flatIterable (tensor ).cursor ();
222
+ int i = 0 ;
223
+ while (cursor .hasNext ()) {
224
+ cursor .fwd ();
225
+ flatArr [i ++] = cursor .get ().get ();
226
+ }
227
+ org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , ogShape );
190
228
return ndarray ;
191
229
}
192
230
}
0 commit comments