2020 */
2121package io .bioimage .modelrunner .pytorch .javacpp .tensor ;
2222
23+ import java .nio .ByteBuffer ;
24+ import java .util .Arrays ;
25+
2326import io .bioimage .modelrunner .tensor .Tensor ;
2427import io .bioimage .modelrunner .tensor .Utils ;
28+ import net .imglib2 .Cursor ;
2529import net .imglib2 .RandomAccessibleInterval ;
2630import net .imglib2 .blocks .PrimitiveBlocks ;
2731import net .imglib2 .type .Type ;
2832import net .imglib2 .type .numeric .integer .ByteType ;
2933import net .imglib2 .type .numeric .integer .IntType ;
34+ import net .imglib2 .type .numeric .integer .UnsignedByteType ;
3035import net .imglib2 .type .numeric .real .DoubleType ;
3136import net .imglib2 .type .numeric .real .FloatType ;
3237import net .imglib2 .util .Util ;
38+ import net .imglib2 .view .Views ;
3339
3440/**
3541 * Class that manages the creation of JAvaCPP Pytorch tensors
@@ -101,16 +107,24 @@ public static <T extends Type<T>> org.bytedeco.pytorch.Tensor build(RandomAccess
101107 private static org .bytedeco .pytorch .Tensor buildFromTensorByte (RandomAccessibleInterval <ByteType > tensor )
102108 {
103109 long [] ogShape = tensor .dimensionsAsLongArray ();
110+ if (CommonUtils .int32Overflows (ogShape ))
111+ throw new IllegalArgumentException ("Provided tensor with shape " + Arrays .toString (ogShape )
112+ + " is too big. Max number of elements per tensor supported: " + Integer .MAX_VALUE );
104113 tensor = Utils .transpose (tensor );
105- PrimitiveBlocks < ByteType > blocks = PrimitiveBlocks .of ( tensor );
106114 long [] tensorShape = tensor .dimensionsAsLongArray ();
107115 int size = 1 ;
108116 for (long ll : tensorShape ) size *= ll ;
109117 final byte [] flatArr = new byte [size ];
110118 int [] sArr = new int [tensorShape .length ];
111119 for (int i = 0 ; i < sArr .length ; i ++)
112120 sArr [i ] = (int ) tensorShape [i ];
113- blocks .copy ( tensor .minAsLongArray (), flatArr , sArr );
121+
122+ Cursor <ByteType > cursor = Views .flatIterable (tensor ).cursor ();
123+ int i = 0 ;
124+ while (cursor .hasNext ()) {
125+ cursor .fwd ();
126+ flatArr [i ++] = cursor .get ().get ();
127+ }
114128 org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , ogShape );
115129 return ndarray ;
116130 }
@@ -126,17 +140,25 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleI
126140 private static org .bytedeco .pytorch .Tensor buildFromTensorInt (RandomAccessibleInterval <IntType > tensor )
127141 {
128142 long [] ogShape = tensor .dimensionsAsLongArray ();
143+ if (CommonUtils .int32Overflows (ogShape ))
144+ throw new IllegalArgumentException ("Provided tensor with shape " + Arrays .toString (ogShape )
145+ + " is too big. Max number of elements per tensor supported: " + Integer .MAX_VALUE );
129146 tensor = Utils .transpose (tensor );
130- PrimitiveBlocks < IntType > blocks = PrimitiveBlocks .of ( tensor );
131147 long [] tensorShape = tensor .dimensionsAsLongArray ();
132148 int size = 1 ;
133149 for (long ll : tensorShape ) size *= ll ;
134150 final int [] flatArr = new int [size ];
135151 int [] sArr = new int [tensorShape .length ];
136152 for (int i = 0 ; i < sArr .length ; i ++)
137153 sArr [i ] = (int ) tensorShape [i ];
138- blocks .copy ( tensor .minAsLongArray (), flatArr , sArr );
139- org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , ogShape );
154+
155+ Cursor <IntType > cursor = Views .flatIterable (tensor ).cursor ();
156+ int i = 0 ;
157+ while (cursor .hasNext ()) {
158+ cursor .fwd ();
159+ flatArr [i ++] = cursor .get ().get ();
160+ }
161+ org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , ogShape );
140162 return ndarray ;
141163 }
142164
@@ -151,17 +173,25 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleIn
151173 private static org .bytedeco .pytorch .Tensor buildFromTensorFloat (RandomAccessibleInterval <FloatType > tensor )
152174 {
153175 long [] ogShape = tensor .dimensionsAsLongArray ();
176+ if (CommonUtils .int32Overflows (ogShape ))
177+ throw new IllegalArgumentException ("Provided tensor with shape " + Arrays .toString (ogShape )
178+ + " is too big. Max number of elements per tensor supported: " + Integer .MAX_VALUE );
154179 tensor = Utils .transpose (tensor );
155- PrimitiveBlocks < FloatType > blocks = PrimitiveBlocks .of ( tensor );
156180 long [] tensorShape = tensor .dimensionsAsLongArray ();
157181 int size = 1 ;
158182 for (long ll : tensorShape ) size *= ll ;
159183 final float [] flatArr = new float [size ];
160184 int [] sArr = new int [tensorShape .length ];
161185 for (int i = 0 ; i < sArr .length ; i ++)
162186 sArr [i ] = (int ) tensorShape [i ];
163- blocks .copy ( tensor .minAsLongArray (), flatArr , sArr );
164- org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , ogShape );
187+
188+ Cursor <FloatType > cursor = Views .flatIterable (tensor ).cursor ();
189+ int i = 0 ;
190+ while (cursor .hasNext ()) {
191+ cursor .fwd ();
192+ flatArr [i ++] = cursor .get ().get ();
193+ }
194+ org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , ogShape );
165195 return ndarray ;
166196 }
167197
@@ -176,17 +206,25 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessible
176206 private static org .bytedeco .pytorch .Tensor buildFromTensorDouble (RandomAccessibleInterval <DoubleType > tensor )
177207 {
178208 long [] ogShape = tensor .dimensionsAsLongArray ();
209+ if (CommonUtils .int32Overflows (ogShape ))
210+ throw new IllegalArgumentException ("Provided tensor with shape " + Arrays .toString (ogShape )
211+ + " is too big. Max number of elements per tensor supported: " + Integer .MAX_VALUE );
179212 tensor = Utils .transpose (tensor );
180- PrimitiveBlocks < DoubleType > blocks = PrimitiveBlocks .of ( tensor );
181213 long [] tensorShape = tensor .dimensionsAsLongArray ();
182214 int size = 1 ;
183215 for (long ll : tensorShape ) size *= ll ;
184216 final double [] flatArr = new double [size ];
185217 int [] sArr = new int [tensorShape .length ];
186218 for (int i = 0 ; i < sArr .length ; i ++)
187219 sArr [i ] = (int ) tensorShape [i ];
188- blocks .copy ( tensor .minAsLongArray (), flatArr , sArr );
189- org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , ogShape );
220+
221+ Cursor <DoubleType > cursor = Views .flatIterable (tensor ).cursor ();
222+ int i = 0 ;
223+ while (cursor .hasNext ()) {
224+ cursor .fwd ();
225+ flatArr [i ++] = cursor .get ().get ();
226+ }
227+ org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , ogShape );
190228 return ndarray ;
191229 }
192230}
0 commit comments