Skip to content

Commit bf27dd6

Browse files
Skylion007pytorchmergebot
authored andcommitted
Add dynamo support for operator.abs (pytorch#117442)
A test case for operator.abs and allows for constant folding with it. Partially applies to pytorch#116396 Pull Request resolved: pytorch#117442 Approved by: https://github.com/jansel, https://github.com/malfet
1 parent 1a790f5 commit bf27dd6

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

test/dynamo/test_functions.py

+11
Original file line numberDiff line numberDiff line change
@@ -2136,6 +2136,17 @@ def fn(x):
21362136
x = torch.randn(10)
21372137
self.assertEqual(opt_fn(x), fn(x))
21382138

2139+
def test_unary_fold_op(self):
2140+
for op in (operator.abs, abs, operator.pos, operator.neg):
2141+
with self.subTest(op=op):
2142+
2143+
def fn():
2144+
a = range(-10, 10)
2145+
return list(map(op, a))
2146+
2147+
opt_fn = torch._dynamo.optimize(nopython=True)(fn)
2148+
self.assertEqual(opt_fn(), fn())
2149+
21392150

21402151
instantiate_parametrized_tests(FunctionTests)
21412152

torch/_dynamo/variables/builtin.py

+2
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def _constant_fold_functions():
112112
str.format,
113113
sum,
114114
type,
115+
operator.abs,
115116
operator.pos,
116117
operator.neg,
117118
operator.not_,
@@ -155,6 +156,7 @@ def can_constant_fold_through(self):
155156
@functools.lru_cache(None)
156157
def _fx_graph_functions():
157158
fns = {
159+
operator.abs,
158160
operator.pos,
159161
operator.neg,
160162
operator.not_,

0 commit comments

Comments
 (0)