Skip to content

Commit 84502ba

Browse files
malfetpytorchmergebot
authored andcommitted
[MPS] Fix sqrt and other for torch.chalf (pytorch#148285)
Those kernels, instead of being instantiated for half2 (which corresponds to ComplexHalf) were instnatiated for short2, which resuled in the following test ``` % python3 -c "import torch; print(torch.rand(6, device='mps', dtype=torch.chalf).sqrt())" ``` Fail with ``` RuntimeError: Failed to create function state object for: sqrt_complex_half_half ``` As sqrt is not implemented for CPU, add explicit test to `test_sqrt` Pull Request resolved: pytorch#148285 Approved by: https://github.com/dcci
1 parent d57f617 commit 84502ba

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

aten/src/ATen/native/mps/kernels/UnaryKernel.metal

+1-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ INSTANTIATE_UNARY_KERNELS2(float, long);
146146
constant vec2type_t<DTYPE0> * input [[buffer(1)]], \
147147
uint did [[thread_position_in_grid]]);
148148

149-
INSTANTIATE_UNARY_KERNELS_VEC2(short, short);
149+
INSTANTIATE_UNARY_KERNELS_VEC2(half, half);
150150
INSTANTIATE_UNARY_KERNELS_VEC2(float, float);
151151

152152
template <typename T0, typename T1>

test/test_mps.py

+6
Original file line numberDiff line numberDiff line change
@@ -6854,6 +6854,12 @@ def helper(shape):
68546854

68556855
helper((2, 8, 4, 5))
68566856

6857+
# Test complex half
6858+
x = torch.rand(8, device='mps', dtype=torch.chalf)
6859+
rc_h = x.sqrt()
6860+
rc_f = x.cfloat().sqrt().chalf()
6861+
self.assertEqual(rc_h, rc_f)
6862+
68576863
# Test selu, elu, celu
68586864
def test_elu(self):
68596865
def helper(shape, alpha=1.0, memory_format=torch.contiguous_format):

0 commit comments

Comments
 (0)