File tree 2 files changed +5
-3
lines changed
2 files changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -20,9 +20,10 @@ namespace refactor::kernel::cuda {
20
20
tid += step) {
21
21
auto i = tid / batch,
22
22
j = tid % batch;
23
- auto index = __ldg (indices + i % midSizeO);
23
+ auto k = __ldg (indices + i % midSizeO);
24
+ auto quot = k >= 0 ? i / midSizeO : i / midSizeO + 1 ;
24
25
optimizedMemcpy (unit * tid + output,
25
- unit * (batch * (i / midSizeO * midSizeI + index ) + j) + data,
26
+ unit * (batch * (quot * midSizeI + k ) + j) + data,
26
27
unit);
27
28
}
28
29
}
Original file line number Diff line number Diff line change @@ -33,8 +33,9 @@ namespace refactor::kernel {
33
33
int64_t k = info.idxType == DataType::I64
34
34
? reinterpret_cast <int64_t const *>(inputs[1 ])[d.rem ]
35
35
: reinterpret_cast <int32_t const *>(inputs[1 ])[d.rem ];
36
+ auto quot = k >= 0 ? d.quot : d.quot + 1 ;
36
37
std::memcpy (info.postfix * i + output,
37
- info.postfix * (d. quot * info.midSizeI + k) + data,
38
+ info.postfix * (quot * info.midSizeI + k) + data,
38
39
info.postfix );
39
40
});
40
41
};
You can’t perform that action at this time.
0 commit comments