Skip to content

Commit 4b3b65e

Browse files
committed
correct error creating tensors
1 parent ac9f5c6 commit 4b3b65e

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/main/java/io/bioimage/modelrunner/pytorch/javacpp/tensor/JavaCPPTensorBuilder.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorByte(RandomAccessibleI
110110
int[] sArr = new int[tensorShape.length];
111111
for (int i = 0; i < sArr.length; i ++)
112112
sArr[i] = (int) tensorShape[i];
113-
blocks.copy( new long[tensorShape.length], flatArr, sArr );
113+
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );
114114
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
115115
return ndarray;
116116
}
@@ -135,7 +135,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorInt(RandomAccessibleIn
135135
int[] sArr = new int[tensorShape.length];
136136
for (int i = 0; i < sArr.length; i ++)
137137
sArr[i] = (int) tensorShape[i];
138-
blocks.copy( new long[tensorShape.length], flatArr, sArr );
138+
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );
139139
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
140140
return ndarray;
141141
}
@@ -160,7 +160,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorFloat(RandomAccessible
160160
int[] sArr = new int[tensorShape.length];
161161
for (int i = 0; i < sArr.length; i ++)
162162
sArr[i] = (int) tensorShape[i];
163-
blocks.copy( new long[tensorShape.length], flatArr, sArr );
163+
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );
164164
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
165165
return ndarray;
166166
}
@@ -185,7 +185,7 @@ private static org.bytedeco.pytorch.Tensor buildFromTensorDouble(RandomAccessibl
185185
int[] sArr = new int[tensorShape.length];
186186
for (int i = 0; i < sArr.length; i ++)
187187
sArr[i] = (int) tensorShape[i];
188-
blocks.copy( new long[tensorShape.length], flatArr, sArr );
188+
blocks.copy( tensor.minAsLongArray(), flatArr, sArr );
189189
org.bytedeco.pytorch.Tensor ndarray = org.bytedeco.pytorch.Tensor.create(flatArr, ogShape);
190190
return ndarray;
191191
}

0 commit comments

Comments
 (0)