Commit 84502ba
[MPS] Fix sqrt and other for
Those kernels, instead of being instantiated for half2 (which corresponds to ComplexHalf) were instnatiated for short2, which resuled in the following test
```
% python3 -c "import torch; print(torch.rand(6, device='mps', dtype=torch.chalf).sqrt())"
```
Fail with
```
RuntimeError: Failed to create function state object for: sqrt_complex_half_half
```
As sqrt is not implemented for CPU, add explicit test to `test_sqrt`
Pull Request resolved: pytorch#148285
Approved by: https://github.com/dccitorch.chalf (pytorch#148285)1 parent d57f617 commit 84502ba
File tree
2 files changed
+7
-1
lines changed- aten/src/ATen/native/mps/kernels
- test
2 files changed
+7
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
146 | 146 | | |
147 | 147 | | |
148 | 148 | | |
149 | | - | |
| 149 | + | |
150 | 150 | | |
151 | 151 | | |
152 | 152 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6854 | 6854 | | |
6855 | 6855 | | |
6856 | 6856 | | |
| 6857 | + | |
| 6858 | + | |
| 6859 | + | |
| 6860 | + | |
| 6861 | + | |
| 6862 | + | |
6857 | 6863 | | |
6858 | 6864 | | |
6859 | 6865 | | |
| |||
0 commit comments