18
18
from torch ._C import FileCheck
19
19
from torch ._dynamo .testing import rand_strided
20
20
from torch ._dynamo .utils import same
21
- from torch ._inductor import codecache , config , metrics , test_operators
21
+ from torch ._inductor import config , cpu_vec_isa , metrics , test_operators
22
22
from torch ._inductor .codegen .common import OptimizationContext
23
23
from torch ._inductor .codegen .cpp import (
24
24
CppOverrides ,
67
67
check_model = test_torchinductor .check_model
68
68
69
69
requires_vectorization = unittest .skipUnless (
70
- codecache .valid_vec_isa_list (), "Does not support vectorization"
70
+ cpu_vec_isa .valid_vec_isa_list (), "Does not support vectorization"
71
71
)
72
72
73
73
74
74
def check_metrics_vec_kernel_count (num_expected_vec_kernels ):
75
- if codecache .valid_vec_isa_list ():
75
+ if cpu_vec_isa .valid_vec_isa_list ():
76
76
assert metrics .generated_cpp_vec_kernel_count == num_expected_vec_kernels
77
77
78
78
@@ -1583,14 +1583,14 @@ def fn(x):
1583
1583
self .common (fn , (value ,))
1584
1584
1585
1585
@unittest .skipIf (
1586
- platform .machine () != "x86_64" or not codecache .valid_vec_isa_list (),
1586
+ platform .machine () != "x86_64" or not cpu_vec_isa .valid_vec_isa_list (),
1587
1587
"Does not support vectorization or not x86_64 machine" ,
1588
1588
)
1589
1589
@patch ("torch.cuda.is_available" , lambda : False )
1590
1590
def test_auto_simd (self ):
1591
- vec_amx = codecache .supported_vec_isa_list [0 ]
1592
- vec_avx512 = codecache .supported_vec_isa_list [1 ]
1593
- vec_avx2 = codecache .supported_vec_isa_list [2 ]
1591
+ vec_amx = cpu_vec_isa .supported_vec_isa_list [0 ]
1592
+ vec_avx512 = cpu_vec_isa .supported_vec_isa_list [1 ]
1593
+ vec_avx2 = cpu_vec_isa .supported_vec_isa_list [2 ]
1594
1594
self .assertTrue (vec_amx .bit_width () == 512 )
1595
1595
self .assertTrue (vec_amx .nelements () == 16 )
1596
1596
self .assertTrue (vec_amx .nelements (torch .bfloat16 ) == 32 )
@@ -1602,43 +1602,43 @@ def test_auto_simd(self):
1602
1602
self .assertTrue (vec_avx2 .nelements (torch .bfloat16 ) == 16 )
1603
1603
1604
1604
with config .patch ({"cpp.simdlen" : None }):
1605
- isa = codecache .pick_vec_isa ()
1606
- if vec_amx in codecache .valid_vec_isa_list ():
1605
+ isa = cpu_vec_isa .pick_vec_isa ()
1606
+ if vec_amx in cpu_vec_isa .valid_vec_isa_list ():
1607
1607
self .assertTrue (isa == vec_amx )
1608
- elif vec_avx512 in codecache .valid_vec_isa_list ():
1608
+ elif vec_avx512 in cpu_vec_isa .valid_vec_isa_list ():
1609
1609
self .assertTrue (isa == vec_avx512 )
1610
1610
else :
1611
1611
self .assertTrue (isa == vec_avx2 )
1612
1612
1613
1613
with config .patch ({"cpp.simdlen" : 0 }):
1614
- isa = codecache .pick_vec_isa ()
1614
+ isa = cpu_vec_isa .pick_vec_isa ()
1615
1615
self .assertFalse (isa )
1616
1616
1617
1617
with config .patch ({"cpp.simdlen" : 1 }):
1618
- isa = codecache .pick_vec_isa ()
1618
+ isa = cpu_vec_isa .pick_vec_isa ()
1619
1619
self .assertFalse (isa )
1620
1620
1621
1621
with config .patch ({"cpp.simdlen" : 257 }):
1622
- isa = codecache .pick_vec_isa ()
1622
+ isa = cpu_vec_isa .pick_vec_isa ()
1623
1623
self .assertFalse (isa )
1624
1624
1625
1625
with config .patch ({"cpp.simdlen" : 513 }):
1626
- isa_list = codecache .valid_vec_isa_list ()
1626
+ isa_list = cpu_vec_isa .valid_vec_isa_list ()
1627
1627
if vec_avx512 in isa_list :
1628
1628
self .assertFalse (isa )
1629
1629
1630
1630
with config .patch ({"cpp.simdlen" : 512 }):
1631
- isa_list = codecache .valid_vec_isa_list ()
1632
- isa = codecache .pick_vec_isa ()
1631
+ isa_list = cpu_vec_isa .valid_vec_isa_list ()
1632
+ isa = cpu_vec_isa .pick_vec_isa ()
1633
1633
if vec_amx in isa_list :
1634
1634
self .assertTrue (isa == vec_amx )
1635
1635
elif vec_avx512 in isa_list :
1636
1636
self .assertTrue (isa == vec_avx512 )
1637
1637
1638
1638
with config .patch ({"cpp.simdlen" : 256 }):
1639
- isa_list = codecache .valid_vec_isa_list ()
1639
+ isa_list = cpu_vec_isa .valid_vec_isa_list ()
1640
1640
if vec_avx2 in isa_list :
1641
- isa = codecache .pick_vec_isa ()
1641
+ isa = cpu_vec_isa .pick_vec_isa ()
1642
1642
self .assertTrue (isa == vec_avx2 )
1643
1643
1644
1644
@requires_vectorization
@@ -1989,7 +1989,9 @@ def fn(x):
1989
1989
x [0 , 0 ] = torch .nan
1990
1990
x [1 , - 1 ] = torch .nan
1991
1991
1992
- bit_widths = [isa ._bit_width for isa in codecache .valid_vec_isa_list ()] + [None ]
1992
+ bit_widths = [isa ._bit_width for isa in cpu_vec_isa .valid_vec_isa_list ()] + [
1993
+ None
1994
+ ]
1993
1995
for item in bit_widths :
1994
1996
with config .patch ({"cpp.simdlen" : item }):
1995
1997
torch ._dynamo .reset ()
@@ -2007,7 +2009,7 @@ def fn(x):
2007
2009
2008
2010
return fn
2009
2011
2010
- bit_widths = [isa ._bit_width for isa in codecache .valid_vec_isa_list ()]
2012
+ bit_widths = [isa ._bit_width for isa in cpu_vec_isa .valid_vec_isa_list ()]
2011
2013
ih = [16 , 65 ]
2012
2014
iw = ih
2013
2015
oh = ih
@@ -2266,7 +2268,7 @@ def set_opt_dtype(graph):
2266
2268
graph_lowering
2267
2269
):
2268
2270
# The moset inner loop variable is used in the index_expr
2269
- tiling_factor = codecache .pick_vec_isa ().nelements (dtype = torch .float )
2271
+ tiling_factor = cpu_vec_isa .pick_vec_isa ().nelements (dtype = torch .float )
2270
2272
with CppVecKernelChecker (
2271
2273
args = None , num_threads = 1 , tiling_factor = tiling_factor
2272
2274
) as vec_checker :
@@ -2366,7 +2368,7 @@ def get_index():
2366
2368
):
2367
2369
itervars = [sympy .Symbol ("i" ), sympy .Symbol ("j" ), sympy .Symbol ("k" )]
2368
2370
2369
- tiling_factor = codecache .pick_vec_isa ().nelements (dtype = torch .float )
2371
+ tiling_factor = cpu_vec_isa .pick_vec_isa ().nelements (dtype = torch .float )
2370
2372
# The most inner loop variable is used in the index_expr
2371
2373
with CppVecKernelChecker (
2372
2374
args = None , num_threads = 1 , tiling_factor = tiling_factor
0 commit comments