forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_proxy_tensor.py
2156 lines (1783 loc) · 81.5 KB
/
test_proxy_tensor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Owner(s): ["module: ProxyTensor"]
# ruff: noqa: F841
from torch.testing._internal.common_utils import TestCase, run_tests
import torch
import torch._dynamo
import unittest
import warnings
import operator
from collections.abc import Iterable
from torch.nn.utils import stateless
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_methods_invocations import op_db, skip, xfail, skipOps
from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDependentOutputException, FakeTensorMode
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
from torch._decomp import decomposition_table
from torch.fx.experimental.symbolic_shapes import (
eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets,
guard_int, GuardOnDataDependentSymNode
)
from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.hop_db import hop_db
from torch.testing._internal.common_device_type import ops
import torch.testing._internal.optests as optests
from torch._C import _disabled_torch_function_impl
from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule
from torch.utils._pytree import tree_map
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
from torch import nn
import torch._functorch.config
import re
import functools
import itertools
aten = torch.ops.aten
HAS_CUDA = torch.cuda.is_available()
def strip_end(s, suffix):
if suffix and s.endswith(suffix):
return s[:-len(suffix)]
else:
return s
def show_guards(gm):
names = [strip_end(n, "_1") for n in fx_placeholder_targets(gm)]
return "\n".join(
gm.shape_env.produce_guards(fx_placeholder_vals(gm), names, _simplified=True, input_contexts=None)
)
def process_failures():
"""
Takes file containing failures like
FAILED test/test_proxy_tensor.py::TestProxyTensorOpInfoCPU::test_make_fx_symbolic_exhaustive___getitem___cpu_float32 - RuntimeError: aten.size.default - couldn't find symbolic meta function/decomposition # noqa: B950
and processes them into a list of opinfo xfails
"""
f = open('pytest_failures')
failures = f.readlines()
failures = [i.strip() for i in failures]
def process_failure_string(s, matcher):
out = re.search(matcher, s)
return out.groups()
SYMBOLIC_TRACE_MATCH = r'exhaustive_(.*)_cpu.*: (.*)'
failures = [process_failure_string(s, SYMBOLIC_TRACE_MATCH) for s in failures]
def create_normalized_name(op):
if op.variant_test_name == '':
s = op.name
else:
s = f"{op.name}.{op.variant_test_name}"
return s.replace('.', '_')
remap_opinfo = {create_normalized_name(op): (op.name, op.variant_test_name) for op in op_db}
print("symbolic_tensor_failures = {")
for failure, reason in failures:
print(f" xfail{remap_opinfo[failure]}, # {reason}")
print("}")
USE_TORCHVISION = False
try:
import torchvision
USE_TORCHVISION = True
except ImportError:
warnings.warn("Couldn't import torchvision. Some of our tests use it, try "
"to install it with commands from pytorch.org, post-fixed with "
"`--no-deps` to avoid overwriting the pytorch installation",
UserWarning)
def _create_new_input(x):
if not isinstance(x, torch.Tensor):
return x
if x.dtype != torch.float:
return x + 1
if x.is_leaf:
return torch.rand_like(x, requires_grad=x.requires_grad)
else:
return torch.rand_like(x)
"""
Delays a cos being executed on the unwraptensor until its used. Simulates a CommTensor used
"""
class UnwrapTensor(torch.Tensor):
@staticmethod
def __new__(cls, tensor: torch.Tensor):
r = torch.Tensor._make_wrapper_subclass(
cls,
tensor.size(),
dtype=tensor.dtype,
device=tensor.device,
layout=tensor.layout,
requires_grad=tensor.requires_grad,
)
r._tensor = tensor
return r
def __repr__(self):
# TODO: consider all_gather the local tensors for better debugging
return f"UnwrapTensor({self._tensor})"
__torch_function__ = _disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e):
ret = e
if isinstance(e, UnwrapTensor):
ret = e._tensor.cos()
return ret
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
return func(*args, **kwargs)
class TestGenericProxyTensor(TestCase):
# WARNING: if any of your inputs are index tensors, DO NOT use this
# function
def _test(self, f, inps):
fx_f = make_fx(f, tracing_mode=self.tracing_mode)(*inps)
new_inps = tree_map(_create_new_input, inps)
r1 = fx_f(*new_inps)
r2 = f(*new_inps)
self.assertEqual(r1, r2)
def test_pre_dispatch_mode_stack(self):
def f(a):
b = torch.ones(4, 4)
return torch.matmul(a, b)
# We expect to see matmul in the trace - it should NOT be decomposed into mm.
# Also, torch.ones() doesn't show up in the trace.
# This is annoying but expected: ones() never dispatches to the Autograd dispatch key,
# so our mode never sees it - it goes directly to the BackendSelect key.
inp = torch.ones(4, 4)
# Test that make_fx(pre_dispatch=True) clears caches properly.
from torch._dispatch.python import enable_python_dispatcher
with enable_python_dispatcher():
out1 = f(inp)
fx_g = make_fx(f, pre_dispatch=True)(inp)
self.assertExpectedInline(fx_g.code.strip(), """\
def forward(self, a_1):
ones = torch.ops.aten.ones.default([4, 4], device = device(type='cpu'), pin_memory = False)
matmul = torch.ops.aten.matmul.default(a_1, ones); a_1 = ones = None
return matmul""")
def test_pre_dispatch_linear(self):
def f(a, b, c):
return torch.nn.functional.linear(a, b, c)
a = torch.ones(4, 4)
b = torch.ones(4, 4)
c = torch.ones(4)
fx_g = make_fx(f, pre_dispatch=True)(a, b, c)
out1 = f(a, b, c)
out2 = fx_g(a, b, c)
self.assertEqual(out1, out2)
def test_pre_dispatch_no_grad(self):
def f(a):
b = a.sin()
torch.set_grad_enabled(False)
c = b.cos()
torch.set_grad_enabled(True)
return b + c.sin()
a1 = torch.randn(4, requires_grad=True)
a2 = a1.detach().clone().requires_grad_(True)
a_tmp = a1.detach().clone().requires_grad_(True)
fx_g = make_fx(f, pre_dispatch=True)(a_tmp)
out1 = f(a1)
out2 = fx_g(a2)
self.assertEqual(out1, out2)
out1.sum().backward()
out2.sum().backward()
self.assertEqual(a1.grad, a2.grad)
def test_make_fx_simple(self):
def f(x):
return torch.sin(x)
self._test(f, (torch.randn(3),))
def test_scalar_device(self, device='cpu'):
def f(a, b):
return a + b
self._test(f, [torch.randn(3, device=device), torch.tensor(5)])
def test_isolated_graphmodule(self):
def is_any_sum(gm):
return any(node.target == torch.ops.aten.sum.default for node in gm.graph.nodes)
def is_any_digamma(gm):
return any(node.target == torch.ops.aten.digamma.default for node in gm.graph.nodes)
def is_any_sigmoid(gm):
return any(node.target == torch.ops.aten.sigmoid.default for node in gm.graph.nodes)
def inner(x):
return torch.sum(x)
def f(x):
gm = get_isolated_graphmodule(inner, (x,), {})
self.assertTrue(is_any_sum(gm))
return x + torch.randn(x.shape)
# get_isolated_graphmodule uses make_fx internally that shouldn't be traced
# by the outer make_fx call
traced = make_fx(f)(torch.randn(3))
self.assertFalse(is_any_sum(traced))
# When factory functions are used, they should not be traced
# by the outer make_fx call
def inner_with_factory():
val = torch.tensor(float(1))
val.add_(2)
return torch.full((10, 10), val).sum()
def f1(x):
gm = get_isolated_graphmodule(inner_with_factory, (), {})
self.assertTrue(is_any_sum(gm))
return torch.sigmoid(x)
def f2(x):
gm = get_isolated_graphmodule(f1, (x,), {})
self.assertFalse(is_any_sum(gm))
self.assertTrue(is_any_sigmoid(gm))
return torch.digamma(x)
traced = make_fx(f2)(torch.randn(3))
self.assertFalse(is_any_sum(traced))
self.assertFalse(is_any_sigmoid(traced))
self.assertTrue(is_any_digamma(traced))
# Verify nested make_fx calls don't make factory functions to be leaked
# into the outer graph. Verify that `make_fx`` itself does not leak its execution.
def f2(x):
gm = make_fx(f1)(x)
self.assertFalse(is_any_sum(gm))
self.assertTrue(is_any_sigmoid(gm))
return torch.digamma(x)
traced = make_fx(f2)(torch.randn(3))
self.assertFalse(is_any_sum(traced))
self.assertFalse(is_any_sigmoid(traced))
self.assertTrue(is_any_digamma(traced))
# Verify that the `forward`` function of a graph module produced as a
# side effect of an interior `make_fx` is still traced
def f3(x):
gm = make_fx(f1)(x)
self.assertFalse(is_any_sum(gm))
self.assertTrue(is_any_sigmoid(gm))
# `gm.forward`` is still traced
return torch.digamma(gm(x))
traced = make_fx(f3)(torch.randn(3))
self.assertFalse(is_any_sum(traced))
self.assertTrue(is_any_sigmoid(traced))
self.assertTrue(is_any_digamma(traced))
# Verify interaction with non-ProxyTensor modes
from torch.testing._internal.logging_tensor import LoggingTensorMode
def f1_logging(x):
with LoggingTensorMode():
gm = get_isolated_graphmodule(inner_with_factory, (), {})
self.assertTrue(is_any_sum(gm))
return torch.sigmoid(x)
def f2_logging(x):
with LoggingTensorMode(), LoggingTensorMode():
gm = get_isolated_graphmodule(f1_logging, (x,), {})
self.assertFalse(is_any_sum(gm))
self.assertTrue(is_any_sigmoid(gm))
return torch.digamma(x)
traced = make_fx(f2_logging)(torch.randn(3))
self.assertFalse(is_any_sum(traced))
self.assertFalse(is_any_sigmoid(traced))
self.assertTrue(is_any_digamma(traced))
# Verify interaction with another tensor subclass
# This case currently doesn't work and should raise an error
# See: https://github.com/pytorch/pytorch/pull/81764#issuecomment-1200472068
from torch.testing._internal.logging_tensor import LoggingTensor
def f1_logging_tensor(x):
gm = get_isolated_graphmodule(inner_with_factory, (), {})
self.assertTrue(is_any_sum(gm))
return torch.sigmoid(x)
def f2_logging_tensor(x):
x = LoggingTensor(x)
gm = get_isolated_graphmodule(f1_logging_tensor, (x,), {})
self.assertFalse(is_any_sum(gm))
self.assertTrue(is_any_sigmoid(gm))
return torch.digamma(x)
traced = make_fx(f2_logging_tensor)(torch.randn(3))
self.assertFalse(is_any_sum(traced))
self.assertFalse(is_any_sigmoid(traced)) # this fails, sigmoid is traced with LoggingTensor
self.assertTrue(is_any_digamma(traced))
# See https://github.com/pytorch/pytorch/issues/97541
def test_empty_like_doesnt_burn_in_defaults(self):
def f(x):
return torch.empty_like(x)
out = make_fx(f)(torch.randn(3))
self.assertExpectedInline(out.code.strip(), """\
def forward(self, x_1):
empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False); x_1 = None
return empty_like""")
def test_proxy_tensor_mode_with_decomp_table_preserves_proxy(self):
def f(x):
y = x.new_zeros(x.size())
y.copy_(x)
return y
def _new_zeros_decomp(inp, size, dtype=None, layout=None, device=None, pin_memory=None):
return torch.zeros(size, dtype=inp.dtype, device=inp.device)
factory_func_decomp = {torch.ops.aten.new_zeros.default: _new_zeros_decomp}
# When new_zeros() decomposes into torch.zero(), we expect ProxyTensorMode
# to still be (re-entrantly) enabled, so that the `torch.zero()` call
# returns a ProxyTensor.
out = make_fx(f, decomposition_table=factory_func_decomp)(torch.ones(2))
self.assertExpectedInline(out.code, """\
def forward(self, x_1):
zeros = torch.ops.aten.zeros.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
copy_ = torch.ops.aten.copy_.default(zeros, x_1); zeros = x_1 = None
return copy_
""")
def test_make_fx_reentrant_dispatch(self):
def f(x):
return torch.ops.aten.norm.Scalar(x, 2.0)
def norm_decomp(x, p=2.0):
if p != 2.0:
raise RuntimeError("can't handle with p != 2")
return torch.sqrt(torch.sum(torch.square(x)))
decomp = {torch.ops.aten.norm.Scalar: norm_decomp}
traced = make_fx(f, decomposition_table=decomp, tracing_mode=self.tracing_mode)(torch.rand(3))
for n in traced.graph.nodes:
self.assertTrue("square" not in str(n.target))
self.assertTrue("norm" not in str(n.target))
@unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
def test_resnet18_backward_trace(self):
mod = torchvision.models.resnet18()
# An old version of this test called the module directly. This works
# for tracing_mode == "real", but for fake tensors, we also have to
# ensure that the parameters and buffers get wrapped in fake tensors
# because free fake tensors are not supported. Fortunately functional_call
# does precisely this for us.
def f(x, params, buffers):
for p in params.values():
p.grad = None
loss = torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum()
# I could have done this with the functional API, but there is
# plenty of exercising this; I want to show mutating API still
# works
loss.backward()
return [p.grad for p in params.values()]
inp = torch.randn(3, 3, 250, 250)
self._test(f, [inp, dict(mod.named_parameters()), dict(mod.named_buffers())])
def test_varargs(self):
def f(*args):
return sum(args)
self._test(f, [torch.randn(2), torch.randn(2)])
def test_proxy_tensor(self):
def f_grad(x):
val = x.cos().cos().sum()
return torch.autograd.grad(val, x)
def f_backward(x):
val = x.cos().cos().sum()
val.backward()
return x.grad
for f in [f_grad, f_backward]:
self._test(f, [torch.randn(3, requires_grad=True)])
def test_pickle_issue89626(self):
import pickle
x = torch.randn(2)
make_fx(lambda x: x * 2, tracing_mode=self.tracing_mode)(x)
pickle.dumps(x)
def test_inplace_metadata(self):
def f(x):
x = x.clone()
x.unsqueeze_(-1)
assert x.shape[-1] == 1
return x
self._test(f, [torch.randn(5)])
def test_mode_tracing_factory_function(self):
def f(x):
return x + torch.randn(x.shape)
# default behavior should trace factory functions
traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
self.assertTrue(
any(
node.target == aten.randn.default
for node in traced.graph.nodes
)
)
def test_pre_dispatch_functionalization(self):
def f(x):
a = FunctionalTensorMode(pre_dispatch=True, export=True)
with a:
x_unwrapped = FunctionalTensor.to_functional(x)
y = torch.matmul(x_unwrapped, x_unwrapped)
y = y + x_unwrapped
y.mul_(5)
y_unwrapped = torch._from_functional_tensor(y.elem)
return y_unwrapped
from torch._dispatch.python import enable_python_dispatcher
with enable_python_dispatcher():
inp = torch.randn(4, 4)
gm = make_fx(f, pre_dispatch=True)(inp)
# TODO actually not decompose
self.assertExpectedInline(gm.code.strip(), """\
def forward(self, x_1):
matmul = torch.ops.aten.matmul.default(x_1, x_1)
add = torch.ops.aten.add.Tensor(matmul, x_1); matmul = x_1 = None
mul = torch.ops.aten.mul.Tensor(add, 5); add = None
return mul""")
def test_pre_dispatch_functionalization_view_op(self):
def f(x):
a = FunctionalTensorMode(pre_dispatch=True, export=True)
with a:
x_unwrapped = FunctionalTensor.to_functional(x)
y = torch.matmul(x_unwrapped, x_unwrapped)
x_unwrapped = x_unwrapped.transpose(1, 0)
y = y + x_unwrapped
y = y.view(2, 8)
y_unwrapped = torch._from_functional_tensor(y.elem)
return y_unwrapped
from torch._dispatch.python import enable_python_dispatcher
with enable_python_dispatcher():
inp = torch.randn(4, 4)
gm = make_fx(f, pre_dispatch=True)(inp)
# TODO actually not decompose
self.assertExpectedInline(gm.code.strip(), """\
def forward(self, x_1):
matmul = torch.ops.aten.matmul.default(x_1, x_1)
transpose = torch.ops.aten.transpose.int(x_1, 1, 0); x_1 = None
add = torch.ops.aten.add.Tensor(matmul, transpose); matmul = transpose = None
view = torch.ops.aten.view.default(add, [2, 8]); add = None
return view""")
def test_val_metadata_mutation(self):
def f(x):
y = x.clone()
y.unsqueeze_(0)
return y
traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3, requires_grad=True))
self.assertEqual([
tuple(node.meta['val'].shape)
for node in traced.graph.nodes
if 'val' in node.meta
], [(3,), (3,), (1, 3)])
def test_make_fx_overloads(self):
def f(x):
return x.cos() + torch.randn(x.shape)
traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
self.assertTrue(all(isinstance(node.target, torch._ops.OpOverload)
for node in traced.graph.nodes if node.op == 'call_function'))
def test_tensor_constants(self):
def f():
val = torch.tensor(float('inf'))
return torch.full((100, 100), val)
self._test(f, [])
def test_allclose(self):
def f(a, b):
return torch.allclose(a, b)
def test_f():
make_fx(f, tracing_mode=self.tracing_mode)(
torch.zeros(3), torch.zeros(3)
)
if self.tracing_mode != "real":
self.assertRaises(DataDependentOutputException, test_f)
else:
self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
def test_constant_proxy_tensor_mut(self):
def f():
val = torch.tensor(float(1))
val.add_(2)
return torch.full((100, 100), val)
g = make_fx(f, tracing_mode=self.tracing_mode)()
self.assertEqual(g(), f())
# In case we mutated shared state in the g graph!
self.assertEqual(g(), f())
def test_constant_unbind(self):
def f():
val = torch.tensor([2])
r, = torch.unbind(val, 0)
return r.item()
g = make_fx(f, tracing_mode=self.tracing_mode)()
self.assertEqual(g(), f())
def test_constant_blowup(self):
def f():
val = torch.tensor([2])
blowup = val.repeat(1000)
return bool(blowup.sum().item() == 2)
def test_f():
make_fx(f, tracing_mode=self.tracing_mode)()
self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
def test_constant_random(self):
def f():
val = torch.tensor([2.0])
val.normal_()
return bool(val.item() == 2.1)
def test_f():
make_fx(f, tracing_mode=self.tracing_mode)()
self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
def test_decomposition_interpreter(self):
def fn(x):
return torch.nn.functional.silu(x)
x = torch.rand((4, 4))
fx_module = make_fx(fn, tracing_mode=self.tracing_mode, decomposition_table=None)(x)
found_silu = False
for n in fx_module.graph.nodes:
if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default:
found_silu = True
self.assertTrue(found_silu)
new_graph = torch.fx.Graph()
silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]}
DecompositionInterpreter(
fx_module,
new_graph=new_graph,
decomposition_table=silu_decomp_table,
).run(x)
decomposed_module = torch.fx.GraphModule(fx_module, new_graph)
for n in decomposed_module.graph.nodes:
self.assertTrue(n.target != torch.ops.aten.silu)
self.assertTrue(n.target != torch.ops.aten.silu.default)
self.assertEqual(fx_module(x), decomposed_module(x))
def test_make_fx_model_fwd_bwd(self):
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
return self.linear(x).relu()
model = Foo()
def f(x, params):
out = torch.func.functional_call(model, params, x).sum()
out.backward()
return list(params.values())
input = torch.randn(3, 5, requires_grad=True)
params = dict(model.named_parameters())
fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params)
# fx may change the order of parameters in list, so using set() to compare
self.assertTrue(
torch.allclose(fx_f(input, params)[0], f(input, params)[0])
or
torch.allclose(fx_f(input, params)[0], f(input, params)[1])
)
self.assertTrue(
torch.allclose(fx_f(input, params)[1], f(input, params)[0])
or
torch.allclose(fx_f(input, params)[1], f(input, params)[1])
)
def test_make_fx_model_double_param(self):
class Emformer(torch.nn.Module):
def __init__(
self,
input_dim: int = 256,
) -> None:
super().__init__()
self.layer_norm = torch.nn.LayerNorm(input_dim)
def forward(mod_self, x): # noqa: B902
self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor))
y = mod_self.layer_norm(x)
self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor))
z = mod_self.layer_norm(y)
return z
gm = make_fx(Emformer())(torch.randn(16, 1, 256))
ops = {n.target for n in gm.graph.nodes if n.op == 'call_function'}
self.assertEqual(len(ops), 2)
def test_make_fx_model_fwd_bwd_wgtupdate(self):
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
return self.linear(x).relu()
model = Foo()
def f(args, params, buffers):
for p in params.values():
p.grad = None
if not isinstance(args, Iterable):
args = [args]
params_and_buffers = {**params, **buffers}
out = torch.func.functional_call(model, params_and_buffers, args)
out.sum().backward()
return [p - 1e-4 * p.grad for p in params.values()]
input = torch.randn(3, 5, requires_grad=True)
params = dict(model.named_parameters())
buffers = dict(model.named_buffers())
fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params, buffers)
# fx may change the order of parameters in list, so using set() to compare
# also there is a numerical difference in results so changing atol from 1e-08 to 1e-03
self.assertTrue(
torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[0], atol=1e-03)
or
torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[1], atol=1e-03)
)
self.assertTrue(
torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[0], atol=1e-03)
or
torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03)
)
def test_trace_subclasses(self):
def f1(x):
x = UnwrapTensor(x)
y = x * 2
return y
def f2(x):
wrapped = UnwrapTensor(x)
y = x * wrapped
return y
inp = [torch.randn(5)]
self._test(f1, inp)
self._test(f2, inp)
def test_partial_decomp(self):
def f(a, b, c):
x = torch.addmm(a, b, c)
y = torch.addmm(a, b, c, beta=2, alpha=1)
return x + y
inps = [torch.randn(5, 5), torch.randn(5, 5), torch.randn(5, 5)]
fx_g = make_fx(f)(*inps)
def addmm(a, b, c, beta=1, alpha=1):
if beta == 1 and alpha == 1:
return NotImplemented
return beta * a + alpha * (b @ c)
decomposed_fx = make_fx(f, decomposition_table={aten.addmm.default: addmm})(*inps)
self.assertEqual(fx_g(*inps), decomposed_fx(*inps))
self.assertEqual(len([n for n in fx_g.graph.nodes if n.target == aten.addmm.default]), 2)
self.assertEqual(len([n for n in decomposed_fx.graph.nodes if n.target == aten.addmm.default]), 1)
def test_decomp_of_capture(self):
val = torch.randn(5)
def f(x):
return x.t() + val.t()
def nop(x):
return x.cos()
traced = make_fx(f, decomposition_table={torch.ops.aten.t.default: nop})(torch.randn(5))
self.assertEqual(len([n for n in traced.graph.nodes if n.target == torch.ops.aten.t.default]), 0)
@unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
def test_amp_cache(self):
layer = torch.nn.Conv2d(3, 3, 3).cuda()
def f(x, w):
return torch.nn.functional.conv2d(x, w, stride=layer.stride)
inp = torch.randn(4, 3, 10, 10, device='cuda')
with torch.autocast('cuda'):
out_graph = make_fx(f)(inp, layer.weight).graph
out_graph2 = make_fx(f)(inp, layer.weight).graph
self.assertEqual(len(out_graph.nodes), len(out_graph2.nodes))
for a, b in zip(out_graph.nodes, out_graph2.nodes):
self.assertEqual(a.op, b.op)
def test_strides(self):
def f(x):
self.assertTrue(x.is_contiguous())
self.assertFalse(x.is_contiguous(memory_format=torch.channels_last))
x = x.permute(0, 3, 1, 2)
self.assertFalse(x.is_contiguous())
self.assertTrue(x.is_contiguous(memory_format=torch.channels_last))
return x
make_fx(f)(torch.randn(2, 3, 4, 5))
def f(x):
self.assertTrue(x.is_contiguous())
y = x[:, 1]
self.assertFalse(y.is_contiguous())
y = x[:, ::2]
self.assertFalse(y.is_contiguous())
return x.cos()
make_fx(f)(torch.randn(2, 3, 4, 5))
def test_pr_86917(self):
# Tests the issue brought up here https://github.com/pytorch/pytorch/pull/86917#issuecomment-1283155344
def f(a, b):
return torch.ops.aten.nll_loss_forward(a, b, None, 1, 10)
self._test(f, [torch.randn(1, 10), torch.zeros(1, dtype=torch.long)])
class TestGenericProxyTensorReal(TestGenericProxyTensor):
tracing_mode = "real"
class TestGenericProxyTensorFake(TestGenericProxyTensor):
tracing_mode = "fake"
class TestGenericProxyTensorSymbolic(TestGenericProxyTensor):
tracing_mode = "symbolic"
del TestGenericProxyTensor
class TestRealProxyTensor(TestCase):
def test_error_on_data_dependent_ops(self):
def f():
x = torch.randn([])
y = torch.randn([])
assert torch.allclose(x * y, y * x)
z = float(x)
z2 = float(y)
# Smoke tests
make_fx(f, _error_on_data_dependent_ops=False)()
make_fx(f, pre_dispatch=True, _error_on_data_dependent_ops=False)()
class TestFakeProxyTensor(TestCase):
def test_issue82547(self):
x = nn.Parameter(torch.randn(3, 3))
def f():
return torch.ops.aten.t.default(x)
self.assertRaisesRegex(Exception, "Please convert all Tensors", lambda: make_fx(f, tracing_mode="fake")())
class A(torch.Tensor):
pass
x = A(torch.randn(3, 3))
self.assertRaisesRegex(TypeError, "Multiple dispatch failed", lambda: make_fx(f, tracing_mode="fake")())
def test_use_fake_and_tensor(self):
def f(x, y):
z = torch.tensor([2.0, 3.0])
return x + y + z
g = make_fx(f, tracing_mode="fake")(torch.randn(2), torch.randn(2))
x, y = torch.randn(2), torch.randn(2)
self.assertEqual(g(x, y), f(x, y))
def test_free_fake(self):
def f(x):
return torch.add(x, y)
with FakeTensorMode() as fake_mode:
y = torch.randn(2)
make_fx(f, tracing_mode="real")(torch.randn(2))
def test_fused_adam(self):
# See https://github.com/pytorch/pytorch/issues/99356
params = [torch.randn(10, 10) for _ in range(10)]
grads = [torch.randn(10, 10) for _ in range(10)]
exp_avgs = [torch.randn(10, 10) for _ in range(10)]
exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)]
max_exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)]
state_steps = [torch.tensor(0) for _ in range(10)]
def fused_adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps):
(new_params, _, _, _, _) = aten._fused_adam.default(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
lr=0.1,
beta1=0.9,
beta2=0.999,
weight_decay=0.01,
eps=1e-8,
amsgrad=False,
maximize=False,
)
for p, new_p in zip(params, new_params):
p.copy_(new_p)
return params
gm = make_fx(fused_adam, tracing_mode='fake')(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
)
ensure_ops_have_val = [aten._fused_adam.default, operator.getitem]
for n in gm.graph.nodes:
if n.op == "call_function" and n.target in ensure_ops_have_val:
self.assertIn('val', n.meta)
def test_alias(self):
def f(x):
return torch.ops.aten.alias(x)
r = str(make_fx(f, tracing_mode="fake")(torch.randn(2)).code).strip()
# NB: this should not have a detach call
self.assertExpectedInline(r, """\
def forward(self, x_1):
alias = torch.ops.aten.alias.default(x_1); x_1 = None
return alias""")
def test_meta(self):
def f(x):
a = x.cos()
b = torch.var_mean(a, dim=0)
c = b * 2
return c
out = make_fx(f, tracing_mode="fake")(torch.randn(5, 5))
for n in out.graph.nodes:
if n.op == 'output':
continue
self.assertTrue('val' in n.meta)
def test_fake_tensor_mode(self):
def f(a):
d = a.cos()
return d
from torch._guards import detect_fake_mode
existing_fake_mode = FakeTensorMode()
with existing_fake_mode:
out = make_fx(f, tracing_mode="real")(torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]]))
fake_mode = detect_fake_mode([node.meta.get('val', None) for node in out.graph.nodes])
self.assertEqual(fake_mode, existing_fake_mode)
def _get_node(fx_g, cond):
for n in fx_g.graph.nodes:
if cond(n):
return n
raise AssertionError
def _get_free_symbols(shape_env):
vars = tuple(shape_env.var_to_val.keys())
return len([var for var in vars if var not in shape_env.replacements])
def _trace(f, *args):
inps = [torch.randn(arg) for arg in args]
return make_fx(f, tracing_mode="symbolic")(*inps)
# TODO: Need to test the guards themselves specifically as well
class TestSymbolicTracing(TestCase):
def _test_dynamic(self, fn, trace_inputs, test_inputs, assert_eq=True):
"""
Tests fn traced with trace_inputs against test_inputs
Also returns shape env
"""
trace_inputs = [torch.randn(shape) for shape in trace_inputs]
traced_f = make_fx(fn, tracing_mode="symbolic")(*trace_inputs)
for input in test_inputs:
input = [torch.randn(shape) for shape in input]
rx, ry = traced_f(*input), fn(*input)
if assert_eq:
self.assertEqual(rx, ry)
return traced_f
def test_debug_interpreter(self):
import torch.library
from torch.library import Library
foo = Library("foo", "DEF") # noqa: TOR901
foo.define("foo(Tensor self) -> Tensor")
# Operator where meta and cpu disagree on strides
@torch.library.impl(foo, "foo", "CPU")
def foo_cpu(x):
return x.clone().T
@torch.library.impl(foo, "foo", "Meta")
def foo_meta(x):
return x.clone()
def f(x):
return torch.ops.foo.foo.default(x)
gm = make_fx(f, tracing_mode="symbolic")(torch.randn(2, 2))
from torch._functorch.compilers import DebugInterpreter
interp = DebugInterpreter(gm)
# input mismatch is caught (indicates guard problem)
self.assertRaisesRegex(
AssertionError, r"3 != 1",
lambda: interp.run(torch.randn(3, 3).T),
)