forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_python_dispatch.py
2605 lines (2167 loc) · 98.5 KB
/
test_python_dispatch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Owner(s): ["module: __torch_dispatch__"]
# ruff: noqa: F841
import logging
import sys
import tempfile
import unittest
from copy import deepcopy
import torch
import torch._dynamo
from torch import SymInt
from torch._C import DispatchKey, DispatchKeySet
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.cuda.jiterator import _create_jit_fn
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.library import _scoped_library, fallthrough_kernel, impl, Library
from torch.multiprocessing.reductions import StorageWeakRef
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
ops,
)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_utils import (
first_sample,
IS_WINDOWS,
run_tests,
TEST_WITH_ROCM,
TestCase,
)
from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.logging_tensor import (
capture_logs,
capture_logs_with_logging_tensor_mode,
log_input,
LoggingTensor,
LoggingTensorMode,
LoggingTensorReentrant,
)
from torch.testing._internal.two_tensor import TwoTensor
from torch.utils import _pytree as pytree
from torch.utils._mode_utils import all_same_mode, no_dispatch
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
_get_current_dispatch_mode_stack,
is_in_torch_dispatch_mode,
TorchDispatchMode,
)
from torch.utils._pytree import tree_map, tree_map_only
# used as DataLoader collate_fn below; named here to avoid trying to pickle a lambda
def _identity(x):
return x
class TestDispatcherPythonBindings(TestCase):
def test_call_boxed(self) -> None:
sin = torch._C._dispatch_find_schema_or_throw("aten::sin", "")
x = torch.randn(3)
y = torch._C._dispatch_call_boxed(sin, x)
self.assertEqual(y, x.sin())
class TestPythonRegistration(TestCase):
test_ns = "_test_python_registration"
def tearDown(self):
if hasattr(torch.ops, self.test_ns):
del torch.ops._test_python_registration
def test_fallback(self) -> None:
test_key = "TESTING_ONLY_GenericMode"
test_keyset = torch._C.DispatchKeySet(test_key)
include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
with _scoped_library("_", "IMPL") as my_lib:
expected_op = None
expected_args = None
expected_kwargs = None
# Use this out shape to make sure the result from our fallback
# is what is returned to the user
out_shape = None
def my_fallback(op, *args, **kwargs):
# Disable our handler during checks and generating the output
with torch._C._ForceDispatchKeyGuard(
include_to_set, exclude_to_set | test_keyset
):
self.assertIs(op, expected_op)
self.assertEqual(args, expected_args)
self.assertEqual(kwargs, expected_kwargs)
# Return something specific
return torch.empty(out_shape)
my_lib.fallback(my_fallback, test_key)
a, b = torch.rand(2), torch.rand(2)
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
# Check a factory function
expected_op = torch.ops.aten.empty.memory_format
expected_args = ((2, 2),)
# Extra kwargs to bypass issues with default args in factory functions
expected_kwargs = {
"dtype": torch.float64,
"pin_memory": False,
"device": torch.device("cpu"),
}
out_shape = (3,)
out = torch.empty(*expected_args, **expected_kwargs)
self.assertEqual(out.size(), out_shape)
# Check a regular function
expected_op = torch.ops.aten.add.Tensor
expected_args = (a, b)
expected_kwargs = {}
out_shape = (4,)
out = a + b
self.assertEqual(out.size(), out_shape)
def test_fallback_keyset(self) -> None:
test_key_first = "TESTING_ONLY_GenericMode"
test_key_second = "TESTING_ONLY_GenericWrapper"
test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet(
test_key_second
)
include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
with _scoped_library("_", "IMPL") as my_lib:
first_called = False
second_called = False
def first_fallback(keyset, op, *args, **kwargs):
nonlocal first_called
if second_called:
# Recursive call
first_called = True
with torch._C._ForceDispatchKeyGuard(
include_to_set, exclude_to_set | test_keyset
):
return op(*args, **kwargs)
else:
# Redispatch down
keyset = keyset.remove(test_key_first)
return op.redispatch(keyset, *args, **kwargs)
def second_fallback(op, *args, **kwargs):
nonlocal second_called
# Set to avoid infinite recursion
second_called = True
# New dispatcher call should hit the first callback again
self.assertFalse(first_called)
a, b = args
# Make a substraction here instead of add !
c = a - b
self.assertTrue(first_called)
return c
my_lib.fallback(first_fallback, test_key_first, with_keyset=True)
my_lib.fallback(second_fallback, test_key_second)
a, b = torch.rand(2), torch.rand(2)
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
c = a + b
self.assertEqual(c, a - b)
self.assertTrue(first_called)
self.assertTrue(second_called)
def test_fallback_fallthrough(self) -> None:
test_key_first = "TESTING_ONLY_GenericMode"
test_key_second = "TESTING_ONLY_GenericWrapper"
test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet(
test_key_second
)
include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
with _scoped_library("_", "IMPL") as my_lib:
is_called = False
def my_fallback(op, *args, **kwargs):
nonlocal is_called
is_called = True
with torch._C._ForceDispatchKeyGuard(
include_to_set, exclude_to_set | test_keyset
):
return op(*args, **kwargs)
my_lib.fallback(torch.library.fallthrough_kernel, test_key_first)
my_lib.fallback(my_fallback, test_key_second)
a, b = torch.rand(2), torch.rand(2)
with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
c = a + b
self.assertEqual(c, a + b)
self.assertTrue(is_called)
@unittest.skip(
"Causing flakiness, see https://github.com/pytorch/pytorch/issues/145108"
)
def test_fallthrough_for_dense_key_with_meta_in_tls(self) -> None:
# This tests that if meta is included in TlS dispatch key set,
# then a meta kernel should be called regardless if a dense
# backend has a fallthrough kernel
a = torch.randn((3, 3))
with _scoped_library("custom", "DEF") as my_lib:
my_lib.define("sum(Tensor self) -> Tensor")
meta_is_called = False
def sum_meta(*args, **kwargs):
nonlocal meta_is_called
meta_is_called = True
return args[0]
my_lib.impl("sum", fallthrough_kernel, "CPU")
my_lib.impl("sum", sum_meta, "Meta")
with torch._C._IncludeDispatchKeyGuard(torch.DispatchKey.Meta):
torch.ops.custom.sum.default(a)
self.assertTrue(meta_is_called)
def test_override_aten_ops_with_multiple_libraries(self) -> None:
x = torch.tensor([1, 2])
with _scoped_library("aten", "IMPL") as my_lib2:
with _scoped_library("aten", "IMPL") as my_lib1:
# Example 1
def my_neg(*args, **kwargs):
return args[0]._neg_view()
# Now we are secretly making the operator a view op so autograd needs to know how
# to handle it
my_lib1.impl("neg", my_neg, "AutogradCPU")
self.assertTrue(torch.neg(x).is_neg())
# RuntimeError: impl("aten::neg", ...):
# Explicitly provided namespace (aten) in operator name does not match ...
with self.assertRaisesRegex(
RuntimeError, "operator name does not match namespace"
):
with _scoped_library("foo", "DEF") as my_lib3:
my_lib3.define("neg(Tensor self) -> Tensor")
my_lib3.impl(torch.ops.aten.neg.default, my_neg, "AutogradCPU")
# Example 2
def my_mul(*args, **kwargs):
return torch.zeros_like(args[0])
# torch.ops.aten.mul.Tensor
my_lib2.impl("aten::mul.Tensor", my_mul, "ZeroTensor")
y = torch._efficientzerotensor(2)
self.assertFalse(torch.mul(x, y)._is_zerotensor())
# Assert that a user can't override the behavior of a (ns, op, dispatch_key)
# combination if someone overridden the behavior for the same before them
with self.assertRaisesRegex(
RuntimeError, "already a kernel registered from python"
):
my_lib2.impl(torch.ops.aten.mul.Tensor, my_mul, "ZeroTensor")
# Validate that lib2 is not affected by removing lib1
self.assertFalse(torch.mul(x, y)._is_zerotensor())
# Validate that the old behavior is restored for neg and mul
self.assertFalse(torch.neg(x).is_neg())
self.assertTrue(torch.mul(x, y)._is_zerotensor())
def test_error_if_fn_not_callable(self):
with self.assertRaisesRegex(
TypeError, "Input function is required to be a callable"
):
with _scoped_library("aten", "IMPL") as my_lib:
my_lib.impl(torch.ops.aten.neg.default, [], "AutogradCPU")
def test_finalizer(self):
impls_refcnt = sys.getrefcount(torch.library._impls)
lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901
lib.define("foo123(Tensor x) -> Tensor")
# 1 for `lib`, 1 for sys.getrefcount
self.assertEqual(sys.getrefcount(lib), 2)
# We gained an additional reference that gets cleared when the finalizer runs
self.assertEqual(sys.getrefcount(torch.library._impls), impls_refcnt + 1)
# 1 for `lib`
# 1 for the finalizer
# 1 for sys.getrefcount
self.assertEqual(sys.getrefcount(lib._op_impls), 3)
def foo123(x):
pass
lib.impl(f"{self.test_ns}::foo123", foo123, "CPU")
key = f"{self.test_ns}/foo123/CPU"
self.assertTrue(key in torch.library._impls)
saved_op_impls = lib._op_impls
# del will definitely work if the following passes
self.assertEqual(sys.getrefcount(lib), 2)
del lib
# 1 for saved_op_impls
# 1 for sys.getrefcount
# This function should be the last user of lib._op_impls:
# - lib should not have a reference anymore (it was del'ed)
# - lib's finalizer should not have a reference anymore
self.assertEqual(sys.getrefcount(saved_op_impls), 2)
self.assertTrue(key not in torch.library._impls)
# lib's finalizer should not have a reference anymore
self.assertEqual(sys.getrefcount(torch.library._impls), impls_refcnt)
def test_override_cpu_sum(self) -> None:
# Example 1
run = [False]
def my_sum(*args, **kwargs):
run[0] = True
return args[0].clone()
with _scoped_library("aten", "IMPL") as my_lib1:
my_lib1.impl("aten::sum", my_sum, "CPU")
x = torch.tensor([1, 2])
self.assertEqual(torch.sum(x), x)
self.assertTrue(run[0])
# Validate that the old behavior is restored for sum
self.assertEqual(torch.sum(x), torch.tensor(3))
def test_override_cuda_with_jiterator(self) -> None:
def override_where_cuda() -> None:
# Example 1: Invert the behavior of where's condition input
not_where_code_string = """
template <typename T> T inverted_where(bool cond, T a, T b){
return !cond ? a : b;
}
"""
jitted_where = _create_jit_fn(not_where_code_string)
CALLED = [False]
def inverted_where(*args, **kwargs):
CALLED[0] = True
return jitted_where(*args, **kwargs)
# overriding where's cuda kernel with Jiterator generated kernel
with _scoped_library("aten", "IMPL") as my_lib:
my_lib.impl("aten::where.self", inverted_where, "CUDA")
device = "cuda"
cond = torch.tensor(
[True, True, False], device=device, dtype=torch.bool
)
x = torch.tensor([1, 2, 3], device=device)
y = torch.tensor([-1, -2, -3], device=device)
self.assertEqual(torch.where(cond, x, y), torch.tensor([-1, -2, 3]))
self.assertTrue(CALLED[0])
# behavior restored after deregistration
self.assertEqual(torch.where(cond, x, y), torch.tensor([1, 2, -3]))
def override_gelu_cuda() -> None:
# Example 2: Use relu to approximate gelu for faster compute
fastest_gelu_code_string = """
template <typename T> T fast_gelu(T a){
return a > 0 ? a : 0;
}
"""
jitted_gelu = _create_jit_fn(fastest_gelu_code_string)
CALLED = [False]
def fast_gelu(*args, **kwargs):
CALLED[0] = True
return jitted_gelu(*args, **kwargs)
# overriding gelu's cuda kernel with Jiterator generated relu kernel
with _scoped_library("aten", "IMPL") as my_lib:
my_lib.impl("aten::gelu", fast_gelu, "CUDA")
x = torch.rand([3, 3], device="cuda", dtype=torch.float)
self.assertEqual(
torch.nn.functional.gelu(x), torch.nn.functional.relu(x)
)
self.assertTrue(CALLED[0])
# behavior restored after deregistration
self.assertNotEqual(
torch.nn.functional.gelu(x), torch.nn.functional.relu(x)
)
def override_exp_cuda() -> None:
# Example 3: Preventing exp from exploding for float16
clipped_exp_code_string = """
template <typename T> T clipped_exp(T a){
return a > T(10.0) ? T(22026.4657948) : exp(a);
}
"""
jitted_exp = _create_jit_fn(clipped_exp_code_string)
CALLED = [False]
def clipped_exp(*args, **kwargs):
CALLED[0] = True
return jitted_exp(*args, **kwargs)
# overriding exp's cuda kernel with clipped_exp kernel
with _scoped_library("aten", "IMPL") as my_lib:
my_lib.impl("aten::exp", clipped_exp, "CUDA")
x = torch.tensor([0.0, 100.0], device="cuda", dtype=torch.float16)
self.assertEqual(
torch.exp(x),
torch.tensor([1.0, 22026.4657948], dtype=torch.float16),
)
self.assertTrue(CALLED[0])
# behavior restored after deregistration
self.assertEqual(
torch.exp(x), torch.tensor([1.0, torch.inf], dtype=torch.float16)
)
def override_add_cuda() -> None:
# Example 4: simulate a hardware bug, where the adder is always off by 1
buggy_add_code_string = """
template <typename T> T buggy_add(T a, T b){
return a + b + T(1);
}
"""
jitted_add = _create_jit_fn(buggy_add_code_string)
CALLED = [False]
def buggy_add(*args, **kwargs):
CALLED[0] = True
return jitted_add(*args, **kwargs)
with _scoped_library("aten", "IMPL") as my_lib:
my_lib.impl("aten::add.Tensor", buggy_add, "CUDA")
x_cpu = torch.rand([3, 3], device="cpu")
y_cpu = torch.rand([3], device="cpu")
x_cuda = x_cpu.cuda()
y_cuda = y_cpu.cuda()
self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu + 1)
self.assertTrue(CALLED[0])
# behavior restored after deregistration
self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu)
if torch.cuda.is_available() and not TEST_WITH_ROCM:
override_where_cuda()
override_gelu_cuda()
override_exp_cuda()
override_add_cuda()
def test_extend_library_with_dispatch_key_arg(self):
def my_sum(*args, **kwargs):
return args[0].clone()
with _scoped_library("aten", "IMPL", dispatch_key="CPU") as my_lib1:
# RuntimeError: Explicitly provided dispatch key (Conjugate) is
# inconsistent with the dispatch key of the enclosing TORCH_LIBRARY_IMPL block
with self.assertRaisesRegex(
RuntimeError, "inconsistent with the dispatch key"
):
my_lib1.impl("sum", my_sum, "Conjugate")
my_lib1.impl("aten::sum", my_sum)
x = torch.tensor([1, 2])
self.assertEqual(torch.sum(x), x)
def test_create_new_library(self) -> None:
with _scoped_library(self.test_ns, "DEF") as my_lib1:
my_lib1.define("sum(Tensor self) -> Tensor")
# Example 1
@torch.library.impl(my_lib1, "sum", "CPU")
def my_sum(*args, **kwargs):
return args[0].clone()
x = torch.tensor([1, 2])
op = getattr(torch.ops, self.test_ns).sum
self.assertEqual(op(x), x)
with _scoped_library(self.test_ns, "IMPL") as my_lib2:
# Example 2
@torch.library.impl(my_lib2, op.default, "ZeroTensor")
def my_sum_zt(*args, **kwargs):
if args[0]._is_zerotensor():
return torch._efficientzerotensor(args[0].shape)
else:
return args[0].clone()
y = torch._efficientzerotensor(3)
self.assertTrue(op(y)._is_zerotensor())
self.assertEqual(op(x), x)
def test_create_new_library_fragment_no_existing(self):
with _scoped_library(self.test_ns, "FRAGMENT") as my_lib:
my_lib.define("sum2(Tensor self) -> Tensor")
@torch.library.impl(my_lib, "sum2", "CPU")
def my_sum(*args, **kwargs):
return args[0]
x = torch.tensor([1, 2])
self.assertEqual(getattr(torch.ops, self.test_ns).sum2(x), x)
def test_create_new_library_fragment_with_existing(self):
with _scoped_library(self.test_ns, "DEF") as my_lib1:
# Create a fragment
with _scoped_library(self.test_ns, "FRAGMENT") as my_lib2:
my_lib2.define("sum4(Tensor self) -> Tensor")
@torch.library.impl(my_lib2, "sum4", "CPU")
def my_sum4(*args, **kwargs):
return args[0]
x = torch.tensor([1, 2])
self.assertEqual(getattr(torch.ops, self.test_ns).sum4(x), x)
# Create another fragment
with _scoped_library(self.test_ns, "FRAGMENT") as my_lib3:
my_lib3.define("sum3(Tensor self) -> Tensor")
@torch.library.impl(my_lib3, "sum3", "CPU")
def my_sum3(*args, **kwargs):
return args[0]
x = torch.tensor([1, 2])
self.assertEqual(getattr(torch.ops, self.test_ns).sum3(x), x)
@unittest.skipIf(IS_WINDOWS, "Skipped under Windows")
def test_alias_analysis(self):
def test_helper(alias_analysis=""):
my_lib1 = Library(self.test_ns, "DEF") # noqa: TOR901
called = [0]
@torch.library.define(
my_lib1, "_op() -> None", alias_analysis=alias_analysis
)
def _op(*args, **kwargs):
called[0] += 1
@torch.jit.script
def _test():
torch.ops._test_python_registration._op()
assert "_test_python_registration::_op" in str(_test.graph)
with self.assertRaises(AssertionError):
test_helper("") # alias_analysis="FROM_SCHEMA"
test_helper("CONSERVATIVE")
def test_error_for_unsupported_ns_or_kind(self) -> None:
with self.assertRaisesRegex(ValueError, "Unsupported kind"):
my_lib1 = Library("myns", "BLA") # noqa: TOR901
for kind in ("DEF", "FRAGMENT"):
with self.assertRaisesRegex(ValueError, "reserved namespace"):
my_lib1 = Library("prim", kind) # noqa: TOR901
def test_returning_symint(self) -> None:
shape_env = ShapeEnv()
fake_tensor_mode = FakeTensorMode(shape_env=shape_env)
ft = fake_tensor_mode.from_tensor(torch.rand(2, 3))
s0, s1 = ft.shape
with _scoped_library(self.test_ns, "DEF") as tlib:
tlib.define("sqsum(SymInt a, SymInt b) -> SymInt")
@impl(tlib, "sqsum", "CompositeExplicitAutograd")
def sqsum(a: SymInt, b: SymInt):
return a * a + b * b
out = getattr(torch.ops, self.test_ns).sqsum.default(s0, s1)
out_val = shape_env.evaluate_expr(out.node.expr)
self.assertEqual(out_val, 13)
def test_register_fallthrough(self):
with _scoped_library("aten", "IMPL") as my_lib:
my_lib.impl("mm", fallthrough_kernel, "AutocastCPU")
a = torch.randn(2, 3, device="cpu", dtype=torch.float32)
b = torch.randn(3, 2, device="cpu", dtype=torch.float32)
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
# dtype for mm should be float32 since we registered a fallthrough
self.assertEqual(torch.mm(a, b).dtype, torch.float32)
# ops that don't have a fallthrough registered should not be affected
self.assertEqual(torch.matmul(a, b).dtype, torch.bfloat16)
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
# default behavior should have been restored
self.assertEqual(torch.mm(a, b).dtype, torch.bfloat16)
class TestPythonDispatch(TestCase):
def test_basic(self) -> None:
with capture_logs() as logs:
x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
log_input("x", x)
y = x * x
saved_x = y.grad_fn._saved_self
grad_y = LoggingTensor(torch.tensor([1.0]))
log_input("grad_y", grad_y)
(g,) = torch.autograd.grad((y,), (x,), (grad_y,))
self.assertEqual(g.elem, torch.tensor([6.0]))
with torch.no_grad():
self.assertEqual(saved_x, x)
self.assertEqual(saved_x._version, x._version)
x.add_(2)
self.assertEqual(saved_x, x)
# TODO: figure out why broken
# self.assertEqual(saved_x._version, x._version)
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[1] = input('x')
$1: f32[1] = torch._ops.aten.mul.Tensor($0, $0)
$2: f32[1] = input('grad_y')
$3: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
$4: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
$5: f32[1] = torch._ops.aten.add.Tensor($4, $3)""",
)
def test_out(self) -> None:
with capture_logs() as logs:
x = LoggingTensor(torch.ones(1))
y = LoggingTensor(torch.zeros(1))
log_input("x", x)
log_input("y", y)
torch.abs(x, out=y)
self.assertEqual(y.elem, torch.ones(1))
# TODO: arguably this shouldn't pass and we should complain
# that out isn't a kwarg
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[1] = input('x')
$1: f32[1] = input('y')
$2: f32[1] = torch._ops.aten.abs.out($0, out=$1)""",
)
def test_kwarg_only(self) -> None:
with capture_logs() as logs:
x = LoggingTensor(torch.ones(1))
y = LoggingTensor(torch.ones(1, 1))
z = LoggingTensor(torch.ones(1))
log_input("x", x)
log_input("y", y)
log_input("z", z)
torch.addmv(x, y, z)
torch.addmv(x, y, z, beta=1)
torch.addmv(x, y, z, beta=2)
torch.addmv(x, y, z, alpha=2)
torch.addmv(x, y, z, beta=2, alpha=2)
# The expectation is that beta/alpha don't show up when they're
# defaulted. This is even if the user explicitly specified it.
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[1] = input('x')
$1: f32[1, 1] = input('y')
$2: f32[1] = input('z')
$3: f32[1] = torch._ops.aten.addmv.default($0, $1, $2)
$4: f32[1] = torch._ops.aten.addmv.default($0, $1, $2)
$5: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2)
$6: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, alpha=2)
$7: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2, alpha=2)""",
)
def test_kwarg_only_and_positional_default(self) -> None:
with capture_logs() as logs:
x = LoggingTensor(torch.ones(1))
log_input("x", x)
torch.ops.aten._foobar(x)
torch.ops.aten._foobar(x, False)
torch.ops.aten._foobar(x, arg3=False)
torch.ops.aten._foobar(x, False, arg3=False)
# What we are testing here is that we omit arg2
# if it is defaulted, even if a kwarg is set
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[1] = input('x')
$1: f32[1] = torch._ops.aten._foobar.default($0)
$2: f32[1] = torch._ops.aten._foobar.default($0, False)
$3: f32[1] = torch._ops.aten._foobar.default($0, arg3=False)
$4: f32[1] = torch._ops.aten._foobar.default($0, False, arg3=False)""",
)
def test_produce_real_type(self) -> None:
with capture_logs() as logs:
x = LoggingTensor(torch.ones(2, 2))
log_input("x", x)
x.to(dtype=torch.double) # non-optional dtype
torch.cumprod(x, 0, dtype=torch.double) # optional dtype
x[:, 1].contiguous(
memory_format=torch.contiguous_format
) # optional memory format
# There doesn't appear to be any layout signatures which are
# triggerable using tensor subclasses (need to use a mode)
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[2, 2] = input('x')
$1: f64[2, 2] = torch._ops.aten._to_copy.default($0, dtype=torch.float64)
$2: f64[2, 2] = torch._ops.aten.cumprod.default($0, 0, dtype=torch.float64)
$3: f32[2, 2] = torch._ops.aten.slice.Tensor($0, 0, 0, 9223372036854775807)
$4: f32[2] = torch._ops.aten.select.int($3, 1, 1)
$5: f32[2] = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)""",
)
def test_optional_tensor_list(self) -> None:
def weird(xs):
print("woof")
return torch.empty(())
with _scoped_library("my_lib", "DEF") as my_lib:
my_lib.define("weird(Tensor?[] self) -> Tensor")
my_lib.impl("weird", weird, "CPU")
with capture_logs() as logs:
x = LoggingTensor(torch.ones(2, 2))
log_input("x", x)
torch.ops.my_lib.weird.default([None, x])
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[2, 2] = input('x')
$1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])""",
)
def test_list_ret(self) -> None:
# test all sequence types are permissible returns
for list_type in (list, tuple):
class A(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if func.overloadpacket == torch.ops.aten.split:
with no_dispatch():
return list_type(torch.split(*args))
else:
raise AssertionError(f"unrecognized func: {func}")
self.assertEqual(
torch.split(A(torch.tensor([0, 1])), 2),
torch.split(torch.tensor([0, 1]), 2),
)
def test_invalid_ret(self) -> None:
# test invalid return gets reasonable error message
class A(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
return "arf"
# Wobbles depending on NDEBUG mode of pybind11
self.assertRaisesRegex(
RuntimeError,
"Unable to cast",
lambda: A(torch.zeros(1)).neg(),
)
self.assertRaisesRegex(
RuntimeError,
"Unable to cast",
lambda: A(torch.zeros(1)).detach(),
)
def test_detach_appears_twice_when_called_once(self) -> None:
with capture_logs() as logs:
x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
log_input("x", x)
x.detach()
# FIXME: We actually want this to emit a single detach. However,
# it currently emits two, for reasons unclear to us. Leaving
# this test here to make sure we don't regress even further (it
# would be bad if calling .detach() once emits 3+ detaches).
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[1] = input('x')
$1: f32[1] = torch._ops.aten.detach.default($0)
$2: f32[1] = torch._ops.aten.detach.default($1)""",
)
def test_storage(self) -> None:
# For now, just make sure it doesn't crash. Ideally, we should
# return some virtual storage that is safe to work with
x = LoggingTensor(torch.ones(1))
storage = x.untyped_storage()
self.assertRaises(RuntimeError, lambda: storage.data_ptr())
def test_make_wrapper_subclass_noalloc(self) -> None:
# This is ludicrously big (8TB) and this should pass because wrapper
# subclasses don't allocate
torch.Tensor._make_wrapper_subclass(LoggingTensor, (1000000000000,))
def test_version(self) -> None:
x = LoggingTensor(torch.ones(1))
prev_vc = x._version
x.detach().add_(2)
cur_vc = x._version
self.assertNotEqual(prev_vc, cur_vc)
x.data.add_(2)
self.assertEqual(cur_vc, x._version)
def test_subclass_priority(self) -> None:
class ErrorA(RuntimeError):
pass
class ErrorB(RuntimeError):
pass
# The big tests for code coverage are test_precedence_semantics in
# test_overrides.py; this is just to make sure it is wired up at all
# correctly for __torch_dispatch__
class A(torch.Tensor):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
raise ErrorA
class B(A):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
raise ErrorB
self.assertRaises(
ErrorA, lambda: torch.add(A(torch.empty(1)), A(torch.empty(1)))
)
self.assertRaises(
ErrorB, lambda: torch.add(A(torch.empty(1)), B(torch.empty(1)))
)
self.assertRaises(
ErrorB, lambda: torch.add(B(torch.empty(1)), A(torch.empty(1)))
)
self.assertRaises(
ErrorB, lambda: torch.add(B(torch.empty(1)), B(torch.empty(1)))
)
def test_format(self) -> None:
x = LoggingTensor(torch.ones(1))
s1 = str(x)
s2 = repr(x)
s3 = f"{x}"
self.assertExpectedInline(s1, """LoggingTensor(tensor([1.]))""")
self.assertEqual(s1, s2)
self.assertEqual(s1, s3)
def test_custom_autograd(self) -> None:
escape = [None]
class Square(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
y = x**2
ctx.save_for_backward(x)
return y
@staticmethod
def backward(ctx, grad_output):
assert isinstance(grad_output, LoggingTensor)
(x,) = ctx.saved_tensors
assert isinstance(x, LoggingTensor)
escape[0] = x
return grad_output * 2 * x
with capture_logs() as logs:
x = LoggingTensor(torch.ones(1), requires_grad=True)
log_input("x", x)
x.grad = LoggingTensor(torch.zeros(1))
log_input("x.grad", x.grad)
y = Square.apply(x)
grad_output = LoggingTensor(torch.ones(1))
log_input("grad_output", grad_output)
y.backward(grad_output)
with torch.no_grad():
self.assertEqual(escape[0], x)
self.assertEqual(escape[0]._version, x._version)
# TODO: figure out why x.requires_grad = False doesn't
# trigger an error for LoggingTensor
x.add_(2)
self.assertEqual(escape[0], x)
# TODO: figure out why this is broken
# self.assertEqual(escape[0]._version, x._version)
self.assertExpectedInline(
"\n".join(logs),
"""\
$0: f32[1] = input('x')
$1: f32[1] = input('x.grad')
$2: f32[1] = torch._ops.aten.pow.Tensor_Scalar($0, 2)
$3: f32[1] = input('grad_output')
$4: f32[1] = torch._ops.aten.mul.Tensor($3, 2)
$5: f32[1] = torch._ops.aten.mul.Tensor($4, $0)
$6: f32[1] = torch._ops.aten.add_.Tensor($1, $5)""",
)
def test_subclass_creation(self):
# Make sure these statements runs without error
# In particular checking that when internal detach returns
# subclasses, these are cleanly overwritten.
class Foo(torch.Tensor):
pass
err_msg = "subclass Foo but.*already associated to a python object of type LoggingTensor"
with self.assertRaisesRegex(RuntimeError, err_msg):
a = torch.Tensor._make_subclass(Foo, LoggingTensor(torch.rand(2)))
with self.assertRaisesRegex(RuntimeError, err_msg):
b = LoggingTensor(torch.rand(2)).as_subclass(Foo)
with self.assertRaisesRegex(RuntimeError, err_msg):
Foo(LoggingTensor(torch.rand(2)))
with self.assertRaisesRegex(TypeError, "Foo must define __torch_dispatch__"):
torch.Tensor._make_wrapper_subclass(Foo, (2, 2))
def test_new_ones(self) -> None:
class MyTensor(torch.Tensor):
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
return MyTensor(3)
self.assertEqual(type(MyTensor(2).new_ones(3)), MyTensor)
def test_like(self) -> None:
class MyTensor(torch.Tensor):
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
return MyTensor(3)
for f in ["empty", "ones", "rand", "randn", "zeros"]:
f_name = f + "_like"
self.assertEqual(type(getattr(torch, f_name)(MyTensor(2))), MyTensor)
self.assertEqual(type(torch.full_like(MyTensor(2), 1.0)), MyTensor)
self.assertEqual(type(torch.randint_like(MyTensor(2), high=3)), MyTensor)
def test_make_fx_with_subclass(self) -> None:
def f(x, y):
# Returns (TwoTensor, Tensor)
return x * y, y + y
x_a = torch.zeros(4)
x_b = torch.zeros(4)
y = torch.ones(4)
# make_fx() is not responsible for unwrapping tensor subclass inputs,
# so we do it manually here.
# Why? In general, make_fx(f)(*args) promises that the graph returned has the same calling
# convention as f(*args). Unwrapping tensor subclass inputs can potentially change
# the number of input args to the graph, breaking that assumption
def f_to_trace(x_a, x_b, y):
x = TwoTensor(x_a, x_b)
out1, out2 = f(x, y)
out1_unwrapped_attrs, _ = out1.__tensor_flatten__()
return (*[getattr(out1, attr) for attr in out1_unwrapped_attrs], out2)
fx_g = make_fx(f_to_trace, tracing_mode="fake")(x_a, x_b, y)
self.assertExpectedInline(
fx_g.code,
"""\