Skip to content

Commit 5abaa35

Browse files
authored
[ROCm] preshuffled weight mm (#1702)
* [ROCm][experimental] pre-shuffle weights * add custom gemm op * pass through swizzled * copy paste bug causing extra matmul to execute * correct transpose and permute logic * swizzle.cpp is rocm-only, remove #ifndef USE_ROCM * transpose is shallow, don't unswizzle/swizzle * add fp8 swizzle * remove print statement * setup.py missing check for vec ext * remove merge mistake
1 parent 04d1186 commit 5abaa35

File tree

8 files changed

+1289
-36
lines changed

8 files changed

+1289
-36
lines changed

setup.py

Lines changed: 65 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def use_debug_mode():
7171
import torch
7272
from torch.utils.cpp_extension import (
7373
CUDA_HOME,
74+
ROCM_HOME,
7475
IS_WINDOWS,
7576
ROCM_HOME,
7677
BuildExtension,
@@ -79,7 +80,6 @@ def use_debug_mode():
7980
_get_cuda_arch_flags,
8081
)
8182

82-
IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None)
8383

8484

8585
class BuildOptions:
@@ -255,28 +255,37 @@ def get_extensions():
255255
print(
256256
"PyTorch GPU support is not available. Skipping compilation of CUDA extensions"
257257
)
258-
if (CUDA_HOME is None and ROCM_HOME is None) and torch.cuda.is_available():
258+
if CUDA_HOME is None and torch.cuda.is_available() and torch.version.cuda:
259+
print("CUDA toolkit is not available. Skipping compilation of CUDA extensions")
259260
print(
260-
"CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions"
261+
"If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit"
261262
)
263+
if ROCM_HOME is None and torch.cuda.is_available() and torch.version.hip:
264+
print("ROCm is not available. Skipping compilation of ROCm extensions")
262265
print(
263-
"If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit"
266+
"If you'd like to compile ROCm extensions locally please install ROCm"
264267
)
265268

266-
use_cuda = torch.cuda.is_available() and (
267-
CUDA_HOME is not None or ROCM_HOME is not None
268-
)
269-
extension = CUDAExtension if use_cuda else CppExtension
269+
use_cuda = torch.cuda.is_available() and torch.version.cuda and CUDA_HOME is not None
270+
use_hip = torch.cuda.is_available() and torch.version.hip and ROCM_HOME is not None
271+
extension = CUDAExtension if (use_cuda or use_hip) else CppExtension
270272

271-
extra_link_args = []
272-
extra_compile_args = {
273-
"cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"],
274-
"nvcc": [
273+
nvcc_args = [
275274
"-DNDEBUG" if not debug_mode else "-DDEBUG",
276275
"-O3" if not debug_mode else "-O0",
277276
"-t=0",
278277
"-std=c++17",
279-
],
278+
]
279+
hip_args = [
280+
"-DNDEBUG" if not debug_mode else "-DDEBUG",
281+
"-O3" if not debug_mode else "-O0",
282+
"-std=c++17",
283+
]
284+
285+
extra_link_args = []
286+
extra_compile_args = {
287+
"cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"],
288+
"nvcc": nvcc_args if use_cuda else hip_args
280289
}
281290

282291
if not IS_WINDOWS:
@@ -299,18 +308,45 @@ def get_extensions():
299308
extra_compile_args["nvcc"].append("-g")
300309
extra_link_args.append("/DEBUG")
301310

311+
if use_hip:
312+
# naive search for hipblalst.h, if any found contain HIPBLASLT_ORDER_COL16 and VEC_EXT
313+
found_col16 = False
314+
found_vec_ext = False
315+
print("ROCM_HOME", ROCM_HOME)
316+
hipblaslt_headers = list(glob.glob(os.path.join(ROCM_HOME, "include", "hipblaslt", "hipblaslt.h")))
317+
print("hipblaslt_headers", hipblaslt_headers)
318+
for header in hipblaslt_headers:
319+
with open(header) as f:
320+
text = f.read()
321+
if "HIPBLASLT_ORDER_COL16" in text:
322+
found_col16 = True
323+
if "HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT" in text:
324+
found_vec_ext = True
325+
if found_col16:
326+
extra_compile_args["cxx"].append("-DHIPBLASLT_HAS_ORDER_COL16")
327+
print("hipblaslt found extended col order enums")
328+
else:
329+
print("hipblaslt does not have extended col order enums")
330+
if found_vec_ext:
331+
extra_compile_args["cxx"].append("-DHIPBLASLT_VEC_EXT")
332+
print("hipblaslt found vec ext")
333+
else:
334+
print("hipblaslt does not have vec ext")
335+
302336
# Get base directory and source paths
303337
curdir = os.path.dirname(os.path.curdir)
304338
extensions_dir = os.path.join(curdir, "torchao", "csrc")
305339

306340
# Collect C++ source files
307341
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))
308342

343+
# Collect CUDA source files
309344
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
310345
cuda_sources = list(
311346
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
312347
)
313348

349+
# Collect HIP source files
314350
extensions_hip_dir = os.path.join(
315351
extensions_dir, "cuda", "tensor_core_tiled_layout"
316352
)
@@ -321,26 +357,32 @@ def get_extensions():
321357
hip_sources += list(
322358
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
323359
)
360+
extensions_hip_dir = os.path.join(extensions_dir, "rocm")
361+
hip_sources += list(
362+
glob.glob(os.path.join(extensions_hip_dir, "**/*.hip"), recursive=True)
363+
)
364+
hip_sources += list(
365+
glob.glob(os.path.join(extensions_hip_dir, "**/*.cpp"), recursive=True)
366+
)
324367

325-
# Collect CUDA source files if needed
326-
if not IS_ROCM and use_cuda:
368+
# Add CUDA source files if needed
369+
if use_cuda:
327370
sources += cuda_sources
328371

329-
# TOOD: Remove this and use what CUDA has once we fix all the builds.
330-
if IS_ROCM and use_cuda:
372+
# TODO: Remove this and use what CUDA has once we fix all the builds.
373+
# Add HIP source files if needed
374+
if use_hip:
331375
# Add ROCm GPU architecture check
332376
gpu_arch = torch.cuda.get_device_properties(0).name
333377
if gpu_arch != "gfx942":
334378
print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}")
335-
print(
336-
"Currently only gfx942 is supported. Skipping compilation of ROCm extensions"
337-
)
338-
else:
339-
sources += hip_sources
379+
print("Currently only gfx942 is supported. Compiling only for gfx942.")
380+
extra_compile_args["nvcc"].append("--offload-arch=gfx942")
381+
sources += hip_sources
340382

341383
use_cutlass = False
342384
cutlass_90a_sources = None
343-
if use_cuda and not IS_ROCM and not IS_WINDOWS:
385+
if use_cuda and not IS_WINDOWS:
344386
use_cutlass = True
345387
cutlass_dir = os.path.join(third_party_path, "cutlass")
346388
cutlass_include_dir = os.path.join(cutlass_dir, "include")

test/test_ops.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24
2626
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff
2727

28-
if torch.version.hip is not None:
29-
pytest.skip("Skipping the test in ROCm", allow_module_level=True)
28+
IS_CUDA = torch.cuda.is_available() and torch.version.cuda
29+
IS_ROCM = torch.cuda.is_available() and torch.version.hip
3030

3131
try:
3232
import torchao.ops
@@ -52,7 +52,7 @@ def _create_floatx_inputs(
5252
fp16_act = torch.rand(BS, IC).to(dtype) + 0.5
5353
return floatx_weight.to(device), scale.to(device), fp16_act.to(device)
5454

55-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
55+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
5656
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
5757
@parametrize("dtype", [torch.half, torch.bfloat16])
5858
def test_quant_llm_linear(self, ebits, mbits, dtype):
@@ -82,7 +82,7 @@ def test_quant_llm_linear(self, ebits, mbits, dtype):
8282
test_utils=test_utils,
8383
)
8484

85-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
85+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
8686
@parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
8787
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
8888
@parametrize("dtype", [torch.half, torch.bfloat16])
@@ -139,7 +139,7 @@ def make_test_id(param):
139139
return f"tiles_{param}"
140140

141141

142-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
142+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
143143
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
144144
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id)
145145
def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
@@ -157,7 +157,7 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
157157

158158

159159
# TODO: Fix "test_aot_dispatch_dynamic" test failure
160-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
160+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
161161
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
162162
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id)
163163
def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
@@ -203,7 +203,7 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
203203
return dq.reshape(n, k)
204204

205205

206-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
206+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
207207
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
208208
@pytest.mark.parametrize(
209209
"shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str
@@ -271,7 +271,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(
271271

272272

273273
# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
274-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
274+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
275275
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
276276
@pytest.mark.parametrize(
277277
"shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str
@@ -337,7 +337,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(
337337
assert diff_op_ao < 1e-1
338338

339339

340-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
340+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
341341
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
342342
@pytest.mark.parametrize(
343343
"shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str
@@ -448,7 +448,7 @@ def reshape_w(w):
448448
)
449449

450450

451-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
451+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
452452
@pytest.mark.parametrize(
453453
"batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors",
454454
MARLIN_TEST_PARAMS,
@@ -538,7 +538,7 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto
538538
)
539539

540540

541-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
541+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
542542
@pytest.mark.parametrize(
543543
"batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors",
544544
MARLIN_TEST_PARAMS,
@@ -617,5 +617,27 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact
617617
)
618618

619619

620+
@pytest.mark.skipif(not IS_ROCM, reason="ROCm not available")
621+
def test_swizzle_mm():
622+
test_utils = [
623+
"test_schema",
624+
"test_autograd_registration",
625+
"test_faketensor",
626+
]
627+
628+
# TODO: Figure out why test fails unless torch >= 2.5
629+
if TORCH_VERSION_AT_LEAST_2_5:
630+
test_utils.append("test_aot_dispatch_dynamic")
631+
632+
mat1 = torch.randint(0, 16, dtype=torch.float, size=(16,32), device="cuda")
633+
mat2 = torch.randint(0, 16, dtype=torch.float, size=(32,16), device="cuda")
634+
635+
opcheck(
636+
torch.ops.torchao.swizzle_mm,
637+
(mat1, mat2, False, False),
638+
test_utils=test_utils,
639+
)
640+
641+
620642
if __name__ == "__main__":
621643
pytest.main(sys.argv)

torchao/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,14 @@
4343
quantize_,
4444
)
4545

46-
from . import dtypes, optim, testing
46+
from . import dtypes, optim, swizzle, testing
4747

4848
__all__ = [
4949
"dtypes",
5050
"autoquant",
5151
"optim",
5252
"quantize_",
53+
"swizzle",
5354
"testing",
5455
"ops",
5556
]

0 commit comments

Comments
 (0)