Skip to content

Commit 4b311a9

Browse files
kshitij12345pytorchmergebot
authored andcommitted
[complex32] conj
Reference: pytorch#74537 Required for `complex32` support by `fft` module. Tested with `test_dtypes` and `test_complex_half_reference_testing` in test_ops.py Pull Request resolved: pytorch#76132 Approved by: https://github.com/anjali411
1 parent 25d5b63 commit 4b311a9

File tree

5 files changed

+56
-10
lines changed

5 files changed

+56
-10
lines changed

aten/src/ATen/cpu/vec/vec_base.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ struct Vectorized {
133133
static constexpr size_type size() {
134134
return VECTOR_WIDTH / sizeof(T);
135135
}
136-
Vectorized() : values{0} {}
136+
Vectorized() : values{static_cast<T>(0)} {}
137137
Vectorized(T val) {
138138
for (int i = 0; i != size(); i++) {
139139
values[i] = val;

aten/src/ATen/native/cpu/UnaryOpsKernel.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,8 @@ static void imag_kernel(TensorIteratorBase& iter) {
211211

212212
// NB: Ignores the negative bit on tensors
213213
void conj_kernel(TensorIteratorBase& iter) {
214-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
215-
kBool, kBFloat16, kHalf, iter.common_dtype(), "conj_cpu", [&]() {
214+
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
215+
kBool, kBFloat16, kHalf, kComplexHalf, iter.common_dtype(), "conj_cpu", [&]() {
216216
cpu_kernel_vec(
217217
iter,
218218
[=](scalar_t a) -> scalar_t { return conj_impl(a); },

aten/src/ATen/native/cpu/zmath.h

+5
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ inline TYPE conj_impl (TYPE z) {
123123
return z; //No-Op
124124
}
125125

126+
template<>
127+
inline c10::complex<at::Half> conj_impl <c10::complex<at::Half>> (c10::complex<at::Half> z) {
128+
return c10::complex<at::Half>{z.real(), -z.imag()};
129+
}
130+
126131
template<>
127132
inline c10::complex<float> conj_impl <c10::complex<float>> (c10::complex<float> z) {
128133
return c10::complex<float>(z.real(), -z.imag());

aten/src/ATen/native/cuda/UnaryComplexKernels.cu

+22-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <limits>
33
#include <ATen/native/UnaryOps.h>
44
#include <ATen/native/cuda/Loops.cuh>
5+
#include <ATen/native/cuda/JitLoops.cuh>
56
#include <ATen/Dispatch.h>
67
#include <ATen/NumericUtils.h>
78
#include <ATen/native/DispatchStub.h>
@@ -81,13 +82,32 @@ __host__ __device__ static inline c10::complex<T> conj_wrapper(c10::complex<T> v
8182
}
8283

8384
// NB: Ignores the negative bit on tensors
85+
const char conj_name[] = "conj_kernel";
8486
void conj_kernel_cuda(TensorIteratorBase& iter) {
85-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
87+
auto common_dtype = iter.common_dtype();
88+
if (common_dtype == kComplexHalf) {
89+
using scalar_t = c10::complex<at::Half>;
90+
#if AT_USE_JITERATOR()
91+
static const auto conj_string = jiterator_stringify(
92+
template <typename T>
93+
T conj_kernel(T z) {
94+
return std::conj(z);
95+
}
96+
);
97+
jitted_gpu_kernel<conj_name, scalar_t, scalar_t, 1>(iter, conj_string);
98+
#else
99+
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
100+
return conj_wrapper(a);
101+
});
102+
#endif
103+
} else {
104+
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
86105
kBool, kBFloat16, kHalf, iter.common_dtype(), "conj_cuda", [&]() {
87106
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
88107
return conj_wrapper(a);
89108
});
90-
});
109+
});
110+
}
91111
}
92112

93113
REGISTER_DISPATCH(angle_stub, &angle_kernel_cuda);

torch/testing/_internal/common_methods_invocations.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -9418,16 +9418,21 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
94189418
),
94199419
UnaryUfuncInfo('conj',
94209420
ref=np.conj,
9421-
dtypes=all_types_and_complex_and(torch.bool,
9422-
torch.bfloat16, torch.half),
9421+
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16,
9422+
torch.half, torch.chalf),
94239423
supports_sparse=True,
94249424
supports_forward_ad=True,
94259425
supports_fwgrad_bwgrad=True,
9426-
supports_out=False),
9426+
supports_out=False,
9427+
skips=(
9428+
# numpy() raises TypeError: Got unsupported ScalarType ComplexHalf
9429+
DecorateInfo(unittest.expectedFailure, "TestUnaryUfuncs", "test_reference_numerics_normal",
9430+
dtypes=(torch.complex32,)),
9431+
)),
94279432
UnaryUfuncInfo('conj_physical',
94289433
ref=np.conj,
9429-
dtypes=all_types_and_complex_and(torch.bool,
9430-
torch.bfloat16, torch.half),
9434+
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16,
9435+
torch.half, torch.chalf),
94319436
supports_forward_ad=True,
94329437
supports_fwgrad_bwgrad=True,
94339438
supports_sparse=True,
@@ -9439,6 +9444,22 @@ def generate_std_var_kwargs(t: torch.Tensor, **kwargs):
94399444
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32, )),
94409445
DecorateInfo(unittest.skip("Skipped! conj_physical_ not implemented for sparse"),
94419446
'TestSparseUnaryUfuncs', 'test_inplace'),
9447+
# numpy() raises TypeError: Got unsupported ScalarType ComplexHalf
9448+
DecorateInfo(unittest.expectedFailure, "TestUnaryUfuncs", "test_reference_numerics_normal",
9449+
dtypes=(torch.complex32,)),
9450+
# RuntimeError: "nonzero_count_cpu" not implemented for 'ComplexHalf'
9451+
DecorateInfo(unittest.expectedFailure, "TestSparseCSR", "test_sparse_csr_consistency",
9452+
dtypes=(torch.complex32,)),
9453+
# RuntimeError: "nonzero_count_cpu" not implemented for 'ComplexHalf'
9454+
DecorateInfo(unittest.expectedFailure, "TestSparseCSR", "test_sparse_csr_unary_inplace",
9455+
dtypes=(torch.complex32,)),
9456+
# RuntimeError: "nonzero_count_cpu" not implemented for 'ComplexHalf'
9457+
DecorateInfo(unittest.expectedFailure, "TestSparseCSR", "test_sparse_csr_unary_out",
9458+
dtypes=(torch.complex32,)),
9459+
# RuntimeError: "add_out_op2_sparse_csr" not implemented for 'ComplexHalf'
9460+
DecorateInfo(unittest.expectedFailure, "TestSparseCSR",
9461+
"test_zero_to_zero_correspondence_unary",
9462+
dtypes=(torch.complex32,)),
94429463
)),
94439464
OpInfo('resolve_conj',
94449465
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),

0 commit comments

Comments
 (0)