Skip to content

Commit ceb664a

Browse files
bobrenjc93pytorchmergebot
authored andcommitted
add float_args benchmark (pytorch#143143)
71% improvement with automatic dynamic float arguments with specialize_float=False ``` float_args,compile_time_instruction_count,346293869 ``` with specialize_float=True ``` float_args,compile_time_instruction_count,1198546486 ``` Pull Request resolved: pytorch#143143 Approved by: https://github.com/laithsakka ghstack dependencies: pytorch#141517
1 parent ab04f3a commit ceb664a

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import sys
2+
3+
from benchmark_base import BenchmarkBase
4+
5+
import torch
6+
from torch._inductor.utils import fresh_inductor_cache
7+
8+
9+
class Benchmark(BenchmarkBase):
10+
def __init__(self):
11+
super().__init__(
12+
category="float_args",
13+
backend="inductor",
14+
device="cpu",
15+
)
16+
17+
def name(self):
18+
return f"{self.category()}"
19+
20+
def description(self):
21+
return "Benchmark to measure recompilations with float arguments."
22+
23+
def _prepare_once(self):
24+
torch.manual_seed(0)
25+
26+
def _prepare(self):
27+
torch._dynamo.reset()
28+
29+
def _work(self):
30+
@torch.compile(backend="inductor")
31+
def f(x, y):
32+
return x + y
33+
34+
with fresh_inductor_cache():
35+
for i in range(8):
36+
f(torch.arange(3), i * 2.5)
37+
38+
39+
def main():
40+
result_path = sys.argv[1]
41+
Benchmark().enable_compile_time_instruction_count().collect_all().append_results(
42+
result_path
43+
)
44+
45+
46+
if __name__ == "__main__":
47+
main()

0 commit comments

Comments
 (0)