Skip to content

[reland][ROCm] preshuffled weight mm #2044

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 75 additions & 27 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ def use_debug_mode():
_get_cuda_arch_flags,
)

IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None)


class BuildOptions:
def __init__(self):
Expand Down Expand Up @@ -255,28 +253,37 @@ def get_extensions():
print(
"PyTorch GPU support is not available. Skipping compilation of CUDA extensions"
)
if (CUDA_HOME is None and ROCM_HOME is None) and torch.cuda.is_available():
print(
"CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions"
)
if CUDA_HOME is None and torch.cuda.is_available() and torch.version.cuda:
print("CUDA toolkit is not available. Skipping compilation of CUDA extensions")
print(
"If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit"
)
if ROCM_HOME is None and torch.cuda.is_available() and torch.version.hip:
print("ROCm is not available. Skipping compilation of ROCm extensions")
print("If you'd like to compile ROCm extensions locally please install ROCm")

use_cuda = torch.cuda.is_available() and (
CUDA_HOME is not None or ROCM_HOME is not None
use_cuda = (
torch.cuda.is_available() and torch.version.cuda and CUDA_HOME is not None
)
extension = CUDAExtension if use_cuda else CppExtension
use_hip = torch.cuda.is_available() and torch.version.hip and ROCM_HOME is not None
extension = CUDAExtension if (use_cuda or use_hip) else CppExtension

nvcc_args = [
"-DNDEBUG" if not debug_mode else "-DDEBUG",
"-O3" if not debug_mode else "-O0",
"-t=0",
"-std=c++17",
]
hip_args = [
"-DNDEBUG" if not debug_mode else "-DDEBUG",
"-O3" if not debug_mode else "-O0",
"-std=c++17",
]

extra_link_args = []
extra_compile_args = {
"cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"],
"nvcc": [
"-DNDEBUG" if not debug_mode else "-DDEBUG",
"-O3" if not debug_mode else "-O0",
"-t=0",
"-std=c++17",
],
"nvcc": nvcc_args if use_cuda else hip_args,
}

if not IS_WINDOWS:
Expand All @@ -299,48 +306,89 @@ def get_extensions():
extra_compile_args["nvcc"].append("-g")
extra_link_args.append("/DEBUG")

hip_sparse_marlin_supported = True
if use_hip:
# naive search for hipblalst.h, if any found contain HIPBLASLT_ORDER_COL16 and VEC_EXT
found_col16 = False
found_vec_ext = False
print("ROCM_HOME", ROCM_HOME)
hipblaslt_headers = list(
glob.glob(os.path.join(ROCM_HOME, "include", "hipblaslt", "hipblaslt.h"))
)
print("hipblaslt_headers", hipblaslt_headers)
for header in hipblaslt_headers:
with open(header) as f:
text = f.read()
if "HIPBLASLT_ORDER_COL16" in text:
found_col16 = True
if "HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT" in text:
found_vec_ext = True
if found_col16:
extra_compile_args["cxx"].append("-DHIPBLASLT_HAS_ORDER_COL16")
print("hipblaslt found extended col order enums")
else:
print("hipblaslt does not have extended col order enums")
if found_vec_ext:
extra_compile_args["cxx"].append("-DHIPBLASLT_VEC_EXT")
print("hipblaslt found vec ext")
else:
print("hipblaslt does not have vec ext")

# sparse_marlin depends on features in ROCm 6.4, __builtin_amdgcn_global_load_lds
ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split(".")[:2])
hip_sparse_marlin_supported = ROCM_VERSION >= (6, 4)

# Get base directory and source paths
curdir = os.path.dirname(os.path.curdir)
extensions_dir = os.path.join(curdir, "torchao", "csrc")

# Collect C++ source files
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))

# Collect CUDA source files
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
cuda_sources = list(
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
)

# Collect HIP source files
extensions_hip_dir = os.path.join(
extensions_dir, "cuda", "tensor_core_tiled_layout"
)
hip_sources = list(
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
)
extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin")
if hip_sparse_marlin_supported:
extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin")
hip_sources += list(
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
)
extensions_hip_dir = os.path.join(extensions_dir, "rocm")
hip_sources += list(
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
glob.glob(os.path.join(extensions_hip_dir, "**/*.hip"), recursive=True)
)
hip_sources += list(
glob.glob(os.path.join(extensions_hip_dir, "**/*.cpp"), recursive=True)
)

# Collect CUDA source files if needed
if not IS_ROCM and use_cuda:
# Add CUDA source files if needed
if use_cuda:
sources += cuda_sources

# TOOD: Remove this and use what CUDA has once we fix all the builds.
if IS_ROCM and use_cuda:
# TODO: Remove this and use what CUDA has once we fix all the builds.
# Add HIP source files if needed
if use_hip:
# Add ROCm GPU architecture check
gpu_arch = torch.cuda.get_device_properties(0).name
if gpu_arch != "gfx942":
print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}")
print(
"Currently only gfx942 is supported. Skipping compilation of ROCm extensions"
)
else:
sources += hip_sources
print("Currently only gfx942 is supported. Compiling only for gfx942.")
extra_compile_args["nvcc"].append("--offload-arch=gfx942")
sources += hip_sources

use_cutlass = False
cutlass_90a_sources = None
if use_cuda and not IS_ROCM and not IS_WINDOWS:
if use_cuda and not IS_WINDOWS:
use_cutlass = True
cutlass_dir = os.path.join(third_party_path, "cutlass")
cutlass_include_dir = os.path.join(cutlass_dir, "include")
Expand Down
44 changes: 33 additions & 11 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff

if torch.version.hip is not None:
pytest.skip("Skipping the test in ROCm", allow_module_level=True)
IS_CUDA = torch.cuda.is_available() and torch.version.cuda
IS_ROCM = torch.cuda.is_available() and torch.version.hip

try:
import torchao.ops
Expand All @@ -52,7 +52,7 @@ def _create_floatx_inputs(
fp16_act = torch.rand(BS, IC).to(dtype) + 0.5
return floatx_weight.to(device), scale.to(device), fp16_act.to(device)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
@parametrize("dtype", [torch.half, torch.bfloat16])
def test_quant_llm_linear(self, ebits, mbits, dtype):
Expand Down Expand Up @@ -82,7 +82,7 @@ def test_quant_llm_linear(self, ebits, mbits, dtype):
test_utils=test_utils,
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
@parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
@parametrize("dtype", [torch.half, torch.bfloat16])
Expand Down Expand Up @@ -139,7 +139,7 @@ def make_test_id(param):
return f"tiles_{param}"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id)
def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
Expand All @@ -157,7 +157,7 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):


# TODO: Fix "test_aot_dispatch_dynamic" test failure
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id)
def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
Expand Down Expand Up @@ -203,7 +203,7 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
return dq.reshape(n, k)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize(
"shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str
Expand Down Expand Up @@ -271,7 +271,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(


# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize(
"shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str
Expand Down Expand Up @@ -337,7 +337,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(
assert diff_op_ao < 1e-1


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize(
"shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str
Expand Down Expand Up @@ -448,7 +448,7 @@ def reshape_w(w):
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
@pytest.mark.parametrize(
"batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors",
MARLIN_TEST_PARAMS,
Expand Down Expand Up @@ -538,7 +538,7 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
@pytest.mark.parametrize(
"batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors",
MARLIN_TEST_PARAMS,
Expand Down Expand Up @@ -617,5 +617,27 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact
)


@pytest.mark.skipif(not IS_ROCM, reason="ROCm not available")
def test_swizzle_mm():
test_utils = [
"test_schema",
"test_autograd_registration",
"test_faketensor",
]

# TODO: Figure out why test fails unless torch >= 2.5
if TORCH_VERSION_AT_LEAST_2_5:
test_utils.append("test_aot_dispatch_dynamic")

mat1 = torch.randint(0, 16, dtype=torch.float, size=(16, 32), device="cuda")
mat2 = torch.randint(0, 16, dtype=torch.float, size=(32, 16), device="cuda")

opcheck(
torch.ops.torchao.swizzle_mm,
(mat1, mat2, False, False),
test_utils=test_utils,
)


if __name__ == "__main__":
pytest.main(sys.argv)
3 changes: 2 additions & 1 deletion torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@
quantize_,
)

from . import dtypes, optim, testing
from . import dtypes, optim, swizzle, testing

__all__ = [
"dtypes",
"autoquant",
"optim",
"quantize_",
"swizzle",
"testing",
"ops",
]
Loading