Skip to content

Commit fe83843

Browse files
HunterTracerzenghongtai
and
zenghongtai
authored
add SCATTER_API definition for scatter_mul in scatter.cpp & scatter.h (#344)
Co-authored-by: zenghongtai <[email protected]>
1 parent 111ffc4 commit fe83843

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

csrc/scatter.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,10 @@ scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
239239
return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0];
240240
}
241241

242-
torch::Tensor scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim,
243-
torch::optional<torch::Tensor> optional_out,
244-
torch::optional<int64_t> dim_size) {
242+
SCATTER_API torch::Tensor
243+
scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim,
244+
torch::optional<torch::Tensor> optional_out,
245+
torch::optional<int64_t> dim_size) {
245246
return ScatterMul::apply(src, index, dim, optional_out, dim_size)[0];
246247
}
247248

csrc/scatter.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
1515
torch::optional<torch::Tensor> optional_out,
1616
torch::optional<int64_t> dim_size);
1717

18+
SCATTER_API torch::Tensor
19+
scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim,
20+
torch::optional<torch::Tensor> optional_out,
21+
torch::optional<int64_t> dim_size);
22+
1823
SCATTER_API torch::Tensor
1924
scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
2025
torch::optional<torch::Tensor> optional_out,

0 commit comments

Comments
 (0)