Skip to content

Commit a0a2b75

Browse files
khushi-411pytorchmergebot
authored andcommitted
[jiterator] sigmoid_backward: complex (pytorch#74948)
Summary: Follows: pytorch#74748 cc kshitij12345! Pull Request resolved: pytorch#74948 Reviewed By: malfet Differential Revision: D35445949 Pulled By: ngimel fbshipit-source-id: acd921a210fd15cc78046e37532173162daa38b8 (cherry picked from commit 7883b8d)
1 parent 53c2fc6 commit a0a2b75

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu

+22-3
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,41 @@
88
#include <ATen/native/DispatchStub.h>
99
#include <ATen/native/TensorIterator.h>
1010
#include <ATen/native/cuda/Loops.cuh>
11+
#include <ATen/native/cuda/JitLoops.cuh>
1112

1213
// NOTE: CUDA on Windows requires that the enclosing function
1314
// of a __device__ lambda not have internal linkage.
1415

1516
namespace at {
1617
namespace native {
1718

19+
const char sigmoid_backward_name[] = "sigmoid_backward";
1820
void sigmoid_backward_kernel_cuda(TensorIteratorBase& iter) {
19-
if(isComplexType(iter.dtype())) {
20-
AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "sigmoid_backward_cuda", [&]() {
21+
auto dtype = iter.dtype();
22+
if(isComplexType(dtype)) {
23+
#if AT_USE_JITERATOR()
24+
static const auto sigmoid_backward_string = jiterator_stringify(
25+
template <typename T>
26+
T sigmoid_backward(T a, T b) {
27+
return a * std::conj((T{1.} - b) * b);
28+
}
29+
); // sigmoid_backward_string
30+
AT_DISPATCH_COMPLEX_TYPES(dtype, "sigmoid_backward_cuda", [&]() {
31+
jitted_gpu_kernel<
32+
/*name=*/ sigmoid_backward_name,
33+
/*return_dtype=*/ scalar_t,
34+
/*common_dtype=*/ scalar_t,
35+
/*arity=*/ 2>(iter, sigmoid_backward_string);
36+
});
37+
#else
38+
AT_DISPATCH_COMPLEX_TYPES(dtype, "sigmoid_backward_cuda", [&]() {
2139
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
2240
return a * std::conj((scalar_t{1.} - b) * b);
2341
});
2442
});
43+
#endif
2544
} else {
26-
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "sigmoid_backward_cuda", [&]() {
45+
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, dtype, "sigmoid_backward_cuda", [&]() {
2746
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
2847
return a * (scalar_t(1.) - b) * b;
2948
});

0 commit comments

Comments
 (0)