Skip to content

Commit 19a94a9

Browse files
committed
fix assertion 'index out of bounds' in case dim_size is omitted
1 parent 5c9462e commit 19a94a9

File tree

3 files changed

+5
-13
lines changed

3 files changed

+5
-13
lines changed

csrc/cuda/scatter_cuda.cu

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,7 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
8484
else if (index.numel() == 0)
8585
sizes[dim] = 0;
8686
else {
87-
auto d_size = index.max().data_ptr<int64_t>();
88-
auto h_size = (int64_t *)malloc(sizeof(int64_t));
89-
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
90-
sizes[dim] = 1 + *h_size;
87+
sizes[dim] = 1 + index.max().cpu().data_ptr<int64_t>()[0];
9188
}
9289
out = torch::empty(sizes, src.options());
9390
}

csrc/cuda/segment_coo_cuda.cu

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,7 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
186186
else {
187187
auto tmp = index.select(dim, index.size(dim) - 1);
188188
tmp = tmp.numel() > 1 ? tmp.max() : tmp;
189-
auto d_size = tmp.data_ptr<int64_t>();
190-
auto h_size = (int64_t *)malloc(sizeof(int64_t));
191-
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
192-
sizes[dim] = 1 + *h_size;
189+
sizes[dim] = 1 + tmp.cpu().data_ptr<int64_t>()[0];
193190
}
194191
out = torch::zeros(sizes, src.options());
195192
}

csrc/cuda/segment_csr_cuda.cu

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,12 +245,10 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
245245
} else {
246246
auto sizes = src.sizes().vec();
247247
if (src.numel() > 0) {
248-
auto d_size = indptr.flatten()[-1].data_ptr<int64_t>();
249-
auto h_size = (int64_t *)malloc(sizeof(int64_t));
250-
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
251-
sizes[dim] = *h_size;
252-
} else
248+
sizes[dim] = indptr.flatten()[-1].cpu().data_ptr<int64_t>()[0];
249+
} else {
253250
sizes[dim] = 0;
251+
}
254252
out = torch::empty(sizes, src.options());
255253
}
256254

0 commit comments

Comments
 (0)