forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_hop_infra.py
92 lines (75 loc) · 3.05 KB
/
test_hop_infra.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Owner(s): ["module: higher order operators"]
import importlib
import pkgutil
import torch
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
from torch.testing._internal.hop_db import (
FIXME_hop_that_doesnt_have_opinfo_test_allowlist,
hop_db,
)
def do_imports():
for mod in pkgutil.walk_packages(
torch._higher_order_ops.__path__, "torch._higher_order_ops."
):
modname = mod.name
importlib.import_module(modname)
do_imports()
@skipIfTorchDynamo("not applicable")
class TestHOPInfra(TestCase):
def test_all_hops_have_opinfo(self):
"""All HOPs should have an OpInfo in torch/testing/_internal/hop_db.py"""
from torch._ops import _higher_order_ops
hops_that_have_op_info = {k.name for k in hop_db}
all_hops = _higher_order_ops.keys()
missing_ops = set()
for op in all_hops:
if (
op not in hops_that_have_op_info
and op not in FIXME_hop_that_doesnt_have_opinfo_test_allowlist
):
missing_ops.add(op)
self.assertTrue(
len(missing_ops) == 0,
f"Missing hop_db OpInfo entries for {missing_ops}, please add them to torch/testing/_internal/hop_db.py",
)
def test_all_hops_are_imported(self):
"""All HOPs should be listed in torch._higher_order_ops.__all__
Some constraints (see test_testing.py::TestImports)
- Sympy must be lazily imported
- Dynamo must be lazily imported
"""
imported_hops = torch._higher_order_ops.__all__
registered_hops = torch._ops._higher_order_ops.keys()
# Please don't add anything here.
# We want to ensure that all HOPs are imported at "import torch" time.
# It is bad if someone tries to access torch.ops.higher_order.cond
# and it doesn't exist (this may happen if your HOP isn't imported at
# "import torch" time).
FIXME_ALLOWLIST = {
"autograd_function_apply",
"run_with_rng_state",
"graphsafe_run_with_rng_state",
"map_impl",
"_export_tracepoint",
"run_and_save_rng_state",
"map",
"custom_function_call",
"trace_wrapped",
"triton_kernel_wrapper_functional",
"triton_kernel_wrapper_mutation",
"wrap", # Really weird failure -- importing this causes Dynamo to choke on checkpoint
}
not_imported_hops = registered_hops - imported_hops
not_imported_hops = not_imported_hops - FIXME_ALLOWLIST
self.assertEqual(
not_imported_hops,
set(),
msg="All HOPs must be listed under torch/_higher_order_ops/__init__.py's __all__.",
)
def test_imports_from_all_work(self):
"""All APIs listed in torch._higher_order_ops.__all__ must be importable"""
stuff = torch._higher_order_ops.__all__
for attr in stuff:
getattr(torch._higher_order_ops, attr)
if __name__ == "__main__":
run_tests()