@@ -425,24 +425,6 @@ def fn2(x):
425425 self .assertEqual (counters ["inductor" ]["fxgraph_cache_miss" ], 2 )
426426 self .assertEqual (counters ["inductor" ]["fxgraph_cache_hit" ], 0 )
427427
428- # Now pretend the constants are frozen params.
429- counters .clear ()
430- self .reset ()
431-
432- with mock .patch (
433- "torch._inductor.output_code.has_frozen_params" , return_value = True
434- ):
435- # A call to fn1 should miss in the cache since we do not consider
436- # the constant values.
437- self .assertEqual (fn1 (a ), compiled_fn1 (a ))
438- self .assertEqual (counters ["inductor" ]["fxgraph_cache_miss" ], 1 )
439- self .assertEqual (counters ["inductor" ]["fxgraph_cache_hit" ], 0 )
440-
441- # A call to fn2 should hit for the same reason.
442- self .assertEqual (fn2 (a ), compiled_fn2 (a ))
443- self .assertEqual (counters ["inductor" ]["fxgraph_cache_miss" ], 1 )
444- self .assertEqual (counters ["inductor" ]["fxgraph_cache_hit" ], 1 )
445-
446428 @requires_cuda
447429 @config .patch ({"fx_graph_cache" : True })
448430 @config .patch ({"fx_graph_remote_cache" : False })
@@ -806,14 +788,28 @@ def f(x, val):
806788 @config .patch ({"fx_graph_remote_cache" : False })
807789 @config .patch ({"freezing" : True })
808790 @parametrize ("device" , (GPU_TYPE , "cpu" ))
809- def test_freezing (self , device ):
791+ @parametrize ("inlinable" , (True , False ))
792+ def test_freezing (self , device , inlinable ):
810793 if device == GPU_TYPE and not HAS_GPU :
811794 raise unittest .SkipTest (f"requires { GPU_TYPE } " )
812795
796+ # For machines with mkldnn_fp16 support, weight_pack in mkldnn_fusion.py causes
797+ # the creation of a mkldnn format tensor which the current implementation does
798+ # not support.
799+ if (
800+ device == "cpu"
801+ and torch .backends .mkldnn .is_available ()
802+ and torch .ops .mkldnn ._is_mkldnn_fp16_supported ()
803+ ):
804+ raise unittest .SkipTest ("mkldnn tensors unsupported" )
805+
806+ # The shape of the frozen constant determines if it will be inlined.
807+ shape = (4 ,) if inlinable else (8 , 8 )
808+
813809 class MM (torch .nn .Module ):
814810 def __init__ (self ) -> None :
815811 super ().__init__ ()
816- self .param = torch .nn .Parameter (torch .rand (8 , 8 ))
812+ self .param = torch .nn .Parameter (torch .rand (shape ))
817813
818814 def forward (self , x ):
819815 return x @ self .param
@@ -823,71 +819,37 @@ def forward(self, x):
823819 # Populate a cache entry.
824820 mod1 = MM ().to (device = device , dtype = dtype )
825821 with torch .no_grad ():
826- x = torch .rand (8 , 8 ).to (device = device , dtype = dtype )
822+ x = torch .rand (shape ).to (device = device , dtype = dtype )
827823 out0 = mod1 (x )
828824 out1 = torch .compile (mod1 )(x )
829825 self .assertEqual (out0 , out1 )
830826
831- # For mahcine that has mkldnn_fp16 support, the weight_pack in mkldnn_fusion.py
832- # wroks, which result in mkldnn format tensor, then the exception
833- # BypassFxGraphCache("mkldnn tensors unpickleable") is raised, and cause the
834- # fxgraph not cached.
835- def is_cpu_mkldnn_fp16_supported ():
836- return (
837- device == "cpu"
838- and torch .backends .mkldnn .is_available ()
839- and torch .ops .mkldnn ._is_mkldnn_fp16_supported ()
840- )
841-
842- if is_cpu_mkldnn_fp16_supported ():
843- fxgraph_cache_bypass_cnt = 1
844- fxgraph_cache_miss_cnt = 0
845- fxgraph_cache_hit_cnt = 0
846- else :
847- fxgraph_cache_bypass_cnt = 0
848- fxgraph_cache_miss_cnt = 1
849- fxgraph_cache_hit_cnt = 0
850-
851- self .assertEqual (
852- counters ["inductor" ]["fxgraph_cache_bypass" ], fxgraph_cache_bypass_cnt
853- )
854- self .assertEqual (
855- counters ["inductor" ]["fxgraph_cache_miss" ], fxgraph_cache_miss_cnt
856- )
857- self .assertEqual (
858- counters ["inductor" ]["fxgraph_cache_hit" ], fxgraph_cache_hit_cnt
859- )
827+ self .assertEqual (counters ["inductor" ]["fxgraph_cache_bypass" ], 0 )
828+ self .assertEqual (counters ["inductor" ]["fxgraph_cache_miss" ], 1 )
829+ self .assertEqual (counters ["inductor" ]["fxgraph_cache_hit" ], 0 )
860830
861831 counters .clear ()
862832 self .reset ()
863833
864- # Same nn.Module, but with different parameters should cache hit.
834+ # Same nn.Module, but with different parameters. In the case that the param can
835+ # be inlined, we should consider the actual tensor value and we expect a cache
836+ # miss (because the values are different here). If the param cannot be inlined,
837+ # then we consider only the tensor metadata and we expect a cache hit.
865838 mod2 = MM ().to (device = device , dtype = dtype )
866839 self .assertNotEqual (mod1 .param , mod2 .param )
867840
868841 with torch .no_grad ():
869- x = torch .rand (8 , 8 ).to (device = device , dtype = dtype )
842+ x = torch .rand (shape ).to (device = device , dtype = dtype )
870843 out0 = mod2 (x )
871844 out1 = torch .compile (mod2 )(x )
872845 self .assertEqual (out0 , out1 )
873846
874- if is_cpu_mkldnn_fp16_supported ():
875- fxgraph_cache_bypass_cnt = 1
876- fxgraph_cache_miss_cnt = 0
877- fxgraph_cache_hit_cnt = 0
878- else :
879- fxgraph_cache_bypass_cnt = 0
880- fxgraph_cache_miss_cnt = 0
881- fxgraph_cache_hit_cnt = 1
882-
883- self .assertEqual (
884- counters ["inductor" ]["fxgraph_cache_bypass" ], fxgraph_cache_bypass_cnt
885- )
847+ self .assertEqual (counters ["inductor" ]["fxgraph_cache_bypass" ], 0 )
886848 self .assertEqual (
887- counters ["inductor" ]["fxgraph_cache_miss" ], fxgraph_cache_miss_cnt
849+ counters ["inductor" ]["fxgraph_cache_miss" ], 1 if inlinable else 0
888850 )
889851 self .assertEqual (
890- counters ["inductor" ]["fxgraph_cache_hit" ], fxgraph_cache_hit_cnt
852+ counters ["inductor" ]["fxgraph_cache_hit" ], 0 if inlinable else 1
891853 )
892854
893855
0 commit comments