Skip to content

Commit 3ca1a25

Browse files
malfetpytorchmergebot
authored andcommitted
[BE][MPS] Use copysign for imaginary part of sqrt (pytorch#148286)
Also it's tempting trying to replace `a*a + b*b` with `dot(input[index])` but for some reason it results in a slightly different output Pull Request resolved: pytorch#148286 Approved by: https://github.com/dcci ghstack dependencies: pytorch#148285
1 parent 84502ba commit 3ca1a25

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

aten/src/ATen/native/mps/kernels/UnaryKernel.metal

+5-11
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,11 @@ kernel void sqrt_complex_kernel(
5656
T0 b = input[index].y;
5757

5858
// modulus
59-
T0 r = T0(precise::sqrt(a * a + b * b));
60-
// real part: sqrt((r + a)/2)
61-
T0 real_part = T0(precise::sqrt((r + a) / 2.0));
62-
63-
// imaginary part: sign(b) * sqrt((r - a)/2)
64-
T0 imag_part;
65-
if (b >= 0) {
66-
imag_part = T0(precise::sqrt((r - a) / 2.0));
67-
} else {
68-
imag_part = T0(-precise::sqrt((r - a) / 2.0));
69-
}
59+
auto m = precise::sqrt(a * a + b * b);
60+
// real part: sqrt((m + a)/2)
61+
auto real_part = precise::sqrt((m + a) * .5);
62+
// imaginary part: sign(b) * sqrt((m - a)/2)
63+
auto imag_part = copysign(static_cast<T0>(precise::sqrt((m - a) * .5)), b);
7064
output[index] = vec2type_t<T0>(real_part, imag_part);
7165
}
7266

0 commit comments

Comments
 (0)