Skip to content

Commit 78e53a9

Browse files
masnesralpytorchmergebot
authored andcommitted
Remove monkeypatch of has_frozen_params in test/inductor/test_codecache.py (pytorch#141898)
Summary: This particular test isn't really needed since the code path is already exercised in `test_freezing`. While I was here, I beefed up testing in that method to consider whether the frozen paramater is inlinable vs. not since the caching behavior is different. Pull Request resolved: pytorch#141898 Approved by: https://github.com/ezyang, https://github.com/jansel
1 parent 42547f8 commit 78e53a9

File tree

2 files changed

+31
-76
lines changed

2 files changed

+31
-76
lines changed

test/inductor/test_codecache.py

Lines changed: 28 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -425,24 +425,6 @@ def fn2(x):
425425
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 2)
426426
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
427427

428-
# Now pretend the constants are frozen params.
429-
counters.clear()
430-
self.reset()
431-
432-
with mock.patch(
433-
"torch._inductor.output_code.has_frozen_params", return_value=True
434-
):
435-
# A call to fn1 should miss in the cache since we do not consider
436-
# the constant values.
437-
self.assertEqual(fn1(a), compiled_fn1(a))
438-
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
439-
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
440-
441-
# A call to fn2 should hit for the same reason.
442-
self.assertEqual(fn2(a), compiled_fn2(a))
443-
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
444-
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
445-
446428
@requires_cuda
447429
@config.patch({"fx_graph_cache": True})
448430
@config.patch({"fx_graph_remote_cache": False})
@@ -806,14 +788,28 @@ def f(x, val):
806788
@config.patch({"fx_graph_remote_cache": False})
807789
@config.patch({"freezing": True})
808790
@parametrize("device", (GPU_TYPE, "cpu"))
809-
def test_freezing(self, device):
791+
@parametrize("inlinable", (True, False))
792+
def test_freezing(self, device, inlinable):
810793
if device == GPU_TYPE and not HAS_GPU:
811794
raise unittest.SkipTest(f"requires {GPU_TYPE}")
812795

796+
# For machines with mkldnn_fp16 support, weight_pack in mkldnn_fusion.py causes
797+
# the creation of a mkldnn format tensor which the current implementation does
798+
# not support.
799+
if (
800+
device == "cpu"
801+
and torch.backends.mkldnn.is_available()
802+
and torch.ops.mkldnn._is_mkldnn_fp16_supported()
803+
):
804+
raise unittest.SkipTest("mkldnn tensors unsupported")
805+
806+
# The shape of the frozen constant determines if it will be inlined.
807+
shape = (4,) if inlinable else (8, 8)
808+
813809
class MM(torch.nn.Module):
814810
def __init__(self) -> None:
815811
super().__init__()
816-
self.param = torch.nn.Parameter(torch.rand(8, 8))
812+
self.param = torch.nn.Parameter(torch.rand(shape))
817813

818814
def forward(self, x):
819815
return x @ self.param
@@ -823,71 +819,37 @@ def forward(self, x):
823819
# Populate a cache entry.
824820
mod1 = MM().to(device=device, dtype=dtype)
825821
with torch.no_grad():
826-
x = torch.rand(8, 8).to(device=device, dtype=dtype)
822+
x = torch.rand(shape).to(device=device, dtype=dtype)
827823
out0 = mod1(x)
828824
out1 = torch.compile(mod1)(x)
829825
self.assertEqual(out0, out1)
830826

831-
# For mahcine that has mkldnn_fp16 support, the weight_pack in mkldnn_fusion.py
832-
# wroks, which result in mkldnn format tensor, then the exception
833-
# BypassFxGraphCache("mkldnn tensors unpickleable") is raised, and cause the
834-
# fxgraph not cached.
835-
def is_cpu_mkldnn_fp16_supported():
836-
return (
837-
device == "cpu"
838-
and torch.backends.mkldnn.is_available()
839-
and torch.ops.mkldnn._is_mkldnn_fp16_supported()
840-
)
841-
842-
if is_cpu_mkldnn_fp16_supported():
843-
fxgraph_cache_bypass_cnt = 1
844-
fxgraph_cache_miss_cnt = 0
845-
fxgraph_cache_hit_cnt = 0
846-
else:
847-
fxgraph_cache_bypass_cnt = 0
848-
fxgraph_cache_miss_cnt = 1
849-
fxgraph_cache_hit_cnt = 0
850-
851-
self.assertEqual(
852-
counters["inductor"]["fxgraph_cache_bypass"], fxgraph_cache_bypass_cnt
853-
)
854-
self.assertEqual(
855-
counters["inductor"]["fxgraph_cache_miss"], fxgraph_cache_miss_cnt
856-
)
857-
self.assertEqual(
858-
counters["inductor"]["fxgraph_cache_hit"], fxgraph_cache_hit_cnt
859-
)
827+
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
828+
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
829+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
860830

861831
counters.clear()
862832
self.reset()
863833

864-
# Same nn.Module, but with different parameters should cache hit.
834+
# Same nn.Module, but with different parameters. In the case that the param can
835+
# be inlined, we should consider the actual tensor value and we expect a cache
836+
# miss (because the values are different here). If the param cannot be inlined,
837+
# then we consider only the tensor metadata and we expect a cache hit.
865838
mod2 = MM().to(device=device, dtype=dtype)
866839
self.assertNotEqual(mod1.param, mod2.param)
867840

868841
with torch.no_grad():
869-
x = torch.rand(8, 8).to(device=device, dtype=dtype)
842+
x = torch.rand(shape).to(device=device, dtype=dtype)
870843
out0 = mod2(x)
871844
out1 = torch.compile(mod2)(x)
872845
self.assertEqual(out0, out1)
873846

874-
if is_cpu_mkldnn_fp16_supported():
875-
fxgraph_cache_bypass_cnt = 1
876-
fxgraph_cache_miss_cnt = 0
877-
fxgraph_cache_hit_cnt = 0
878-
else:
879-
fxgraph_cache_bypass_cnt = 0
880-
fxgraph_cache_miss_cnt = 0
881-
fxgraph_cache_hit_cnt = 1
882-
883-
self.assertEqual(
884-
counters["inductor"]["fxgraph_cache_bypass"], fxgraph_cache_bypass_cnt
885-
)
847+
self.assertEqual(counters["inductor"]["fxgraph_cache_bypass"], 0)
886848
self.assertEqual(
887-
counters["inductor"]["fxgraph_cache_miss"], fxgraph_cache_miss_cnt
849+
counters["inductor"]["fxgraph_cache_miss"], 1 if inlinable else 0
888850
)
889851
self.assertEqual(
890-
counters["inductor"]["fxgraph_cache_hit"], fxgraph_cache_hit_cnt
852+
counters["inductor"]["fxgraph_cache_hit"], 0 if inlinable else 1
891853
)
892854

893855

torch/_inductor/codecache.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,6 @@
4949
)
5050

5151
import torch
52-
53-
# WARNING: Do not directly import has_frozen_params, it is monkeypatched in
54-
# python test/inductor/test_codecache.py
55-
# TestFxGraphCache.test_constant_handling_device_cpu
56-
# TODO: Why are we monkeypatching it......
57-
import torch._inductor.output_code as output_code
5852
import torch.distributed as dist
5953
from torch import SymInt, Tensor
6054
from torch._dynamo.utils import (
@@ -70,6 +64,7 @@
7064
rocm_compiler,
7165
)
7266
from torch._inductor.custom_graph_pass import CustomGraphPass, CustomGraphPassType
67+
from torch._inductor.output_code import has_frozen_params
7368
from torch._utils_internal import log_cache_bypass
7469

7570
from .remote_cache import create_cache
@@ -897,7 +892,7 @@ def compiled_fx_graph_hash(
897892
# To support caching when the graph has frozen params, we ignore the tensor values
898893
# of non-inlined constants since they won't be included in the cache entry. Without
899894
# freezing, we want to include the values of any constant attribute.
900-
include_non_inlined = not output_code.has_frozen_params(gm)
895+
include_non_inlined = not has_frozen_params(gm)
901896

902897
details = FxGraphHashDetails(gm, example_inputs, fx_kwargs, inputs_to_check)
903898
has_user_defined_triton_kernels = len(details.user_defined_triton_source) != 0
@@ -1400,9 +1395,7 @@ def _check_can_cache(gm: torch.fx.GraphModule) -> None:
14001395
raise BypassFxGraphCache("Unsupported post grad custom pass")
14011396

14021397
# Freezing can embed constants that wouldn't be static across runs.
1403-
if output_code.has_frozen_params(
1404-
gm
1405-
) and not torch._utils_internal.justknobs_check(
1398+
if has_frozen_params(gm) and not torch._utils_internal.justknobs_check(
14061399
"pytorch/inductor:allow_freezing_with_caching"
14071400
):
14081401
raise BypassFxGraphCache("Skipping graph with frozen constants")

0 commit comments

Comments
 (0)