Skip to content

Commit ac9f5c6

Browse files
committed
correct error creating tensors
1 parent 57f3077 commit ac9f5c6

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/JavaCPPTensorBuilder.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ public static <T extends Type<T>> org.bytedeco.pytorch.Tensor build(RandomAccess
100100
*/
101101
private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleInterval<ByteType> tensor)
102102
{
103+
long[] ogShape = tensor.dimensionsAsLongArray();
103104
tensor = Utils.transpose(tensor);
104105
PrimitiveBlocks< ByteType > blocks = PrimitiveBlocks.of( tensor );
105106
long[] tensorShape = tensor.dimensionsAsLongArray();
@@ -110,7 +111,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleI
110111
for (int i = 0; i < sArr.length; i ++)
111112
sArr[i] = (int) tensorShape[i];
112113
blocks.copy( new long[tensorShape.length], flatArr, sArr );
113-
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, tensorShape);
114+
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
114115
return ndarray;
115116
}
116117

@@ -124,6 +125,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleI
124125
*/
125126
private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleInterval<IntType> tensor)
126127
{
128+
long[] ogShape = tensor.dimensionsAsLongArray();
127129
tensor = Utils.transpose(tensor);
128130
PrimitiveBlocks< IntType > blocks = PrimitiveBlocks.of( tensor );
129131
long[] tensorShape = tensor.dimensionsAsLongArray();
@@ -134,7 +136,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleIn
134136
for (int i = 0; i < sArr.length; i ++)
135137
sArr[i] = (int) tensorShape[i];
136138
blocks.copy( new long[tensorShape.length], flatArr, sArr );
137-
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, tensorShape);
139+
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
138140
return ndarray;
139141
}
140142

@@ -148,6 +150,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleIn
148150
*/
149151
private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessibleInterval<FloatType> tensor)
150152
{
153+
long[] ogShape = tensor.dimensionsAsLongArray();
151154
tensor = Utils.transpose(tensor);
152155
PrimitiveBlocks< FloatType > blocks = PrimitiveBlocks.of( tensor );
153156
long[] tensorShape = tensor.dimensionsAsLongArray();
@@ -158,7 +161,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessible
158161
for (int i = 0; i < sArr.length; i ++)
159162
sArr[i] = (int) tensorShape[i];
160163
blocks.copy( new long[tensorShape.length], flatArr, sArr );
161-
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, tensorShape);
164+
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
162165
return ndarray;
163166
}
164167

@@ -172,6 +175,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessible
172175
*/
173176
private static org.bytedeco.pytorch.Tensor buildFromTensorDouble(RandomAccessibleInterval<DoubleType> tensor)
174177
{
178+
long[] ogShape = tensor.dimensionsAsLongArray();
175179
tensor = Utils.transpose(tensor);
176180
PrimitiveBlocks< DoubleType > blocks = PrimitiveBlocks.of( tensor );
177181
long[] tensorShape = tensor.dimensionsAsLongArray();
@@ -182,7 +186,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorDouble(RandomAccessibl
182186
for (int i = 0; i < sArr.length; i ++)
183187
sArr[i] = (int) tensorShape[i];
184188
blocks.copy( new long[tensorShape.length], flatArr, sArr );
185-
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, tensorShape);
189+
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
186190
return ndarray;
187191
}
188192
}

0 commit comments

Comments
 (0)