Skip to content

Commit 2266451

Browse files
authored
[reland][ROCm] preshuffled weight mm (#2044)
* [ROCm][experimental] pre-shuffle weights * add custom gemm op * pass through swizzled * copy paste bug causing extra matmul to execute * correct transpose and permute logic * swizzle.cpp is rocm-only, remove #ifndef USE_ROCM * transpose is shallow, don't unswizzle/swizzle * add fp8 swizzle * remove print statement * setup.py missing check for vec ext * remove merge mistake * conditionalize building sparse marlin for hip * ruff format * ruff check --fix * protect swizzle.cpp inside USE_ROCM * patch from @mxz297
1 parent 7eb6125 commit 2266451

File tree

8 files changed

+1356
-40
lines changed

8 files changed

+1356
-40
lines changed

setup.py

+75-27
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,6 @@ def use_debug_mode():
7979
_get_cuda_arch_flags,
8080
)
8181

82-
IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None)
83-
8482

8583
class BuildOptions:
8684
def __init__(self):
@@ -255,28 +253,37 @@ def get_extensions():
255253
print(
256254
"PyTorch GPU support is not available. Skipping compilation of CUDA extensions"
257255
)
258-
if (CUDA_HOME is None and ROCM_HOME is None) and torch.cuda.is_available():
259-
print(
260-
"CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions"
261-
)
256+
if CUDA_HOME is None and torch.cuda.is_available() and torch.version.cuda:
257+
print("CUDA toolkit is not available. Skipping compilation of CUDA extensions")
262258
print(
263259
"If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit"
264260
)
261+
if ROCM_HOME is None and torch.cuda.is_available() and torch.version.hip:
262+
print("ROCm is not available. Skipping compilation of ROCm extensions")
263+
print("If you'd like to compile ROCm extensions locally please install ROCm")
265264

266-
use_cuda = torch.cuda.is_available() and (
267-
CUDA_HOME is not None or ROCM_HOME is not None
265+
use_cuda = (
266+
torch.cuda.is_available() and torch.version.cuda and CUDA_HOME is not None
268267
)
269-
extension = CUDAExtension if use_cuda else CppExtension
268+
use_hip = torch.cuda.is_available() and torch.version.hip and ROCM_HOME is not None
269+
extension = CUDAExtension if (use_cuda or use_hip) else CppExtension
270+
271+
nvcc_args = [
272+
"-DNDEBUG" if not debug_mode else "-DDEBUG",
273+
"-O3" if not debug_mode else "-O0",
274+
"-t=0",
275+
"-std=c++17",
276+
]
277+
hip_args = [
278+
"-DNDEBUG" if not debug_mode else "-DDEBUG",
279+
"-O3" if not debug_mode else "-O0",
280+
"-std=c++17",
281+
]
270282

271283
extra_link_args = []
272284
extra_compile_args = {
273285
"cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"],
274-
"nvcc": [
275-
"-DNDEBUG" if not debug_mode else "-DDEBUG",
276-
"-O3" if not debug_mode else "-O0",
277-
"-t=0",
278-
"-std=c++17",
279-
],
286+
"nvcc": nvcc_args if use_cuda else hip_args,
280287
}
281288

282289
if not IS_WINDOWS:
@@ -299,48 +306,89 @@ def get_extensions():
299306
extra_compile_args["nvcc"].append("-g")
300307
extra_link_args.append("/DEBUG")
301308

309+
hip_sparse_marlin_supported = True
310+
if use_hip:
311+
# naive search for hipblalst.h, if any found contain HIPBLASLT_ORDER_COL16 and VEC_EXT
312+
found_col16 = False
313+
found_vec_ext = False
314+
print("ROCM_HOME", ROCM_HOME)
315+
hipblaslt_headers = list(
316+
glob.glob(os.path.join(ROCM_HOME, "include", "hipblaslt", "hipblaslt.h"))
317+
)
318+
print("hipblaslt_headers", hipblaslt_headers)
319+
for header in hipblaslt_headers:
320+
with open(header) as f:
321+
text = f.read()
322+
if "HIPBLASLT_ORDER_COL16" in text:
323+
found_col16 = True
324+
if "HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT" in text:
325+
found_vec_ext = True
326+
if found_col16:
327+
extra_compile_args["cxx"].append("-DHIPBLASLT_HAS_ORDER_COL16")
328+
print("hipblaslt found extended col order enums")
329+
else:
330+
print("hipblaslt does not have extended col order enums")
331+
if found_vec_ext:
332+
extra_compile_args["cxx"].append("-DHIPBLASLT_VEC_EXT")
333+
print("hipblaslt found vec ext")
334+
else:
335+
print("hipblaslt does not have vec ext")
336+
337+
# sparse_marlin depends on features in ROCm 6.4, __builtin_amdgcn_global_load_lds
338+
ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split(".")[:2])
339+
hip_sparse_marlin_supported = ROCM_VERSION >= (6, 4)
340+
302341
# Get base directory and source paths
303342
curdir = os.path.dirname(os.path.curdir)
304343
extensions_dir = os.path.join(curdir, "torchao", "csrc")
305344

306345
# Collect C++ source files
307346
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))
308347

348+
# Collect CUDA source files
309349
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
310350
cuda_sources = list(
311351
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
312352
)
313353

354+
# Collect HIP source files
314355
extensions_hip_dir = os.path.join(
315356
extensions_dir, "cuda", "tensor_core_tiled_layout"
316357
)
317358
hip_sources = list(
318359
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
319360
)
320-
extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin")
361+
if hip_sparse_marlin_supported:
362+
extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin")
363+
hip_sources += list(
364+
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
365+
)
366+
extensions_hip_dir = os.path.join(extensions_dir, "rocm")
321367
hip_sources += list(
322-
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
368+
glob.glob(os.path.join(extensions_hip_dir, "**/*.hip"), recursive=True)
369+
)
370+
hip_sources += list(
371+
glob.glob(os.path.join(extensions_hip_dir, "**/*.cpp"), recursive=True)
323372
)
324373

325-
# Collect CUDA source files if needed
326-
if not IS_ROCM and use_cuda:
374+
# Add CUDA source files if needed
375+
if use_cuda:
327376
sources += cuda_sources
328377

329-
# TOOD: Remove this and use what CUDA has once we fix all the builds.
330-
if IS_ROCM and use_cuda:
378+
# TODO: Remove this and use what CUDA has once we fix all the builds.
379+
# Add HIP source files if needed
380+
if use_hip:
331381
# Add ROCm GPU architecture check
332382
gpu_arch = torch.cuda.get_device_properties(0).name
333383
if gpu_arch != "gfx942":
334384
print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}")
335-
print(
336-
"Currently only gfx942 is supported. Skipping compilation of ROCm extensions"
337-
)
338-
else:
339-
sources += hip_sources
385+
print("Currently only gfx942 is supported. Compiling only for gfx942.")
386+
extra_compile_args["nvcc"].append("--offload-arch=gfx942")
387+
sources += hip_sources
340388

341389
use_cutlass = False
342390
cutlass_90a_sources = None
343-
if use_cuda and not IS_ROCM and not IS_WINDOWS:
391+
if use_cuda and not IS_WINDOWS:
344392
use_cutlass = True
345393
cutlass_dir = os.path.join(third_party_path, "cutlass")
346394
cutlass_include_dir = os.path.join(cutlass_dir, "include")

test/test_ops.py

+33-11
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24
2626
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff
2727

28-
if torch.version.hip is not None:
29-
pytest.skip("Skipping the test in ROCm", allow_module_level=True)
28+
IS_CUDA = torch.cuda.is_available() and torch.version.cuda
29+
IS_ROCM = torch.cuda.is_available() and torch.version.hip
3030

3131
try:
3232
import torchao.ops
@@ -52,7 +52,7 @@ def _create_floatx_inputs(
5252
fp16_act = torch.rand(BS, IC).to(dtype) + 0.5
5353
return floatx_weight.to(device), scale.to(device), fp16_act.to(device)
5454

55-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
55+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
5656
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
5757
@parametrize("dtype", [torch.half, torch.bfloat16])
5858
def test_quant_llm_linear(self, ebits, mbits, dtype):
@@ -82,7 +82,7 @@ def test_quant_llm_linear(self, ebits, mbits, dtype):
8282
test_utils=test_utils,
8383
)
8484

85-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
85+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
8686
@parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
8787
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
8888
@parametrize("dtype", [torch.half, torch.bfloat16])
@@ -139,7 +139,7 @@ def make_test_id(param):
139139
return f"tiles_{param}"
140140

141141

142-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
142+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
143143
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
144144
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id)
145145
def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
@@ -157,7 +157,7 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
157157

158158

159159
# TODO: Fix "test_aot_dispatch_dynamic" test failure
160-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
160+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
161161
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
162162
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id)
163163
def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
@@ -203,7 +203,7 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
203203
return dq.reshape(n, k)
204204

205205

206-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
206+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
207207
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
208208
@pytest.mark.parametrize(
209209
"shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str
@@ -271,7 +271,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(
271271

272272

273273
# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
274-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
274+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
275275
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
276276
@pytest.mark.parametrize(
277277
"shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str
@@ -337,7 +337,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(
337337
assert diff_op_ao < 1e-1
338338

339339

340-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
340+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
341341
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
342342
@pytest.mark.parametrize(
343343
"shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str
@@ -448,7 +448,7 @@ def reshape_w(w):
448448
)
449449

450450

451-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
451+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
452452
@pytest.mark.parametrize(
453453
"batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors",
454454
MARLIN_TEST_PARAMS,
@@ -538,7 +538,7 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto
538538
)
539539

540540

541-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
541+
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
542542
@pytest.mark.parametrize(
543543
"batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors",
544544
MARLIN_TEST_PARAMS,
@@ -617,5 +617,27 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact
617617
)
618618

619619

620+
@pytest.mark.skipif(not IS_ROCM, reason="ROCm not available")
621+
def test_swizzle_mm():
622+
test_utils = [
623+
"test_schema",
624+
"test_autograd_registration",
625+
"test_faketensor",
626+
]
627+
628+
# TODO: Figure out why test fails unless torch >= 2.5
629+
if TORCH_VERSION_AT_LEAST_2_5:
630+
test_utils.append("test_aot_dispatch_dynamic")
631+
632+
mat1 = torch.randint(0, 16, dtype=torch.float, size=(16, 32), device="cuda")
633+
mat2 = torch.randint(0, 16, dtype=torch.float, size=(32, 16), device="cuda")
634+
635+
opcheck(
636+
torch.ops.torchao.swizzle_mm,
637+
(mat1, mat2, False, False),
638+
test_utils=test_utils,
639+
)
640+
641+
620642
if __name__ == "__main__":
621643
pytest.main(sys.argv)

torchao/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,14 @@
4343
quantize_,
4444
)
4545

46-
from . import dtypes, optim, testing
46+
from . import dtypes, optim, swizzle, testing
4747

4848
__all__ = [
4949
"dtypes",
5050
"autoquant",
5151
"optim",
5252
"quantize_",
53+
"swizzle",
5354
"testing",
5455
"ops",
5556
]

0 commit comments

Comments
 (0)