Skip to content

Commit 6726b0b

Browse files
authored
Revert "[ROCm] preshuffled weight mm" (#2031)
Revert "[ROCm] preshuffled weight mm (#1702)" This reverts commit 5abaa35.
1 parent 5abaa35 commit 6726b0b

File tree

8 files changed

+36
-1289
lines changed

8 files changed

+36
-1289
lines changed

Diff for: setup.py

+23-65
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def use_debug_mode():
7171
import torch
7272
from torch.utils.cpp_extension import (
7373
CUDA_HOME,
74-
ROCM_HOME,
7574
IS_WINDOWS,
7675
ROCM_HOME,
7776
BuildExtension,
@@ -80,6 +79,7 @@ def use_debug_mode():
8079
_get_cuda_arch_flags,
8180
)
8281

82+
IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None)
8383

8484

8585
class BuildOptions:
@@ -255,37 +255,28 @@ def get_extensions():
255255
print(
256256
"PyTorch GPU support is not available. Skipping compilation of CUDA extensions"
257257
)
258-
if CUDA_HOME is None and torch.cuda.is_available() and torch.version.cuda:
259-
print("CUDA toolkit is not available. Skipping compilation of CUDA extensions")
258+
if (CUDA_HOME is None and ROCM_HOME is None) and torch.cuda.is_available():
260259
print(
261-
"If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit"
260+
"CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions"
262261
)
263-
if ROCM_HOME is None and torch.cuda.is_available() and torch.version.hip:
264-
print("ROCm is not available. Skipping compilation of ROCm extensions")
265262
print(
266-
"If you'd like to compile ROCm extensions locally please install ROCm"
263+
"If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit"
267264
)
268265

269-
use_cuda = torch.cuda.is_available() and torch.version.cuda and CUDA_HOME is not None
270-
use_hip = torch.cuda.is_available() and torch.version.hip and ROCM_HOME is not None
271-
extension = CUDAExtension if (use_cuda or use_hip) else CppExtension
266+
use_cuda = torch.cuda.is_available() and (
267+
CUDA_HOME is not None or ROCM_HOME is not None
268+
)
269+
extension = CUDAExtension if use_cuda else CppExtension
272270

273-
nvcc_args = [
271+
extra_link_args = []
272+
extra_compile_args = {
273+
"cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"],
274+
"nvcc": [
274275
"-DNDEBUG" if not debug_mode else "-DDEBUG",
275276
"-O3" if not debug_mode else "-O0",
276277
"-t=0",
277278
"-std=c++17",
278-
]
279-
hip_args = [
280-
"-DNDEBUG" if not debug_mode else "-DDEBUG",
281-
"-O3" if not debug_mode else "-O0",
282-
"-std=c++17",
283-
]
284-
285-
extra_link_args = []
286-
extra_compile_args = {
287-
"cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"],
288-
"nvcc": nvcc_args if use_cuda else hip_args
279+
],
289280
}
290281

291282
if not IS_WINDOWS:
@@ -308,45 +299,18 @@ def get_extensions():
308299
extra_compile_args["nvcc"].append("-g")
309300
extra_link_args.append("/DEBUG")
310301

311-
if use_hip:
312-
# naive search for hipblalst.h, if any found contain HIPBLASLT_ORDER_COL16 and VEC_EXT
313-
found_col16 = False
314-
found_vec_ext = False
315-
print("ROCM_HOME", ROCM_HOME)
316-
hipblaslt_headers = list(glob.glob(os.path.join(ROCM_HOME, "include", "hipblaslt", "hipblaslt.h")))
317-
print("hipblaslt_headers", hipblaslt_headers)
318-
for header in hipblaslt_headers:
319-
with open(header) as f:
320-
text = f.read()
321-
if "HIPBLASLT_ORDER_COL16" in text:
322-
found_col16 = True
323-
if "HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT" in text:
324-
found_vec_ext = True
325-
if found_col16:
326-
extra_compile_args["cxx"].append("-DHIPBLASLT_HAS_ORDER_COL16")
327-
print("hipblaslt found extended col order enums")
328-
else:
329-
print("hipblaslt does not have extended col order enums")
330-
if found_vec_ext:
331-
extra_compile_args["cxx"].append("-DHIPBLASLT_VEC_EXT")
332-
print("hipblaslt found vec ext")
333-
else:
334-
print("hipblaslt does not have vec ext")
335-
336302
# Get base directory and source paths
337303
curdir = os.path.dirname(os.path.curdir)
338304
extensions_dir = os.path.join(curdir, "torchao", "csrc")
339305

340306
# Collect C++ source files
341307
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))
342308

343-
# Collect CUDA source files
344309
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
345310
cuda_sources = list(
346311
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
347312
)
348313

349-
# Collect HIP source files
350314
extensions_hip_dir = os.path.join(
351315
extensions_dir, "cuda", "tensor_core_tiled_layout"
352316
)
@@ -357,32 +321,26 @@ def get_extensions():
357321
hip_sources += list(
358322
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
359323
)
360-
extensions_hip_dir = os.path.join(extensions_dir, "rocm")
361-
hip_sources += list(
362-
glob.glob(os.path.join(extensions_hip_dir, "**/*.hip"), recursive=True)
363-
)
364-
hip_sources += list(
365-
glob.glob(os.path.join(extensions_hip_dir, "**/*.cpp"), recursive=True)
366-
)
367324

368-
# Add CUDA source files if needed
369-
if use_cuda:
325+
# Collect CUDA source files if needed
326+
if not IS_ROCM and use_cuda:
370327
sources += cuda_sources
371328

372-
# TODO: Remove this and use what CUDA has once we fix all the builds.
373-
# Add HIP source files if needed
374-
if use_hip:
329+
# TOOD: Remove this and use what CUDA has once we fix all the builds.
330+
if IS_ROCM and use_cuda:
375331
# Add ROCm GPU architecture check
376332
gpu_arch = torch.cuda.get_device_properties(0).name
377333
if gpu_arch != "gfx942":
378334
print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}")
379-
print("Currently only gfx942 is supported. Compiling only for gfx942.")
380-
extra_compile_args["nvcc"].append("--offload-arch=gfx942")
381-
sources += hip_sources
335+
print(
336+
"Currently only gfx942 is supported. Skipping compilation of ROCm extensions"
337+
)
338+
else:
339+
sources += hip_sources
382340

383341
use_cutlass = False
384342
cutlass_90a_sources = None
385-
if use_cuda and not IS_WINDOWS:
343+
if use_cuda and not IS_ROCM and not IS_WINDOWS:
386344
use_cutlass = True
387345
cutlass_dir = os.path.join(third_party_path, "cutlass")
388346
cutlass_include_dir = os.path.join(cutlass_dir, "include")

Diff for: test/test_ops.py

+11-33
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24
2626
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff
2727

28-
IS_CUDA = torch.cuda.is_available() and torch.version.cuda
29-
IS_ROCM = torch.cuda.is_available() and torch.version.hip
28+
if torch.version.hip is not None:
29+
pytest.skip("Skipping the test in ROCm", allow_module_level=True)
3030

3131
try:
3232
import torchao.ops
@@ -52,7 +52,7 @@ def _create_floatx_inputs(
5252
fp16_act = torch.rand(BS, IC).to(dtype) + 0.5
5353
return floatx_weight.to(device), scale.to(device), fp16_act.to(device)
5454

55-
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
55+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
5656
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
5757
@parametrize("dtype", [torch.half, torch.bfloat16])
5858
def test_quant_llm_linear(self, ebits, mbits, dtype):
@@ -82,7 +82,7 @@ def test_quant_llm_linear(self, ebits, mbits, dtype):
8282
test_utils=test_utils,
8383
)
8484

85-
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
85+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
8686
@parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
8787
@parametrize("ebits,mbits", [(3, 2), (2, 2)])
8888
@parametrize("dtype", [torch.half, torch.bfloat16])
@@ -139,7 +139,7 @@ def make_test_id(param):
139139
return f"tiles_{param}"
140140

141141

142-
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
142+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
143143
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
144144
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id)
145145
def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
@@ -157,7 +157,7 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
157157

158158

159159
# TODO: Fix "test_aot_dispatch_dynamic" test failure
160-
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
160+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
161161
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
162162
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id)
163163
def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
@@ -203,7 +203,7 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
203203
return dq.reshape(n, k)
204204

205205

206-
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
206+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
207207
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
208208
@pytest.mark.parametrize(
209209
"shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str
@@ -271,7 +271,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(
271271

272272

273273
# This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize
274-
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
274+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
275275
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
276276
@pytest.mark.parametrize(
277277
"shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str
@@ -337,7 +337,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(
337337
assert diff_op_ao < 1e-1
338338

339339

340-
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
340+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
341341
# @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+")
342342
@pytest.mark.parametrize(
343343
"shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str
@@ -448,7 +448,7 @@ def reshape_w(w):
448448
)
449449

450450

451-
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
451+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
452452
@pytest.mark.parametrize(
453453
"batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors",
454454
MARLIN_TEST_PARAMS,
@@ -538,7 +538,7 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto
538538
)
539539

540540

541-
@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available")
541+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
542542
@pytest.mark.parametrize(
543543
"batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors",
544544
MARLIN_TEST_PARAMS,
@@ -617,27 +617,5 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact
617617
)
618618

619619

620-
@pytest.mark.skipif(not IS_ROCM, reason="ROCm not available")
621-
def test_swizzle_mm():
622-
test_utils = [
623-
"test_schema",
624-
"test_autograd_registration",
625-
"test_faketensor",
626-
]
627-
628-
# TODO: Figure out why test fails unless torch >= 2.5
629-
if TORCH_VERSION_AT_LEAST_2_5:
630-
test_utils.append("test_aot_dispatch_dynamic")
631-
632-
mat1 = torch.randint(0, 16, dtype=torch.float, size=(16,32), device="cuda")
633-
mat2 = torch.randint(0, 16, dtype=torch.float, size=(32,16), device="cuda")
634-
635-
opcheck(
636-
torch.ops.torchao.swizzle_mm,
637-
(mat1, mat2, False, False),
638-
test_utils=test_utils,
639-
)
640-
641-
642620
if __name__ == "__main__":
643621
pytest.main(sys.argv)

Diff for: torchao/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,13 @@
4343
quantize_,
4444
)
4545

46-
from . import dtypes, optim, swizzle, testing
46+
from . import dtypes, optim, testing
4747

4848
__all__ = [
4949
"dtypes",
5050
"autoquant",
5151
"optim",
5252
"quantize_",
53-
"swizzle",
5453
"testing",
5554
"ops",
5655
]

0 commit comments

Comments
 (0)