Skip to content

Commit 58f346c

Browse files
xuhancnpytorchmergebot
authored andcommitted
[inductor] split cpu vec isa to dedicate file (keep git history) (pytorch#129789)
This PR is the implemention of pytorch#124245 (comment) plan 1 Changes: 1. Duplicate `codecache.py` to `cpu_vec_isa.py` with its `git history`. <img width="745" alt="image" src="https://github.com/pytorch/pytorch/assets/8433590/106533da-ce80-4825-8271-35ffb3141f92"> 2. Make `cpu_vec_isa.py` as dedicate file for CPU vec isa. It also good to extend for more archtectures and vec isa. 3. Update code for above changes. Pull Request resolved: pytorch#129789 Approved by: https://github.com/jgong5, https://github.com/jansel
1 parent a676b7c commit 58f346c

File tree

8 files changed

+420
-389
lines changed

8 files changed

+420
-389
lines changed

test/inductor/test_cpu_repro.py

+24-22
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch._C import FileCheck
1919
from torch._dynamo.testing import rand_strided
2020
from torch._dynamo.utils import same
21-
from torch._inductor import codecache, config, metrics, test_operators
21+
from torch._inductor import config, cpu_vec_isa, metrics, test_operators
2222
from torch._inductor.codegen.common import OptimizationContext
2323
from torch._inductor.codegen.cpp import (
2424
CppOverrides,
@@ -67,12 +67,12 @@
6767
check_model = test_torchinductor.check_model
6868

6969
requires_vectorization = unittest.skipUnless(
70-
codecache.valid_vec_isa_list(), "Does not support vectorization"
70+
cpu_vec_isa.valid_vec_isa_list(), "Does not support vectorization"
7171
)
7272

7373

7474
def check_metrics_vec_kernel_count(num_expected_vec_kernels):
75-
if codecache.valid_vec_isa_list():
75+
if cpu_vec_isa.valid_vec_isa_list():
7676
assert metrics.generated_cpp_vec_kernel_count == num_expected_vec_kernels
7777

7878

@@ -1583,14 +1583,14 @@ def fn(x):
15831583
self.common(fn, (value,))
15841584

15851585
@unittest.skipIf(
1586-
platform.machine() != "x86_64" or not codecache.valid_vec_isa_list(),
1586+
platform.machine() != "x86_64" or not cpu_vec_isa.valid_vec_isa_list(),
15871587
"Does not support vectorization or not x86_64 machine",
15881588
)
15891589
@patch("torch.cuda.is_available", lambda: False)
15901590
def test_auto_simd(self):
1591-
vec_amx = codecache.supported_vec_isa_list[0]
1592-
vec_avx512 = codecache.supported_vec_isa_list[1]
1593-
vec_avx2 = codecache.supported_vec_isa_list[2]
1591+
vec_amx = cpu_vec_isa.supported_vec_isa_list[0]
1592+
vec_avx512 = cpu_vec_isa.supported_vec_isa_list[1]
1593+
vec_avx2 = cpu_vec_isa.supported_vec_isa_list[2]
15941594
self.assertTrue(vec_amx.bit_width() == 512)
15951595
self.assertTrue(vec_amx.nelements() == 16)
15961596
self.assertTrue(vec_amx.nelements(torch.bfloat16) == 32)
@@ -1602,43 +1602,43 @@ def test_auto_simd(self):
16021602
self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16)
16031603

16041604
with config.patch({"cpp.simdlen": None}):
1605-
isa = codecache.pick_vec_isa()
1606-
if vec_amx in codecache.valid_vec_isa_list():
1605+
isa = cpu_vec_isa.pick_vec_isa()
1606+
if vec_amx in cpu_vec_isa.valid_vec_isa_list():
16071607
self.assertTrue(isa == vec_amx)
1608-
elif vec_avx512 in codecache.valid_vec_isa_list():
1608+
elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list():
16091609
self.assertTrue(isa == vec_avx512)
16101610
else:
16111611
self.assertTrue(isa == vec_avx2)
16121612

16131613
with config.patch({"cpp.simdlen": 0}):
1614-
isa = codecache.pick_vec_isa()
1614+
isa = cpu_vec_isa.pick_vec_isa()
16151615
self.assertFalse(isa)
16161616

16171617
with config.patch({"cpp.simdlen": 1}):
1618-
isa = codecache.pick_vec_isa()
1618+
isa = cpu_vec_isa.pick_vec_isa()
16191619
self.assertFalse(isa)
16201620

16211621
with config.patch({"cpp.simdlen": 257}):
1622-
isa = codecache.pick_vec_isa()
1622+
isa = cpu_vec_isa.pick_vec_isa()
16231623
self.assertFalse(isa)
16241624

16251625
with config.patch({"cpp.simdlen": 513}):
1626-
isa_list = codecache.valid_vec_isa_list()
1626+
isa_list = cpu_vec_isa.valid_vec_isa_list()
16271627
if vec_avx512 in isa_list:
16281628
self.assertFalse(isa)
16291629

16301630
with config.patch({"cpp.simdlen": 512}):
1631-
isa_list = codecache.valid_vec_isa_list()
1632-
isa = codecache.pick_vec_isa()
1631+
isa_list = cpu_vec_isa.valid_vec_isa_list()
1632+
isa = cpu_vec_isa.pick_vec_isa()
16331633
if vec_amx in isa_list:
16341634
self.assertTrue(isa == vec_amx)
16351635
elif vec_avx512 in isa_list:
16361636
self.assertTrue(isa == vec_avx512)
16371637

16381638
with config.patch({"cpp.simdlen": 256}):
1639-
isa_list = codecache.valid_vec_isa_list()
1639+
isa_list = cpu_vec_isa.valid_vec_isa_list()
16401640
if vec_avx2 in isa_list:
1641-
isa = codecache.pick_vec_isa()
1641+
isa = cpu_vec_isa.pick_vec_isa()
16421642
self.assertTrue(isa == vec_avx2)
16431643

16441644
@requires_vectorization
@@ -1989,7 +1989,9 @@ def fn(x):
19891989
x[0, 0] = torch.nan
19901990
x[1, -1] = torch.nan
19911991

1992-
bit_widths = [isa._bit_width for isa in codecache.valid_vec_isa_list()] + [None]
1992+
bit_widths = [isa._bit_width for isa in cpu_vec_isa.valid_vec_isa_list()] + [
1993+
None
1994+
]
19931995
for item in bit_widths:
19941996
with config.patch({"cpp.simdlen": item}):
19951997
torch._dynamo.reset()
@@ -2007,7 +2009,7 @@ def fn(x):
20072009

20082010
return fn
20092011

2010-
bit_widths = [isa._bit_width for isa in codecache.valid_vec_isa_list()]
2012+
bit_widths = [isa._bit_width for isa in cpu_vec_isa.valid_vec_isa_list()]
20112013
ih = [16, 65]
20122014
iw = ih
20132015
oh = ih
@@ -2266,7 +2268,7 @@ def set_opt_dtype(graph):
22662268
graph_lowering
22672269
):
22682270
# The moset inner loop variable is used in the index_expr
2269-
tiling_factor = codecache.pick_vec_isa().nelements(dtype=torch.float)
2271+
tiling_factor = cpu_vec_isa.pick_vec_isa().nelements(dtype=torch.float)
22702272
with CppVecKernelChecker(
22712273
args=None, num_threads=1, tiling_factor=tiling_factor
22722274
) as vec_checker:
@@ -2366,7 +2368,7 @@ def get_index():
23662368
):
23672369
itervars = [sympy.Symbol("i"), sympy.Symbol("j"), sympy.Symbol("k")]
23682370

2369-
tiling_factor = codecache.pick_vec_isa().nelements(dtype=torch.float)
2371+
tiling_factor = cpu_vec_isa.pick_vec_isa().nelements(dtype=torch.float)
23702372
# The most inner loop variable is used in the index_expr
23712373
with CppVecKernelChecker(
23722374
args=None, num_threads=1, tiling_factor=tiling_factor

test/inductor/test_cpu_select_algorithm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch._inductor.config as inductor_config
1414
import torch._inductor.select_algorithm as select_algorithm
1515
from torch._dynamo.utils import counters
16-
from torch._inductor.codecache import VecAMX
16+
from torch._inductor.cpu_vec_isa import VecAMX
1717
from torch._inductor.test_case import run_tests, TestCase
1818
from torch.testing._internal.common_device_type import (
1919
dtypes,

test/inductor/test_extension_backend.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424

2525
import torch._inductor.config as config
26-
from torch._inductor import codecache, metrics
26+
from torch._inductor import cpu_vec_isa, metrics
2727
from torch._inductor.codegen import cpp_utils
2828
from torch._inductor.codegen.common import (
2929
get_scheduling_for_device,
@@ -146,7 +146,7 @@ def fn(a, b, c):
146146
metrics.reset()
147147
opt_fn = torch.compile()(fn)
148148
_, code = run_and_get_cpp_code(opt_fn, x, y, z)
149-
if codecache.valid_vec_isa_list():
149+
if cpu_vec_isa.valid_vec_isa_list():
150150
load_expr = "loadu"
151151
else:
152152
load_expr = " = in_ptr0[static_cast<long>(i0)];"

0 commit comments

Comments
 (0)