1818from torch ._C import FileCheck
1919from torch ._dynamo .testing import rand_strided
2020from torch ._dynamo .utils import same
21- from torch ._inductor import codecache , config , metrics , test_operators
21+ from torch ._inductor import config , cpu_vec_isa , metrics , test_operators
2222from torch ._inductor .codegen .common import OptimizationContext
2323from torch ._inductor .codegen .cpp import (
2424 CppOverrides ,
6767check_model = test_torchinductor .check_model
6868
6969requires_vectorization = unittest .skipUnless (
70- codecache .valid_vec_isa_list (), "Does not support vectorization"
70+ cpu_vec_isa .valid_vec_isa_list (), "Does not support vectorization"
7171)
7272
7373
7474def check_metrics_vec_kernel_count (num_expected_vec_kernels ):
75- if codecache .valid_vec_isa_list ():
75+ if cpu_vec_isa .valid_vec_isa_list ():
7676 assert metrics .generated_cpp_vec_kernel_count == num_expected_vec_kernels
7777
7878
@@ -1583,14 +1583,14 @@ def fn(x):
15831583 self .common (fn , (value ,))
15841584
15851585 @unittest .skipIf (
1586- platform .machine () != "x86_64" or not codecache .valid_vec_isa_list (),
1586+ platform .machine () != "x86_64" or not cpu_vec_isa .valid_vec_isa_list (),
15871587 "Does not support vectorization or not x86_64 machine" ,
15881588 )
15891589 @patch ("torch.cuda.is_available" , lambda : False )
15901590 def test_auto_simd (self ):
1591- vec_amx = codecache .supported_vec_isa_list [0 ]
1592- vec_avx512 = codecache .supported_vec_isa_list [1 ]
1593- vec_avx2 = codecache .supported_vec_isa_list [2 ]
1591+ vec_amx = cpu_vec_isa .supported_vec_isa_list [0 ]
1592+ vec_avx512 = cpu_vec_isa .supported_vec_isa_list [1 ]
1593+ vec_avx2 = cpu_vec_isa .supported_vec_isa_list [2 ]
15941594 self .assertTrue (vec_amx .bit_width () == 512 )
15951595 self .assertTrue (vec_amx .nelements () == 16 )
15961596 self .assertTrue (vec_amx .nelements (torch .bfloat16 ) == 32 )
@@ -1602,43 +1602,43 @@ def test_auto_simd(self):
16021602 self .assertTrue (vec_avx2 .nelements (torch .bfloat16 ) == 16 )
16031603
16041604 with config .patch ({"cpp.simdlen" : None }):
1605- isa = codecache .pick_vec_isa ()
1606- if vec_amx in codecache .valid_vec_isa_list ():
1605+ isa = cpu_vec_isa .pick_vec_isa ()
1606+ if vec_amx in cpu_vec_isa .valid_vec_isa_list ():
16071607 self .assertTrue (isa == vec_amx )
1608- elif vec_avx512 in codecache .valid_vec_isa_list ():
1608+ elif vec_avx512 in cpu_vec_isa .valid_vec_isa_list ():
16091609 self .assertTrue (isa == vec_avx512 )
16101610 else :
16111611 self .assertTrue (isa == vec_avx2 )
16121612
16131613 with config .patch ({"cpp.simdlen" : 0 }):
1614- isa = codecache .pick_vec_isa ()
1614+ isa = cpu_vec_isa .pick_vec_isa ()
16151615 self .assertFalse (isa )
16161616
16171617 with config .patch ({"cpp.simdlen" : 1 }):
1618- isa = codecache .pick_vec_isa ()
1618+ isa = cpu_vec_isa .pick_vec_isa ()
16191619 self .assertFalse (isa )
16201620
16211621 with config .patch ({"cpp.simdlen" : 257 }):
1622- isa = codecache .pick_vec_isa ()
1622+ isa = cpu_vec_isa .pick_vec_isa ()
16231623 self .assertFalse (isa )
16241624
16251625 with config .patch ({"cpp.simdlen" : 513 }):
1626- isa_list = codecache .valid_vec_isa_list ()
1626+ isa_list = cpu_vec_isa .valid_vec_isa_list ()
16271627 if vec_avx512 in isa_list :
16281628 self .assertFalse (isa )
16291629
16301630 with config .patch ({"cpp.simdlen" : 512 }):
1631- isa_list = codecache .valid_vec_isa_list ()
1632- isa = codecache .pick_vec_isa ()
1631+ isa_list = cpu_vec_isa .valid_vec_isa_list ()
1632+ isa = cpu_vec_isa .pick_vec_isa ()
16331633 if vec_amx in isa_list :
16341634 self .assertTrue (isa == vec_amx )
16351635 elif vec_avx512 in isa_list :
16361636 self .assertTrue (isa == vec_avx512 )
16371637
16381638 with config .patch ({"cpp.simdlen" : 256 }):
1639- isa_list = codecache .valid_vec_isa_list ()
1639+ isa_list = cpu_vec_isa .valid_vec_isa_list ()
16401640 if vec_avx2 in isa_list :
1641- isa = codecache .pick_vec_isa ()
1641+ isa = cpu_vec_isa .pick_vec_isa ()
16421642 self .assertTrue (isa == vec_avx2 )
16431643
16441644 @requires_vectorization
@@ -1989,7 +1989,9 @@ def fn(x):
19891989 x [0 , 0 ] = torch .nan
19901990 x [1 , - 1 ] = torch .nan
19911991
1992- bit_widths = [isa ._bit_width for isa in codecache .valid_vec_isa_list ()] + [None ]
1992+ bit_widths = [isa ._bit_width for isa in cpu_vec_isa .valid_vec_isa_list ()] + [
1993+ None
1994+ ]
19931995 for item in bit_widths :
19941996 with config .patch ({"cpp.simdlen" : item }):
19951997 torch ._dynamo .reset ()
@@ -2007,7 +2009,7 @@ def fn(x):
20072009
20082010 return fn
20092011
2010- bit_widths = [isa ._bit_width for isa in codecache .valid_vec_isa_list ()]
2012+ bit_widths = [isa ._bit_width for isa in cpu_vec_isa .valid_vec_isa_list ()]
20112013 ih = [16 , 65 ]
20122014 iw = ih
20132015 oh = ih
@@ -2266,7 +2268,7 @@ def set_opt_dtype(graph):
22662268 graph_lowering
22672269 ):
22682270 # The moset inner loop variable is used in the index_expr
2269- tiling_factor = codecache .pick_vec_isa ().nelements (dtype = torch .float )
2271+ tiling_factor = cpu_vec_isa .pick_vec_isa ().nelements (dtype = torch .float )
22702272 with CppVecKernelChecker (
22712273 args = None , num_threads = 1 , tiling_factor = tiling_factor
22722274 ) as vec_checker :
@@ -2366,7 +2368,7 @@ def get_index():
23662368 ):
23672369 itervars = [sympy .Symbol ("i" ), sympy .Symbol ("j" ), sympy .Symbol ("k" )]
23682370
2369- tiling_factor = codecache .pick_vec_isa ().nelements (dtype = torch .float )
2371+ tiling_factor = cpu_vec_isa .pick_vec_isa ().nelements (dtype = torch .float )
23702372 # The most inner loop variable is used in the index_expr
23712373 with CppVecKernelChecker (
23722374 args = None , num_threads = 1 , tiling_factor = tiling_factor
0 commit comments