Skip to content

Commit 6630866

Browse files
committed
fix(kernel): 为 Gather 支持负的 indices 值
Signed-off-by: YdrMaster <[email protected]>
1 parent 9495516 commit 6630866

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

src/04kernel/cuda/src/gather.cu

+3-2
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@ namespace refactor::kernel::cuda {
2020
tid += step) {
2121
auto i = tid / batch,
2222
j = tid % batch;
23-
auto index = __ldg(indices + i % midSizeO);
23+
auto k = __ldg(indices + i % midSizeO);
24+
auto quot = k >= 0 ? i / midSizeO : i / midSizeO + 1;
2425
optimizedMemcpy(unit * tid + output,
25-
unit * (batch * (i / midSizeO * midSizeI + index) + j) + data,
26+
unit * (batch * (quot * midSizeI + k) + j) + data,
2627
unit);
2728
}
2829
}

src/04kernel/src/kernels/gather/cpu_kernel.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@ namespace refactor::kernel {
3333
int64_t k = info.idxType == DataType::I64
3434
? reinterpret_cast<int64_t const *>(inputs[1])[d.rem]
3535
: reinterpret_cast<int32_t const *>(inputs[1])[d.rem];
36+
auto quot = k >= 0 ? d.quot : d.quot + 1;
3637
std::memcpy(info.postfix * i + output,
37-
info.postfix * (d.quot * info.midSizeI + k) + data,
38+
info.postfix * (quot * info.midSizeI + k) + data,
3839
info.postfix);
3940
});
4041
};

0 commit comments

Comments
 (0)