Skip to content

Commit c7f12a4

Browse files
malfetpytorchmergebot
authored andcommitted
[MPSInductor] Speedup maximum/minumum ops (pytorch#144581)
By relying on the fact that if either `a` or `b` is NaN (or both), than `a + b` would also be NaN. I.e. it replaces ```metal auto tmp2 = metal::any(metal::isnan(static_cast<decltype(tmp0+tmp1)>(tmp0))) | metal::any(metal::isnan(static_cast<decltype(tmp0+tmp1)>(tmp1))) ? static_cast<decltype(tmp0+tmp1)>(NAN) : metal::max(static_cast<decltype(tmp0+tmp1)>(tmp0), static_cast<decltype(tmp0+tmp1)>(tmp1)); ``` with ```metal auto tmp2 = metal::isnan(tmp0 + tmp1) ? tmp0 + tmp1 : metal::max(static_cast<decltype(tmp0+tmp1)>(tmp0), static_cast<decltype(tmp0+tmp1)>(tmp1)); ``` which according to MetalProfiler takes fewer instructions: <img width="520" alt="image" src="https://github.com/user-attachments/assets/54659392-012b-453e-9c02-c3c5f332074a" /> vs <img width="1031" alt="image" src="https://github.com/user-attachments/assets/55fcfa78-1ea5-4b0a-8154-d79b3e3cc400" /> Pull Request resolved: pytorch#144581 Approved by: https://github.com/dcci, https://github.com/jhavukainen
1 parent a94ec0a commit c7f12a4

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

torch/_inductor/codegen/mps.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,19 +114,15 @@ def remainder(a: CSEVariable, b: CSEVariable) -> str:
114114
def maximum(a: CSEVariable, b: CSEVariable) -> str:
115115
typecast_a = f"static_cast<decltype({a}+{b})>({a})"
116116
typecast_b = f"static_cast<decltype({a}+{b})>({b})"
117-
nan_value = f"static_cast<decltype({a}+{b})>(NAN)"
118-
nan_check = f"metal::any(metal::isnan({typecast_a})) | metal::any(metal::isnan({typecast_b}))"
119117
max_res = f"metal::max({typecast_a}, {typecast_b})"
120-
return f"{nan_check} ? {nan_value} : {max_res}"
118+
return f"metal::isnan({a} + {b}) ? {a} + {b} : {max_res}"
121119

122120
@staticmethod
123121
def minimum(a: CSEVariable, b: CSEVariable) -> str:
124122
typecast_a = f"static_cast<decltype({a}+{b})>({a})"
125123
typecast_b = f"static_cast<decltype({a}+{b})>({b})"
126-
nan_value = f"static_cast<decltype({a}+{b})>(NAN)"
127-
nan_check = f"metal::any(metal::isnan({typecast_a})) | metal::any(metal::isnan({typecast_b}))"
128124
min_res = f"metal::min({typecast_a}, {typecast_b})"
129-
return f"{nan_check} ? {nan_value} : {min_res}"
125+
return f"metal::isnan({a} + {b}) ? {a} + {b} : {min_res}"
130126

131127
@staticmethod
132128
def logical_or(a: CSEVariable, b: CSEVariable) -> str:

0 commit comments

Comments
 (0)