@@ -100,6 +100,7 @@ public static <T extends Type<T>> org.bytedeco.pytorch.Tensor build(RandomAccess
100
100
*/
101
101
private static org .bytedeco .pytorch .Tensor buildFromTensorByte (RandomAccessibleInterval <ByteType > tensor )
102
102
{
103
+ long [] ogShape = tensor .dimensionsAsLongArray ();
103
104
tensor = Utils .transpose (tensor );
104
105
PrimitiveBlocks < ByteType > blocks = PrimitiveBlocks .of ( tensor );
105
106
long [] tensorShape = tensor .dimensionsAsLongArray ();
@@ -110,7 +111,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleI
110
111
for (int i = 0 ; i < sArr .length ; i ++)
111
112
sArr [i ] = (int ) tensorShape [i ];
112
113
blocks .copy ( new long [tensorShape .length ], flatArr , sArr );
113
- org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , tensorShape );
114
+ org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , ogShape );
114
115
return ndarray ;
115
116
}
116
117
@@ -124,6 +125,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleI
124
125
*/
125
126
private static org .bytedeco .pytorch .Tensor buildFromTensorInt (RandomAccessibleInterval <IntType > tensor )
126
127
{
128
+ long [] ogShape = tensor .dimensionsAsLongArray ();
127
129
tensor = Utils .transpose (tensor );
128
130
PrimitiveBlocks < IntType > blocks = PrimitiveBlocks .of ( tensor );
129
131
long [] tensorShape = tensor .dimensionsAsLongArray ();
@@ -134,7 +136,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleIn
134
136
for (int i = 0 ; i < sArr .length ; i ++)
135
137
sArr [i ] = (int ) tensorShape [i ];
136
138
blocks .copy ( new long [tensorShape .length ], flatArr , sArr );
137
- org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , tensorShape );
139
+ org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , ogShape );
138
140
return ndarray ;
139
141
}
140
142
@@ -148,6 +150,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleIn
148
150
*/
149
151
private static org .bytedeco .pytorch .Tensor buildFromTensorFloat (RandomAccessibleInterval <FloatType > tensor )
150
152
{
153
+ long [] ogShape = tensor .dimensionsAsLongArray ();
151
154
tensor = Utils .transpose (tensor );
152
155
PrimitiveBlocks < FloatType > blocks = PrimitiveBlocks .of ( tensor );
153
156
long [] tensorShape = tensor .dimensionsAsLongArray ();
@@ -158,7 +161,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessible
158
161
for (int i = 0 ; i < sArr .length ; i ++)
159
162
sArr [i ] = (int ) tensorShape [i ];
160
163
blocks .copy ( new long [tensorShape .length ], flatArr , sArr );
161
- org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , tensorShape );
164
+ org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , ogShape );
162
165
return ndarray ;
163
166
}
164
167
@@ -172,6 +175,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessible
172
175
*/
173
176
private static org .bytedeco .pytorch .Tensor buildFromTensorDouble (RandomAccessibleInterval <DoubleType > tensor )
174
177
{
178
+ long [] ogShape = tensor .dimensionsAsLongArray ();
175
179
tensor = Utils .transpose (tensor );
176
180
PrimitiveBlocks < DoubleType > blocks = PrimitiveBlocks .of ( tensor );
177
181
long [] tensorShape = tensor .dimensionsAsLongArray ();
@@ -182,7 +186,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorDouble(RandomAccessibl
182
186
for (int i = 0 ; i < sArr .length ; i ++)
183
187
sArr [i ] = (int ) tensorShape [i ];
184
188
blocks .copy ( new long [tensorShape .length ], flatArr , sArr );
185
- org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , tensorShape );
189
+ org .bytedeco .pytorch .Tensor ndarray = org .bytedeco .pytorch .Tensor .create (flatArr , ogShape );
186
190
return ndarray ;
187
191
}
188
192
}
0 commit comments