|
8 | 8 | #include <ATen/native/DispatchStub.h>
|
9 | 9 | #include <ATen/native/TensorIterator.h>
|
10 | 10 | #include <ATen/native/cuda/Loops.cuh>
|
| 11 | +#include <ATen/native/cuda/JitLoops.cuh> |
11 | 12 |
|
12 | 13 | // NOTE: CUDA on Windows requires that the enclosing function
|
13 | 14 | // of a __device__ lambda not have internal linkage.
|
14 | 15 |
|
15 | 16 | namespace at {
|
16 | 17 | namespace native {
|
17 | 18 |
|
| 19 | +const char sigmoid_backward_name[] = "sigmoid_backward"; |
18 | 20 | void sigmoid_backward_kernel_cuda(TensorIteratorBase& iter) {
|
19 |
| - if(isComplexType(iter.dtype())) { |
20 |
| - AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "sigmoid_backward_cuda", [&]() { |
| 21 | + auto dtype = iter.dtype(); |
| 22 | + if(isComplexType(dtype)) { |
| 23 | +#if AT_USE_JITERATOR() |
| 24 | + static const auto sigmoid_backward_string = jiterator_stringify( |
| 25 | + template <typename T> |
| 26 | + T sigmoid_backward(T a, T b) { |
| 27 | + return a * std::conj((T{1.} - b) * b); |
| 28 | + } |
| 29 | + ); // sigmoid_backward_string |
| 30 | + AT_DISPATCH_COMPLEX_TYPES(dtype, "sigmoid_backward_cuda", [&]() { |
| 31 | + jitted_gpu_kernel< |
| 32 | + /*name=*/ sigmoid_backward_name, |
| 33 | + /*return_dtype=*/ scalar_t, |
| 34 | + /*common_dtype=*/ scalar_t, |
| 35 | + /*arity=*/ 2>(iter, sigmoid_backward_string); |
| 36 | + }); |
| 37 | +#else |
| 38 | + AT_DISPATCH_COMPLEX_TYPES(dtype, "sigmoid_backward_cuda", [&]() { |
21 | 39 | gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
|
22 | 40 | return a * std::conj((scalar_t{1.} - b) * b);
|
23 | 41 | });
|
24 | 42 | });
|
| 43 | +#endif |
25 | 44 | } else {
|
26 |
| - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "sigmoid_backward_cuda", [&]() { |
| 45 | + AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, dtype, "sigmoid_backward_cuda", [&]() { |
27 | 46 | gpu_kernel(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
|
28 | 47 | return a * (scalar_t(1.) - b) * b;
|
29 | 48 | });
|
|
0 commit comments