From 780aa6044671c6e74c08704a7b94e5d2422ef94c Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Mon, 10 Feb 2025 21:45:32 +0000 Subject: [PATCH 01/11] [ROCm][experimental] pre-shuffle weights --- setup.py | 28 +++++++-- test/test_ops.py | 42 ++++++++++--- torchao/__init__.py | 3 +- torchao/csrc/rocm/swizzle/swizzle.hip | 23 +++++++ torchao/ops.py | 18 ++++++ torchao/swizzle/__init__.py | 11 ++++ torchao/swizzle/swizzle_ops.py | 58 ++++++++++++++++++ torchao/swizzle/swizzle_tensor.py | 88 +++++++++++++++++++++++++++ 8 files changed, 255 insertions(+), 16 deletions(-) create mode 100644 torchao/csrc/rocm/swizzle/swizzle.hip create mode 100644 torchao/swizzle/__init__.py create mode 100644 torchao/swizzle/swizzle_ops.py create mode 100644 torchao/swizzle/swizzle_tensor.py diff --git a/setup.py b/setup.py index 67a8d2e576..8ac7b04756 100644 --- a/setup.py +++ b/setup.py @@ -69,6 +69,7 @@ def use_debug_mode(): import torch from torch.utils.cpp_extension import ( CUDA_HOME, + ROCM_HOME, IS_WINDOWS, BuildExtension, CppExtension, @@ -203,22 +204,31 @@ def get_extensions(): print( "PyTorch GPU support is not available. Skipping compilation of CUDA extensions" ) - if CUDA_HOME is None and torch.cuda.is_available(): + if CUDA_HOME is None and torch.cuda.is_available() and torch.version.cuda: print("CUDA toolkit is not available. Skipping compilation of CUDA extensions") print( "If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit" ) + if ROCM_HOME is None and torch.cuda.is_available() and torch.version.hip: + print("ROCm is not available. Skipping compilation of ROCm extensions") + print( + "If you'd like to compile ROCm extensions locally please install ROCm" + ) use_cuda = torch.cuda.is_available() and CUDA_HOME is not None - extension = CUDAExtension if use_cuda else CppExtension + use_rocm = torch.cuda.is_available() and ROCM_HOME is not None + extension = CUDAExtension if (use_cuda or use_rocm) else CppExtension + + nvcc_args = [ + "-O3" if not debug_mode else "-O0", + "-t=0", + ], + rocm_args = ["-O3" if not debug_mode else "-O0"] extra_link_args = [] extra_compile_args = { "cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"], - "nvcc": [ - "-O3" if not debug_mode else "-O0", - "-t=0", - ], + "nvcc": nvcc_args if use_cuda else rocm_args } if not IS_WINDOWS: @@ -245,12 +255,18 @@ def get_extensions(): sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) extensions_cuda_dir = os.path.join(extensions_dir, "cuda") + extensions_rocm_dir = os.path.join(extensions_dir, "rocm") cuda_sources = list( glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True) ) + rocm_sources = list( + glob.glob(os.path.join(extensions_rocm_dir, "**/*.hip"), recursive=True) + ) if use_cuda: sources += cuda_sources + if use_rocm: + sources += rocm_sources use_cutlass = False if use_cuda and not IS_WINDOWS: diff --git a/test/test_ops.py b/test/test_ops.py index 54efefb026..24c25dfdc8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -20,6 +20,9 @@ from torchao.sparsity.marlin import inject_24, marlin_24_workspace, pack_to_marlin_24 from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, compute_max_diff, is_fbcode +IS_CUDA = torch.cuda.is_available() and torch.version.cuda +IS_ROCM = torch.cuda.is_available() and torch.version.hip + if is_fbcode(): pytest.skip( "Skipping the test in fbcode since we don't have TARGET file for kernels" @@ -49,7 +52,7 @@ def _create_floatx_inputs( fp16_act = torch.rand(BS, IC).to(dtype) + 0.5 return floatx_weight.to(device), scale.to(device), fp16_act.to(device) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") @parametrize("ebits,mbits", [(3, 2), (2, 2)]) @parametrize("dtype", [torch.half, torch.bfloat16]) def test_quant_llm_linear(self, ebits, mbits, dtype): @@ -79,7 +82,7 @@ def test_quant_llm_linear(self, ebits, mbits, dtype): test_utils=test_utils, ) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") @parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) @parametrize("ebits,mbits", [(3, 2), (2, 2)]) @parametrize("dtype", [torch.half, torch.bfloat16]) @@ -136,7 +139,7 @@ def make_test_id(param): return f"tiles_{param}" -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id) def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles): @@ -154,7 +157,7 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles): # TODO: Fix "test_aot_dispatch_dynamic" test failure -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id) def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles): @@ -200,7 +203,7 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16): return dq.reshape(n, k) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize( "shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str @@ -268,7 +271,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant( # This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize( "shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str @@ -334,7 +337,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant( assert diff_op_ao < 1e-1 -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize( "shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str @@ -445,7 +448,7 @@ def reshape_w(w): ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") @pytest.mark.parametrize( "batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, @@ -535,7 +538,7 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") @pytest.mark.parametrize( "batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, @@ -614,5 +617,26 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact ) +@pytest.mark.skipif(not IS_ROCM, reason="ROCm not available") +def test_swizzle_stub(): + test_utils = [ + "test_schema", + "test_autograd_registration", + "test_faketensor", + ] + + # TODO: Figure out why test fails unless torch >= 2.5 + if TORCH_VERSION_AT_LEAST_2_5: + test_utils.append("test_aot_dispatch_dynamic") + + t = torch.randint(0, 16, dtype=torch.int, size=(16,16), device="cuda") + + opcheck( + torch.ops.torchao.swizzle_stub, + (t,), + test_utils=test_utils, + ) + + if __name__ == "__main__": pytest.main(sys.argv) diff --git a/torchao/__init__.py b/torchao/__init__.py index 11716da62e..9db71b1471 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -55,12 +55,13 @@ quantize_, ) -from . import dtypes, testing +from . import dtypes, swizzle, testing __all__ = [ "dtypes", "autoquant", "quantize_", + "swizzle", "testing", "ops", ] diff --git a/torchao/csrc/rocm/swizzle/swizzle.hip b/torchao/csrc/rocm/swizzle/swizzle.hip new file mode 100644 index 0000000000..3423370093 --- /dev/null +++ b/torchao/csrc/rocm/swizzle/swizzle.hip @@ -0,0 +1,23 @@ +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +using at::Scalar; +using at::Tensor; +using at::TensorArg; +using c10::IntArrayRef; + +Tensor swizzle_stub(const Tensor& w) { + return at::zeros_like(w); +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::swizzle_stub", &swizzle_stub); +} diff --git a/torchao/ops.py b/torchao/ops.py index 8b573876f2..037bda4f56 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -25,6 +25,9 @@ lib.define( "rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" ) +lib.define( + "swizzle_stub(Tensor input) -> Tensor" +) def register_custom_op(name): @@ -592,3 +595,18 @@ def _( bias: Tensor, ) -> Tensor: return input_scale.new_empty(*input.shape[:-1], weight.shape[0]) + + +def swizzle_stub(w: Tensor) -> Tensor: + """ + Returns empty tensor like w. + + """ + return torch.ops.torchao.swizzle_stub.default( + w=w + ) + + +@register_custom_op("torchao::swizzle_stub") +def _(w: Tensor) -> Tensor: + return torch.zeros_like(w) diff --git a/torchao/swizzle/__init__.py b/torchao/swizzle/__init__.py new file mode 100644 index 0000000000..b5135532ef --- /dev/null +++ b/torchao/swizzle/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .swizzle_tensor import SwizzleTensor + +__all__ = [ + "SwizzleTensor" +] diff --git a/torchao/swizzle/swizzle_ops.py b/torchao/swizzle/swizzle_ops.py new file mode 100644 index 0000000000..a7e7640ef9 --- /dev/null +++ b/torchao/swizzle/swizzle_ops.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from typing import Any, Dict, Tuple + +import torch + +from torchao.swizzle.swizzle_tensor import SwizzleTensor + +aten = torch.ops.aten +SWIZZLE_OPS_TABLE: Dict[Any, Any] = {} + + +def implements(aten_ops): + """Register aten ops to the swizzle op table""" + + def decorator(func): + for op in aten_ops: + SWIZZLE_OPS_TABLE[op] = func + return func + + return decorator + + +@implements([aten.mm.default]) +def swizzle_mm(aten_op, args, kwargs=None): + a = args[0] + b = args[1] + + a = a.unswizzle() if isinstance(a, SwizzleTensor) else a + b = b.unswizzle() if isinstance(b, SwizzleTensor) else b + tensor_out = aten_op(a, b, **kwargs) + return tensor_out + + +@implements([aten.bmm.default]) +def swizzle_mm(aten_op, args, kwargs=None): + a = args[0] + b = args[1] + + a = a.unswizzle() if isinstance(a, SwizzleTensor) else a + b = b.unswizzle() if isinstance(b, SwizzleTensor) else b + tensor_out = aten_op(a, b, **kwargs) + return tensor_out + + +@implements([aten.addmm.default]) +def swizzle_addmm(aten_op, args, kwargs=None): + bias = args[0] + a = args[1] + b = args[2] + a = a.unswizzle() if isinstance(a, SwizzleTensor) else a + b = b.unswizzle() if isinstance(b, SwizzleTensor) else b + return aten_op(bias, a, b, args[3:], **kwargs) + + diff --git a/torchao/swizzle/swizzle_tensor.py b/torchao/swizzle/swizzle_tensor.py new file mode 100644 index 0000000000..d06574d994 --- /dev/null +++ b/torchao/swizzle/swizzle_tensor.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.utils._pytree import tree_map + +# copied from float8_utils.py +def _get_min_alignment(size: int, alignment_value: int) -> int: + return (1 + ((size - 1) // alignment_value)) * alignment_value + +class SwizzleTensor(torch.Tensor): + """ + A Python-only swizzled tensor subclass. + + Intended usage of this abstraction: + Swizzle weight Tensor to avoid LDS use during GEMMs on ROCm hardware. + """ + + def __new__( + cls, + original: torch.Tensor, + ): + wrapper = torch.empty_like(original, device="meta") + return torch.Tensor._make_subclass(cls, wrapper) + + def __init__(self, original): + assert original.ndim == 2 or original.ndim == 3 # (M, K) or (B, M, K) + if original.ndim == 2: + M, K = original.shape + B = 0 + if original.ndim == 3: + B, M, K = original.shape + alignedM = _get_min_alignment(M, 16) + alignedK = _get_min_alignment(K, 32) + paddedM = alignedM - M + paddedK = alignedK - K + x = torch.nn.functional.pad(original, (0, paddedK, 0, paddedM), "constant", 0) + if original.ndim == 2: + x = x.view(alignedM//16, 16, alignedK//32, 4, 8) + x = x.permute(0, 2, 3, 1, 4) + if original.ndim == 3: + x = x.view(B, alignedM//16, 16, alignedK//32, 4, 8) + x = x.permute(0, 1, 3, 4, 2, 5) + self.x = x.contiguous() + self.B = B + self.M = M + self.K = K + self.alignedM = alignedM + self.alignedK = alignedK + self.paddedM = paddedM + self.paddedK = paddedK + self.original_ndim = original.ndim + + def __repr__(self): + return f"{self.__class__.__name__}(original={self.unswizzle()})" + + def unswizzle(self): + if self.original_ndim == 2: + undone = self.x.permute(0, 3, 1, 2, 4).contiguous() + undone = undone.reshape(self.alignedM, self.alignedK) + undone = undone[0:self.M, 0:self.K] + return undone.reshape(self.M, self.K) + if self.original_ndim == 3: + undone = self.x.permute(0, 1, 4, 2, 3, 5).contiguous() + undone = undone.reshape(self.B, self.alignedM, self.alignedK) + undone = undone[0:self.B, 0:self.M, 0:self.K] + return undone.reshape(self.B, self.M, self.K) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + # Lazy import to avoid circular dependency + from torchao.swizzle.swizzle_ops import SWIZZLE_OPS_TABLE + if func in SWIZZLE_OPS_TABLE: + return SWIZZLE_OPS_TABLE[func](func, args, kwargs) + + def unwrap(e): + return e.unswizzle() if isinstance(e, SwizzleTensor) else e + + def wrap(e): + return SwizzleTensor(e) if isinstance(e, torch.Tensor) else e + + return tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) + + # Do not force the SwizzleTensor type on the returned tensor + __torch_function__ = torch._C._disabled_torch_function_impl From 308e0c9791952d61649a2c2def4f01714ec16c28 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Tue, 11 Feb 2025 00:34:43 +0000 Subject: [PATCH 02/11] add custom gemm op --- setup.py | 5 +- test/test_ops.py | 9 +- torchao/csrc/rocm/swizzle/swizzle.cpp | 457 ++++++++++++++++++++++++++ torchao/csrc/rocm/swizzle/swizzle.hip | 23 -- torchao/ops.py | 16 +- torchao/swizzle/swizzle_ops.py | 9 +- 6 files changed, 480 insertions(+), 39 deletions(-) create mode 100644 torchao/csrc/rocm/swizzle/swizzle.cpp delete mode 100644 torchao/csrc/rocm/swizzle/swizzle.hip diff --git a/setup.py b/setup.py index 8ac7b04756..a9d8e5c822 100644 --- a/setup.py +++ b/setup.py @@ -222,7 +222,7 @@ def get_extensions(): nvcc_args = [ "-O3" if not debug_mode else "-O0", "-t=0", - ], + ] rocm_args = ["-O3" if not debug_mode else "-O0"] extra_link_args = [] @@ -262,6 +262,9 @@ def get_extensions(): rocm_sources = list( glob.glob(os.path.join(extensions_rocm_dir, "**/*.hip"), recursive=True) ) + rocm_sources += list( + glob.glob(os.path.join(extensions_rocm_dir, "**/*.cpp"), recursive=True) + ) if use_cuda: sources += cuda_sources diff --git a/test/test_ops.py b/test/test_ops.py index 24c25dfdc8..5b380b28eb 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -618,7 +618,7 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact @pytest.mark.skipif(not IS_ROCM, reason="ROCm not available") -def test_swizzle_stub(): +def test_swizzle_mm(): test_utils = [ "test_schema", "test_autograd_registration", @@ -629,11 +629,12 @@ def test_swizzle_stub(): if TORCH_VERSION_AT_LEAST_2_5: test_utils.append("test_aot_dispatch_dynamic") - t = torch.randint(0, 16, dtype=torch.int, size=(16,16), device="cuda") + mat1 = torch.randint(0, 16, dtype=torch.float, size=(16,32), device="cuda") + mat2 = torch.randint(0, 16, dtype=torch.float, size=(32,16), device="cuda") opcheck( - torch.ops.torchao.swizzle_stub, - (t,), + torch.ops.torchao.swizzle_mm, + (mat1, mat2), test_utils=test_utils, ) diff --git a/torchao/csrc/rocm/swizzle/swizzle.cpp b/torchao/csrc/rocm/swizzle/swizzle.cpp new file mode 100644 index 0000000000..902afb71f3 --- /dev/null +++ b/torchao/csrc/rocm/swizzle/swizzle.cpp @@ -0,0 +1,457 @@ +#include + +#include +#include +#include +#include +#include +#include +#include + +using at::Scalar; +using at::Tensor; +using at::TensorArg; +using c10::IntArrayRef; + +// +// copied from aten/src/ATen/cuda/CUDABlas.cpp +// +namespace { + +static hipblasOperation_t _cublasOpFromChar(char op) { + // NOLINTNEXTLINE(bugprone-switch-missing-default-case) + switch (op) { + case 'n': + case 'N': + return HIPBLAS_OP_N; + case 't': + case 'T': + return HIPBLAS_OP_T; + case 'c': + case 'C': + return HIPBLAS_OP_C; + } + TORCH_CHECK(false, + "_cublasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`"); +} + +static void _cublasAdjustLdLevel3( + char transa, + char transb, + int64_t m, + int64_t n, + int64_t k, + int64_t* lda, + int64_t* ldb, + int64_t* ldc) { + bool transa_ = ((transa != 'n') && (transa != 'N')); + bool transb_ = ((transb != 'n') && (transb != 'N')); + + // Note: leading dimensions generally are checked that they are > 0 + // and at least as big the result requires (even if the value won't + // be used). + if (n <= 1) + *ldc = std::max(m, 1); + + if (transa_) { + if (m <= 1) + *lda = std::max(k, 1); + } else { + if (k <= 1) + *lda = std::max(m, 1); + } + + if (transb_) { + if (k <= 1) + *ldb = std::max(n, 1); + } else { + if (n <= 1) + *ldb = std::max(k, 1); + } +} + +// Following the pattern of CuSparseDescriptor +// Defined here for now because this is the only place cublas_lt interface is +// used but can be moved to a header once cublas_lt interface is used in +// multiple places. +template +struct HipBlasLtDeleter { + void operator()(T* x) { + if (x != nullptr) { + TORCH_CUDABLAS_CHECK(destructor(x)); + } + } +}; + +template +class HipBlasLtDescriptor { + public: + T* descriptor() const { + return descriptor_.get(); + } + T* descriptor() { + return descriptor_.get(); + } + + protected: + std::unique_ptr> descriptor_; +}; + +class HipBlasLtMatmulDescriptor : public HipBlasLtDescriptor< + hipblasLtMatmulDescOpaque_t, + &hipblasLtMatmulDescDestroy> { + public: + HipBlasLtMatmulDescriptor( + hipblasComputeType_t compute_type, + hipDataType scale_type) { + hipblasLtMatmulDesc_t raw_descriptor = nullptr; + TORCH_CUDABLAS_CHECK( + hipblasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type)); + descriptor_.reset(raw_descriptor); + } + template + inline void setAttribute(hipblasLtMatmulDescAttributes_t attr, const T value) { + // NOLINTNEXTLINE(bugprone-sizeof-expression) + TORCH_CUDABLAS_CHECK(::hipblasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(value))); + } +}; + +class HipBlasLtMatrixLayout : public HipBlasLtDescriptor< + hipblasLtMatrixLayoutOpaque_t, + &hipblasLtMatrixLayoutDestroy> { + public: + HipBlasLtMatrixLayout( + hipDataType type, + uint64_t rows, + uint64_t cols, + int64_t ld, + bool t = false) { + hipblasLtMatrixLayout_t raw_descriptor = nullptr; + TORCH_CUDABLAS_CHECK( + hipblasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld)); + descriptor_.reset(raw_descriptor); + } + template + inline void setAttribute(hipblasLtMatrixLayoutAttribute_t attr, const T value) { + TORCH_CUDABLAS_CHECK(::hipblasLtMatrixLayoutSetAttribute(descriptor(), attr, &value, sizeof(T))); + } +}; + +class HipBlasLtMatmulPreference : public HipBlasLtDescriptor< + hipblasLtMatmulPreferenceOpaque_t, + &hipblasLtMatmulPreferenceDestroy> { + public: + HipBlasLtMatmulPreference() { + hipblasLtMatmulPreference_t raw_descriptor = nullptr; + TORCH_CUDABLAS_CHECK(hipblasLtMatmulPreferenceCreate(&raw_descriptor)); + descriptor_.reset(raw_descriptor); + } + template + inline void setAttribute(hipblasLtMatmulPreferenceAttributes_t attr, const T value) { + TORCH_CUDABLAS_CHECK(::hipblasLtMatmulPreferenceSetAttribute(descriptor(), attr, &value, sizeof(T))); + } +}; + +static size_t _parseChosenWorkspaceSize() { + auto val = c10::utils::get_env("CUBLASLT_WORKSPACE_SIZE"); +#ifdef USE_ROCM + if (!val.has_value()) { + // accept either env var + val = c10::utils::get_env("HIPBLASLT_WORKSPACE_SIZE"); + } + size_t workspace_size = 76*1024; /* Use 76 MB for hipBLASLt */ +#else + size_t workspace_size = 1024; /* default size in KiB according to #73328 */ +#endif + + if (val.has_value()) { + try { + workspace_size = std::stoi(val.value()); + } catch(std::invalid_argument const& e) { + TORCH_WARN("invalid CUBLASLT_WORKSPACE_SIZE,", + " using default workspace size of ", workspace_size, " KiB."); + } catch(std::out_of_range const& e) { + TORCH_WARN("CUBLASLT_WORKSPACE_SIZE out of range,", + " using default workspace size of ", workspace_size, " KiB."); + } + } + return workspace_size * 1024; +} + +static size_t _getWorkspaceSize() { + static size_t workspace_size = _parseChosenWorkspaceSize(); + return workspace_size; +} + +} // namespace + +// +// copied from aten/src/ATen/native/cuda/Blas.cpp +// +namespace { + +// TODO: https://github.com/pytorch/pytorch/pull/59380#pullrequestreview-725310492 +c10::MaybeOwned inline resolve_conj_if_indicated(const Tensor& tensor, bool resolve_conj) { + if (resolve_conj && tensor.is_conj()) { + return c10::MaybeOwned::owned(tensor.resolve_conj()); + } else { + return c10::MaybeOwned::borrowed(tensor); + } +} + +c10::MaybeOwned inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, bool transpose_result) { + if (tensor.is_non_overlapping_and_dense()) { // common case + transpose_tensor = tensor.is_contiguous(); + return resolve_conj_if_indicated(tensor, transpose_result ? transpose_tensor : !transpose_tensor); + } + IntArrayRef tensor_strides = tensor.strides(); + IntArrayRef tensor_sizes = tensor.sizes(); + if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max(1, tensor_sizes[0]))) { + transpose_tensor = false; + return resolve_conj_if_indicated(tensor, !transpose_result); + } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max(1, tensor_sizes[1]))) { + transpose_tensor = true; + return resolve_conj_if_indicated(tensor, transpose_result); + } else { + transpose_tensor = true; + return c10::MaybeOwned::owned(tensor.clone(at::MemoryFormat::Contiguous)); + } +} + +c10::MaybeOwned inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor) { + if (tensor.is_non_overlapping_and_dense()) { // common case + transpose_tensor = tensor.is_contiguous(); + return resolve_conj_if_indicated(tensor, true); + } + + IntArrayRef tensor_strides = tensor.strides(); + IntArrayRef tensor_sizes = tensor.sizes(); + if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max(1, tensor_sizes[0]))) { + transpose_tensor = false; + return resolve_conj_if_indicated(tensor, true); + } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max(1, tensor_sizes[1]))) { + transpose_tensor = true; + return resolve_conj_if_indicated(tensor, true); + } else { + transpose_tensor = true; + return c10::MaybeOwned::owned(tensor.clone(at::MemoryFormat::Contiguous)); + } +} + +struct cublasCommonArgs { + cublasCommonArgs(const Tensor& mat1, const Tensor& mat2, Tensor& c) { + bool transpose_result = false, transpose_mat1 = false, transpose_mat2 = false; + result = prepare_matrix_for_cublas(c, transpose_result); + mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_mat1, transpose_result); + matb = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_mat2, transpose_result); + auto mat1_sizes = mat1.sizes(); + auto mat2_sizes = mat2.sizes(); + if (transpose_result) { + transpose_mat1 = !transpose_mat1; + transpose_mat2 = !transpose_mat2; + mat1_sizes = mata->sizes(); + mat2_sizes = matb->sizes(); + } + + m = mat1_sizes[transpose_result ? 1 : 0]; + k = mat1_sizes[transpose_result ? 0 : 1]; + n = mat2_sizes[transpose_result ? 0 : 1]; + lda = mata->stride((transpose_mat1 == transpose_result) ? 1 : 0); + ldb = matb->stride((transpose_mat2 == transpose_result) ? 1 : 0); + result_ld = result->stride(transpose_result ? 0 : 1); + transa = transpose_mat1 ? mata->is_conj() ? 'c' : 't' : 'n'; + transb = transpose_mat2 ? matb->is_conj() ? 'c' : 't' : 'n'; + } + char transa, transb; + int64_t m, n, k; + int64_t lda, ldb, result_ld; + c10::MaybeOwned mata, matb, result; +}; + +} // namespace + +template +inline void bgemm_hipblaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { + hipDataType abcType = HIP_R_32F; + hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F; + hipDataType scaleType = HIP_R_32F; + if constexpr (std::is_same_v) { + abcType = HIP_R_64F; + computeType = HIPBLAS_COMPUTE_64F; + scaleType = HIP_R_64F; + } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v>) { + abcType = HIP_C_64F; + computeType = HIPBLAS_COMPUTE_64F; + scaleType = HIP_C_64F; + } else if constexpr (std::is_same_v>) { + abcType = HIP_C_32F; + scaleType = HIP_C_32F; + } else if constexpr (std::is_same_v) { + abcType = HIP_R_16F; + } else if constexpr (std::is_same_v) { + abcType = HIP_R_16BF; + } else { + static_assert(false && sizeof(Dtype), "at::cuda::blas::bgemm_internal_cublaslt: not implemented"); + } + + hipblasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle(); + hipblasOperation_t opa = _cublasOpFromChar(transa); + hipblasOperation_t opb = _cublasOpFromChar(transb); + _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); + + HipBlasLtMatmulDescriptor computeDesc(computeType, scaleType); + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa); + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb); + HipBlasLtMatrixLayout Adesc(abcType, m, k, lda, opa == HIPBLAS_OP_T); + HipBlasLtMatrixLayout Bdesc(abcType, k, n, ldb, opb == HIPBLAS_OP_T); + HipBlasLtMatrixLayout Cdesc(abcType, m, n, ldc); + + if (num_batches > 1) { + int num_batches_as_int = static_cast(num_batches); + Adesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int); + Bdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int); + Cdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int); + Adesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stridea); + Bdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, strideb); + Cdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stridec); + } + + HipBlasLtMatmulPreference preference; + // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind + // setting this to 1M. + size_t workspaceSize = _getWorkspaceSize(); + preference.setAttribute(HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize); + +#ifndef USE_ROCM + uint32_t a_alignment = _getAlignment(reinterpret_cast(a)); + uint32_t b_alignment = _getAlignment(reinterpret_cast(b)); + uint32_t c_alignment = _getAlignment(reinterpret_cast(c)); + preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, a_alignment); + preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, b_alignment); + preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment); +#endif + + auto workspace = at::empty(static_cast(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + + hipblasLtMatmulHeuristicResult_t heuristicResult = {}; + int returnedResult = 0; + TORCH_CUDABLAS_CHECK(hipblasLtMatmulAlgoGetHeuristic( + ltHandle, + computeDesc.descriptor(), + Adesc.descriptor(), + Bdesc.descriptor(), + Cdesc.descriptor(), + Cdesc.descriptor(), + preference.descriptor(), + 1, + &heuristicResult, + &returnedResult)); + if (returnedResult == 0) { + TORCH_CUDABLAS_CHECK(HIPBLAS_STATUS_NOT_SUPPORTED); + } + + hipblasStatus_t cublasStatus = hipblasLtMatmul( + ltHandle, + computeDesc.descriptor(), + &alpha, + a, + Adesc.descriptor(), + b, + Bdesc.descriptor(), + &beta, + c, + Cdesc.descriptor(), + c, + Cdesc.descriptor(), + &heuristicResult.algo, + workspace.mutable_data_ptr(), + workspaceSize, + at::hip::getCurrentHIPStreamMasqueradingAsCUDA()); + TORCH_CHECK( + cublasStatus == HIPBLAS_STATUS_SUCCESS, + "CUDA error: ", + at::cuda::blas::_cublasGetErrorEnum(cublasStatus), + " when calling hipblasLtMatmul with transpose_mat1 ", + (opa == HIPBLAS_OP_T), + " transpose_mat2 ", + (opb == HIPBLAS_OP_T), + " m ", + m, + " n ", + n, + " k ", + k, + " lda ", + lda, + " ldb ", + ldb, + " ldc ", + ldc, + " abcType ", + abcType, + " computeType ", + computeType, + " scaleType ", + scaleType); +} +template +inline void gemm_hipblaslt(CUDABLAS_GEMM_ARGTYPES(Dtype)) { + // forward to bgemm implementation but set strides and batches to 0 + bgemm_hipblaslt(transa, transb, m, n, k, alpha, a, lda, 0, b, ldb, 0, beta, c, ldc, 0, 0); +} + + +Tensor swizzle_mm(const Tensor& mat1, const Tensor& mat2) { + TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); + TORCH_CHECK( + mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() + ); + + // NOLINTNEXTLINE(*c-array*) + TensorArg targs[]{{mat1, "mat1", 0}, {mat2, "mat2", 1}}; + checkAllSameGPU(__func__, targs); + + Tensor meta_mat1 = mat1.to("meta"); + Tensor meta_mat2 = mat2.to("meta"); + Tensor meta_result = at::mm(mat1, mat2); + Tensor result = at::empty_like(meta_result, mat1.device()); + at::ScalarType scalar_type = result.scalar_type(); + + cublasCommonArgs args(mat1, mat2, result); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + scalar_type, + "addmm_cuda", + [&] { + using opmath_t = at::opmath_type; + opmath_t alpha_val = opmath_t(1.0); + opmath_t beta_val = opmath_t(0.0); + const scalar_t* mat1_ptr = args.mata->const_data_ptr(); + const scalar_t* mat2_ptr = args.matb->const_data_ptr(); + scalar_t* result_ptr = args.result->mutable_data_ptr(); + gemm_hipblaslt( + args.transa, + args.transb, + args.m, + args.n, + args.k, + alpha_val, + mat1_ptr, + args.lda, + mat2_ptr, + args.ldb, + beta_val, + result_ptr, + args.result_ld); + }); + + return result; +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::swizzle_mm", &swizzle_mm); +} diff --git a/torchao/csrc/rocm/swizzle/swizzle.hip b/torchao/csrc/rocm/swizzle/swizzle.hip deleted file mode 100644 index 3423370093..0000000000 --- a/torchao/csrc/rocm/swizzle/swizzle.hip +++ /dev/null @@ -1,23 +0,0 @@ -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -using at::Scalar; -using at::Tensor; -using at::TensorArg; -using c10::IntArrayRef; - -Tensor swizzle_stub(const Tensor& w) { - return at::zeros_like(w); -} - -TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::swizzle_stub", &swizzle_stub); -} diff --git a/torchao/ops.py b/torchao/ops.py index 037bda4f56..73d85581f1 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -26,7 +26,7 @@ "rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" ) lib.define( - "swizzle_stub(Tensor input) -> Tensor" + "swizzle_mm(Tensor mat1, Tensor mat2) -> Tensor" ) @@ -597,16 +597,16 @@ def _( return input_scale.new_empty(*input.shape[:-1], weight.shape[0]) -def swizzle_stub(w: Tensor) -> Tensor: +def swizzle_mm(mat1: Tensor, mat2: Tensor) -> Tensor: """ - Returns empty tensor like w. + Similar to torch.mm but Tensor inputs can be SwizzleTensor instances. """ - return torch.ops.torchao.swizzle_stub.default( - w=w + return torch.ops.torchao.swizzle_mm.default( + mat1, mat2 ) -@register_custom_op("torchao::swizzle_stub") -def _(w: Tensor) -> Tensor: - return torch.zeros_like(w) +@register_custom_op("torchao::swizzle_mm") +def _(mat1: Tensor, mat2: Tensor) -> Tensor: + return mat1.new_empty(mat1.shape[0], mat2.shape[1]) diff --git a/torchao/swizzle/swizzle_ops.py b/torchao/swizzle/swizzle_ops.py index a7e7640ef9..c8a7dc723d 100644 --- a/torchao/swizzle/swizzle_ops.py +++ b/torchao/swizzle/swizzle_ops.py @@ -7,6 +7,7 @@ import torch +import torchao.ops from torchao.swizzle.swizzle_tensor import SwizzleTensor aten = torch.ops.aten @@ -31,7 +32,10 @@ def swizzle_mm(aten_op, args, kwargs=None): a = a.unswizzle() if isinstance(a, SwizzleTensor) else a b = b.unswizzle() if isinstance(b, SwizzleTensor) else b - tensor_out = aten_op(a, b, **kwargs) + if torch.is_floating_point(a): + tensor_out = torchao.ops.swizzle_mm(a, b) + else: + tensor_out = aten_op(a, b, **kwargs) return tensor_out @@ -42,8 +46,7 @@ def swizzle_mm(aten_op, args, kwargs=None): a = a.unswizzle() if isinstance(a, SwizzleTensor) else a b = b.unswizzle() if isinstance(b, SwizzleTensor) else b - tensor_out = aten_op(a, b, **kwargs) - return tensor_out + return aten_op(a, b, **kwargs) @implements([aten.addmm.default]) From 75b69035e5773522026028aee5fbca5d598b3cbc Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Tue, 11 Feb 2025 20:53:03 +0000 Subject: [PATCH 03/11] pass through swizzled --- setup.py | 17 +++++++++++++++++ torchao/csrc/rocm/swizzle/swizzle.cpp | 23 +++++++++++++++++------ torchao/ops.py | 8 ++++---- torchao/swizzle/swizzle_ops.py | 16 ++++++++++++---- torchao/swizzle/swizzle_tensor.py | 6 ++++++ 5 files changed, 56 insertions(+), 14 deletions(-) diff --git a/setup.py b/setup.py index a9d8e5c822..0acfe3aa95 100644 --- a/setup.py +++ b/setup.py @@ -250,6 +250,23 @@ def get_extensions(): extra_compile_args["nvcc"].append("-g") extra_link_args.append("/DEBUG") + if use_rocm: + # naive search for hipblalst.h, if any found contain HIPBLASLT_ORDER_COL16 + found = False + print("ROCM_HOME", ROCM_HOME) + hipblaslt_headers = list(glob.glob(os.path.join(ROCM_HOME, "include", "hipblaslt", "hipblaslt.h"))) + print("hipblaslt_headers", hipblaslt_headers) + for header in hipblaslt_headers: + with open(header) as f: + if "HIPBLASLT_ORDER_COL16" in f.read(): + found = True + break + if found: + extra_compile_args["cxx"].append("-DHIPBLASLT_HAS_ORDER_COL16") + print("hipblaslt found extended col order enums") + else: + print("hipblaslt does not have extended col order enums") + this_dir = os.path.dirname(os.path.curdir) extensions_dir = os.path.join(this_dir, "torchao", "csrc") sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True)) diff --git a/torchao/csrc/rocm/swizzle/swizzle.cpp b/torchao/csrc/rocm/swizzle/swizzle.cpp index 902afb71f3..294cd19e05 100644 --- a/torchao/csrc/rocm/swizzle/swizzle.cpp +++ b/torchao/csrc/rocm/swizzle/swizzle.cpp @@ -271,7 +271,7 @@ struct cublasCommonArgs { } // namespace template -inline void bgemm_hipblaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { +inline void bgemm_hipblaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype), bool mat1_is_swizzled, bool mat2_is_swizzled) { hipDataType abcType = HIP_R_32F; hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F; hipDataType scaleType = HIP_R_32F; @@ -306,6 +306,14 @@ inline void bgemm_hipblaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { HipBlasLtMatrixLayout Adesc(abcType, m, k, lda, opa == HIPBLAS_OP_T); HipBlasLtMatrixLayout Bdesc(abcType, k, n, ldb, opb == HIPBLAS_OP_T); HipBlasLtMatrixLayout Cdesc(abcType, m, n, ldc); +#ifdef HIPBLASLT_HAS_ORDER_COL16 + if (mat1_is_swizzled) { + Adesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_ORDER, HIPBLASLT_ORDER_COL16_4R8); + } + if (mat2_is_swizzled) { + Bdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_ORDER, HIPBLASLT_ORDER_COL16_4R8); + } +#endif if (num_batches > 1) { int num_batches_as_int = static_cast(num_batches); @@ -395,15 +403,16 @@ inline void bgemm_hipblaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { " scaleType ", scaleType); } + + template -inline void gemm_hipblaslt(CUDABLAS_GEMM_ARGTYPES(Dtype)) { +inline void gemm_hipblaslt(CUDABLAS_GEMM_ARGTYPES(Dtype), bool mat1_is_swizzled, bool mat2_is_swizzled) { // forward to bgemm implementation but set strides and batches to 0 - bgemm_hipblaslt(transa, transb, m, n, k, alpha, a, lda, 0, b, ldb, 0, beta, c, ldc, 0, 0); + bgemm_hipblaslt(transa, transb, m, n, k, alpha, a, lda, 0, b, ldb, 0, beta, c, ldc, 0, 0, mat1_is_swizzled, mat2_is_swizzled); } -Tensor swizzle_mm(const Tensor& mat1, const Tensor& mat2) { - TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); +Tensor swizzle_mm(const Tensor& mat1, const Tensor& mat2, bool mat1_is_swizzled, bool mat2_is_swizzled) { TORCH_CHECK( mat1.dtype() == mat2.dtype(), "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() @@ -446,7 +455,9 @@ Tensor swizzle_mm(const Tensor& mat1, const Tensor& mat2) { args.ldb, beta_val, result_ptr, - args.result_ld); + args.result_ld, + mat1_is_swizzled, + mat2_is_swizzled); }); return result; diff --git a/torchao/ops.py b/torchao/ops.py index 73d85581f1..179732776a 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -26,7 +26,7 @@ "rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" ) lib.define( - "swizzle_mm(Tensor mat1, Tensor mat2) -> Tensor" + "swizzle_mm(Tensor mat1, Tensor mat2, bool mat1_is_swizzled, bool mat2_is_swizzled) -> Tensor" ) @@ -597,16 +597,16 @@ def _( return input_scale.new_empty(*input.shape[:-1], weight.shape[0]) -def swizzle_mm(mat1: Tensor, mat2: Tensor) -> Tensor: +def swizzle_mm(mat1: Tensor, mat2: Tensor, mat1_is_swizzled: bool, mat2_is_swizzled: bool) -> Tensor: """ Similar to torch.mm but Tensor inputs can be SwizzleTensor instances. """ return torch.ops.torchao.swizzle_mm.default( - mat1, mat2 + mat1, mat2, mat1_is_swizzled, mat2_is_swizzled ) @register_custom_op("torchao::swizzle_mm") -def _(mat1: Tensor, mat2: Tensor) -> Tensor: +def _(mat1: Tensor, mat2: Tensor, mat1_is_swizzled: bool, mat2_is_swizzled: bool) -> Tensor: return mat1.new_empty(mat1.shape[0], mat2.shape[1]) diff --git a/torchao/swizzle/swizzle_ops.py b/torchao/swizzle/swizzle_ops.py index c8a7dc723d..1128115633 100644 --- a/torchao/swizzle/swizzle_ops.py +++ b/torchao/swizzle/swizzle_ops.py @@ -30,11 +30,19 @@ def swizzle_mm(aten_op, args, kwargs=None): a = args[0] b = args[1] - a = a.unswizzle() if isinstance(a, SwizzleTensor) else a - b = b.unswizzle() if isinstance(b, SwizzleTensor) else b - if torch.is_floating_point(a): - tensor_out = torchao.ops.swizzle_mm(a, b) + if torch.is_floating_point(a) and torch.is_floating_point(b): + a_is_swizzled = False + b_is_swizzled = False + if isinstance(a, SwizzleTensor): + a = a.as_tensor() + a_is_swizzled = True + if isinstance(b, SwizzleTensor): + b = b.as_tensor() + b_is_swizzled = True + tensor_out = torchao.ops.swizzle_mm(a, b, a_is_swizzled, b_is_swizzled) else: + a = a.unswizzle() if isinstance(a, SwizzleTensor) else a + b = b.unswizzle() if isinstance(b, SwizzleTensor) else b tensor_out = aten_op(a, b, **kwargs) return tensor_out diff --git a/torchao/swizzle/swizzle_tensor.py b/torchao/swizzle/swizzle_tensor.py index d06574d994..0b1cb18643 100644 --- a/torchao/swizzle/swizzle_tensor.py +++ b/torchao/swizzle/swizzle_tensor.py @@ -69,6 +69,12 @@ def unswizzle(self): undone = undone[0:self.B, 0:self.M, 0:self.K] return undone.reshape(self.B, self.M, self.K) + def as_tensor(self): + if self.original_ndim == 2: + return self.x.reshape(self.alignedM, self.alignedK) + if self.original_ndim == 3: + return self.x.reshape(self.B, self.alignedM, self.alignedK) + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): # Lazy import to avoid circular dependency From 5a1080363aefb60aef40a894196d0d88737e251a Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Wed, 5 Mar 2025 23:37:38 +0000 Subject: [PATCH 04/11] copy paste bug causing extra matmul to execute --- torchao/csrc/rocm/swizzle/swizzle.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/csrc/rocm/swizzle/swizzle.cpp b/torchao/csrc/rocm/swizzle/swizzle.cpp index 294cd19e05..b4af465e15 100644 --- a/torchao/csrc/rocm/swizzle/swizzle.cpp +++ b/torchao/csrc/rocm/swizzle/swizzle.cpp @@ -424,7 +424,7 @@ Tensor swizzle_mm(const Tensor& mat1, const Tensor& mat2, bool mat1_is_swizzled, Tensor meta_mat1 = mat1.to("meta"); Tensor meta_mat2 = mat2.to("meta"); - Tensor meta_result = at::mm(mat1, mat2); + Tensor meta_result = at::mm(meta_mat1, meta_mat2); Tensor result = at::empty_like(meta_result, mat1.device()); at::ScalarType scalar_type = result.scalar_type(); From ae9ca6ae57ae651b30dcc0665fef99e6a16ffbb0 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Tue, 11 Mar 2025 19:54:39 +0000 Subject: [PATCH 05/11] correct transpose and permute logic --- torchao/csrc/rocm/swizzle/swizzle.cpp | 11 ++++++++--- torchao/swizzle/swizzle_tensor.py | 7 +++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/torchao/csrc/rocm/swizzle/swizzle.cpp b/torchao/csrc/rocm/swizzle/swizzle.cpp index b4af465e15..c94dc4f86f 100644 --- a/torchao/csrc/rocm/swizzle/swizzle.cpp +++ b/torchao/csrc/rocm/swizzle/swizzle.cpp @@ -240,7 +240,8 @@ c10::MaybeOwned inline prepare_matrix_for_cublas(const Tensor& tensor, b struct cublasCommonArgs { cublasCommonArgs(const Tensor& mat1, const Tensor& mat2, Tensor& c) { - bool transpose_result = false, transpose_mat1 = false, transpose_mat2 = false; + bool transpose_mat1 = false, transpose_mat2 = false; + transpose_result = false; result = prepare_matrix_for_cublas(c, transpose_result); mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_mat1, transpose_result); matb = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_mat2, transpose_result); @@ -266,6 +267,7 @@ struct cublasCommonArgs { int64_t m, n, k; int64_t lda, ldb, result_ld; c10::MaybeOwned mata, matb, result; + bool transpose_result; }; } // namespace @@ -325,6 +327,9 @@ inline void bgemm_hipblaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype), bool mat1_is_swizzle Cdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stridec); } + hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT; + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, epilogue); + HipBlasLtMatmulPreference preference; // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind // setting this to 1M. @@ -456,8 +461,8 @@ Tensor swizzle_mm(const Tensor& mat1, const Tensor& mat2, bool mat1_is_swizzled, beta_val, result_ptr, args.result_ld, - mat1_is_swizzled, - mat2_is_swizzled); + args.transpose_result ? mat2_is_swizzled : mat1_is_swizzled, + args.transpose_result ? mat1_is_swizzled : mat2_is_swizzled); }); return result; diff --git a/torchao/swizzle/swizzle_tensor.py b/torchao/swizzle/swizzle_tensor.py index 0b1cb18643..d86f7a9982 100644 --- a/torchao/swizzle/swizzle_tensor.py +++ b/torchao/swizzle/swizzle_tensor.py @@ -27,6 +27,8 @@ def __new__( return torch.Tensor._make_subclass(cls, wrapper) def __init__(self, original): + # we pass in weights.T, but permute should be done on correct data + original = original.T assert original.ndim == 2 or original.ndim == 3 # (M, K) or (B, M, K) if original.ndim == 2: M, K = original.shape @@ -70,10 +72,11 @@ def unswizzle(self): return undone.reshape(self.B, self.M, self.K) def as_tensor(self): + # note the transpose because this causes col major hipblaslt op to be TN if self.original_ndim == 2: - return self.x.reshape(self.alignedM, self.alignedK) + return self.x.reshape(self.alignedM, self.alignedK).T if self.original_ndim == 3: - return self.x.reshape(self.B, self.alignedM, self.alignedK) + return self.x.reshape(self.B, self.alignedM, self.alignedK).T @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): From 2dff63b9bd1bcf88814cea55bcc2ce85a210d5c3 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Wed, 12 Mar 2025 16:38:13 +0000 Subject: [PATCH 06/11] swizzle.cpp is rocm-only, remove #ifndef USE_ROCM --- torchao/csrc/rocm/swizzle/swizzle.cpp | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/torchao/csrc/rocm/swizzle/swizzle.cpp b/torchao/csrc/rocm/swizzle/swizzle.cpp index c94dc4f86f..dbfa06fab5 100644 --- a/torchao/csrc/rocm/swizzle/swizzle.cpp +++ b/torchao/csrc/rocm/swizzle/swizzle.cpp @@ -154,15 +154,11 @@ class HipBlasLtMatmulPreference : public HipBlasLtDescriptor< static size_t _parseChosenWorkspaceSize() { auto val = c10::utils::get_env("CUBLASLT_WORKSPACE_SIZE"); -#ifdef USE_ROCM if (!val.has_value()) { // accept either env var val = c10::utils::get_env("HIPBLASLT_WORKSPACE_SIZE"); } size_t workspace_size = 76*1024; /* Use 76 MB for hipBLASLt */ -#else - size_t workspace_size = 1024; /* default size in KiB according to #73328 */ -#endif if (val.has_value()) { try { @@ -336,15 +332,6 @@ inline void bgemm_hipblaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype), bool mat1_is_swizzle size_t workspaceSize = _getWorkspaceSize(); preference.setAttribute(HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize); -#ifndef USE_ROCM - uint32_t a_alignment = _getAlignment(reinterpret_cast(a)); - uint32_t b_alignment = _getAlignment(reinterpret_cast(b)); - uint32_t c_alignment = _getAlignment(reinterpret_cast(c)); - preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, a_alignment); - preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_B_BYTES, b_alignment); - preference.setAttribute(CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, c_alignment); -#endif - auto workspace = at::empty(static_cast(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); hipblasLtMatmulHeuristicResult_t heuristicResult = {}; From fe461afbb79810a377e47e77ffa22f4c6c2ff712 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Thu, 13 Mar 2025 22:11:02 +0000 Subject: [PATCH 07/11] transpose is shallow, don't unswizzle/swizzle --- test/test_ops.py | 2 +- torchao/swizzle/swizzle_ops.py | 7 +++++ torchao/swizzle/swizzle_tensor.py | 45 ++++++++++++++++++++++++++----- 3 files changed, 47 insertions(+), 7 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 35b23df3ce..f5c933e13f 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -629,7 +629,7 @@ def test_swizzle_mm(): opcheck( torch.ops.torchao.swizzle_mm, - (mat1, mat2), + (mat1, mat2, False, False), test_utils=test_utils, ) diff --git a/torchao/swizzle/swizzle_ops.py b/torchao/swizzle/swizzle_ops.py index 1128115633..390c0c6ba3 100644 --- a/torchao/swizzle/swizzle_ops.py +++ b/torchao/swizzle/swizzle_ops.py @@ -67,3 +67,10 @@ def swizzle_addmm(aten_op, args, kwargs=None): return aten_op(bias, a, b, args[3:], **kwargs) +@implements([aten.permute.default]) +def swizzle_permute(aten_op, args, kwargs=None): + tensor = args[0] + dims = args[1] + if len(dims) == 2 and dims[0] == 1 and dims[1] == 0: + return tensor.shallow_transpose() + return aten_op(tensor.unswizzle(), dims) diff --git a/torchao/swizzle/swizzle_tensor.py b/torchao/swizzle/swizzle_tensor.py index d86f7a9982..65bc7192b8 100644 --- a/torchao/swizzle/swizzle_tensor.py +++ b/torchao/swizzle/swizzle_tensor.py @@ -22,14 +22,16 @@ class SwizzleTensor(torch.Tensor): def __new__( cls, original: torch.Tensor, + shallow: bool = False, ): wrapper = torch.empty_like(original, device="meta") return torch.Tensor._make_subclass(cls, wrapper) - def __init__(self, original): - # we pass in weights.T, but permute should be done on correct data - original = original.T - assert original.ndim == 2 or original.ndim == 3 # (M, K) or (B, M, K) + def __init__(self, original, shallow=False): + if shallow: + return + #assert original.ndim == 2 or original.ndim == 3 # (M, K) or (B, M, K) + assert original.ndim == 2, "SwizzleTensor only supports ndim 2" if original.ndim == 2: M, K = original.shape B = 0 @@ -55,6 +57,7 @@ def __init__(self, original): self.paddedM = paddedM self.paddedK = paddedK self.original_ndim = original.ndim + self.is_transposed = False def __repr__(self): return f"{self.__class__.__name__}(original={self.unswizzle()})" @@ -74,9 +77,39 @@ def unswizzle(self): def as_tensor(self): # note the transpose because this causes col major hipblaslt op to be TN if self.original_ndim == 2: - return self.x.reshape(self.alignedM, self.alignedK).T + tmp = self.x.reshape(self.alignedM, self.alignedK) + if self.is_transposed: + tmp = tmp.T + return tmp if self.original_ndim == 3: - return self.x.reshape(self.B, self.alignedM, self.alignedK).T + tmp = self.x.reshape(self.B, self.alignedM, self.alignedK) + if self.is_transposed: + tmp = tmp.T + return tmp + + def shallow_transpose(self): + shape = (self.M, self.K) if self.original_ndim == 2 else (self.B, self.M, self.K), + new_obj = SwizzleTensor( + torch.empty(*shape, dtype=self.dtype, layout=self.layout, device="meta"), + True) + new_obj.x = self.x + new_obj.B = self.B + new_obj.M = self.M + new_obj.K = self.K + new_obj.alignedM = self.alignedM + new_obj.alignedK = self.alignedK + new_obj.paddedM = self.paddedM + new_obj.paddedK = self.paddedK + new_obj.original_ndim = self.original_ndim + new_obj.is_transposed = not self.is_transposed + return new_obj + + @property + def shape(self): + return torch.Size((self.K, self.M) if self.is_transposed else (self.M, self.K)) + + def stride(self): + return (1, self.K) if self.is_transposed else (self.K, 1) @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): From b087d9206a9ec9230a03da4f2223479e77f0035c Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Fri, 14 Mar 2025 17:11:04 +0000 Subject: [PATCH 08/11] add fp8 swizzle --- torchao/csrc/rocm/swizzle/swizzle.cpp | 491 ++++++++++++++++++++++++-- torchao/ops.py | 21 +- torchao/swizzle/swizzle_ops.py | 22 ++ torchao/swizzle/swizzle_tensor.py | 17 +- 4 files changed, 523 insertions(+), 28 deletions(-) diff --git a/torchao/csrc/rocm/swizzle/swizzle.cpp b/torchao/csrc/rocm/swizzle/swizzle.cpp index dbfa06fab5..ae7b7b4e2c 100644 --- a/torchao/csrc/rocm/swizzle/swizzle.cpp +++ b/torchao/csrc/rocm/swizzle/swizzle.cpp @@ -1,17 +1,25 @@ #include +#include #include #include #include #include +#include #include +#include +#include +#include #include #include using at::Scalar; using at::Tensor; using at::TensorArg; +using c10::kFloat; +using c10::ScalarType; using c10::IntArrayRef; +using at::cuda::ScalarTypeToCudaDataType; // // copied from aten/src/ATen/cuda/CUDABlas.cpp @@ -179,6 +187,19 @@ static size_t _getWorkspaceSize() { return workspace_size; } +static bool _scaled_mm_is_fnuz() { + auto dprops = at::cuda::getCurrentDeviceProperties(); + std::string device_arch = dprops->gcnArchName; + static const std::vector archs = {"gfx940", "gfx941", "gfx942"}; + for (std::string arch : archs) { + size_t substring = device_arch.find(arch); + if (substring != std::string::npos) { + return true; + } + } + return false; +} + } // namespace // @@ -235,37 +256,143 @@ c10::MaybeOwned inline prepare_matrix_for_cublas(const Tensor& tensor, b } struct cublasCommonArgs { - cublasCommonArgs(const Tensor& mat1, const Tensor& mat2, Tensor& c) { - bool transpose_mat1 = false, transpose_mat2 = false; - transpose_result = false; + cublasCommonArgs( + const Tensor& mat1, + const Tensor& mat2, + bool swizzle1, + bool swizzle2, + Tensor& c, + const std::optional& scale_a = std::nullopt, + const std::optional& scale_b = std::nullopt, + const std::optional& scale_result = std::nullopt) { + bool transpose_result = false, transpose_a = false, transpose_b = false; result = prepare_matrix_for_cublas(c, transpose_result); - mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_mat1, transpose_result); - matb = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_mat2, transpose_result); - auto mat1_sizes = mat1.sizes(); - auto mat2_sizes = mat2.sizes(); + mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_a, transpose_result); + matb = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_b, transpose_result); + + // Handle scale tensors if provided + if (scale_a && scale_b) { + // By default since we return in row-major we run the gemm + // as B.T @ A.T, check transpose_result to determine if we flip the scales + scale_mata_ptr = transpose_result ? scale_b->data_ptr() : scale_a->data_ptr(); + scale_mata_dtype = transpose_result ? scale_b->scalar_type() : scale_a->scalar_type(); + scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr(); + scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type(); + } + + if (scale_result) { + scale_result_ptr = scale_result->data_ptr(); + scale_result_dtype = scale_result->scalar_type(); + } + + // Update transpose flags if (transpose_result) { - transpose_mat1 = !transpose_mat1; - transpose_mat2 = !transpose_mat2; - mat1_sizes = mata->sizes(); - mat2_sizes = matb->sizes(); + transpose_a = !transpose_a; + transpose_b = !transpose_b; } - m = mat1_sizes[transpose_result ? 1 : 0]; - k = mat1_sizes[transpose_result ? 0 : 1]; - n = mat2_sizes[transpose_result ? 0 : 1]; - lda = mata->stride((transpose_mat1 == transpose_result) ? 1 : 0); - ldb = matb->stride((transpose_mat2 == transpose_result) ? 1 : 0); + auto sizes_a = mata->sizes(); + auto sizes_b = matb->sizes(); + + m = sizes_a[transpose_result ? 1 : 0]; + k = sizes_a[transpose_result ? 0 : 1]; + n = sizes_b[transpose_result ? 0 : 1]; + lda = mata->stride((transpose_a == transpose_result) ? 1 : 0); + ldb = matb->stride((transpose_b == transpose_result) ? 1 : 0); result_ld = result->stride(transpose_result ? 0 : 1); - transa = transpose_mat1 ? mata->is_conj() ? 'c' : 't' : 'n'; - transb = transpose_mat2 ? matb->is_conj() ? 'c' : 't' : 'n'; + transa = transpose_a ? mata->is_conj() ? 'c' : 't' : 'n'; + transb = transpose_b ? matb->is_conj() ? 'c' : 't' : 'n'; + + mata_is_swizzled = transpose_result ? swizzle2 : swizzle1; + matb_is_swizzled = transpose_result ? swizzle1 : swizzle2; } + + // Matrix members char transa, transb; int64_t m, n, k; int64_t lda, ldb, result_ld; c10::MaybeOwned mata, matb, result; - bool transpose_result; + + // Scale members + void* scale_mata_ptr = nullptr; + void* scale_matb_ptr = nullptr; + void* scale_result_ptr = nullptr; + std::optional scale_mata_dtype; + std::optional scale_matb_dtype; + std::optional scale_result_dtype; + + // swizzle members + bool mata_is_swizzled; + bool matb_is_swizzled; }; +enum class ScalingType { + TensorWise, + RowWise, + Error +}; + +ScalingType get_scaling_type( + const at::Tensor& scale_a, + const at::Tensor& scale_b, + int64_t dim_m, + int64_t dim_n) { + // Both Per-Tensor and Row-wise scaling expect fp32 tensors + TORCH_CHECK( + scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat, + "Both scale_a and scale_b must be float (fp32) tensors."); + + // Check the singluar scale case for per-tensor scaling + if (scale_a.numel() == 1 && scale_b.numel() == 1) { + return ScalingType::TensorWise; + } + + // For non-TensorWise scaling, enforce 2D input tensors + TORCH_CHECK( + scale_a.dim() == 2 && scale_b.dim() == 2, + "For non-TensorWise scaling, scale tensors must be 2-dimensional, " + "but got scale_a.dim()=", + scale_a.dim(), + " and scale_b.dim()=", + scale_b.dim()); + + // Check for RowWise scaling + if (scale_a.size(0) == dim_m && scale_a.size(1) == 1 && + scale_b.size(0) == 1 && scale_b.size(1) == dim_n) { +#if defined(HIPBLASLT_VEC_EXT) + TORCH_CHECK( + scale_a.is_contiguous() && scale_b.is_contiguous(), + "Both scale_a and scale_b must be contiguous for RowWise scaling."); + return ScalingType::RowWise; +#else + TORCH_CHECK(false, "Per-row scaling is not supported for this platform!"); + return ScalingType::Error; +#endif + } + + // If we reach here, the input doesn't match any valid scaling type + TORCH_CHECK( + false, + "Invalid scaling configuration. For TensorWise scaling, both scales should be scalar. " + "For RowWise scaling, scale_a should be (", + dim_m, + ", 1) and scale_b should be (1, ", + dim_n, + "). " + "Got scale_a.size()=(", + scale_a.size(0), + ", ", + scale_a.size(1), + ") and ", + "scale_b.size()=(", + scale_b.size(0), + ", ", + scale_b.size(1), + ")"); + + return ScalingType::Error; +} + } // namespace template @@ -420,7 +547,7 @@ Tensor swizzle_mm(const Tensor& mat1, const Tensor& mat2, bool mat1_is_swizzled, Tensor result = at::empty_like(meta_result, mat1.device()); at::ScalarType scalar_type = result.scalar_type(); - cublasCommonArgs args(mat1, mat2, result); + cublasCommonArgs args(mat1, mat2, mat1_is_swizzled, mat2_is_swizzled, result); AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, @@ -448,13 +575,333 @@ Tensor swizzle_mm(const Tensor& mat1, const Tensor& mat2, bool mat1_is_swizzled, beta_val, result_ptr, args.result_ld, - args.transpose_result ? mat2_is_swizzled : mat1_is_swizzled, - args.transpose_result ? mat1_is_swizzled : mat2_is_swizzled); + args.mata_is_swizzled, + args.matb_is_swizzled); }); return result; } +void _scaled_gemm( + char transa, + char transb, + int64_t m, + int64_t n, + int64_t k, + const void* mat1_ptr, + const void* mat1_scale_ptr, + int64_t mat1_ld, + ScalarType mat1_dtype, + ScalarType mat1_scale_dtype, + bool mat1_is_swizzled, + const void* mat2_ptr, + const void* mat2_scale_ptr, + int64_t mat2_ld, + ScalarType mat2_dtype, + ScalarType mat2_scale_dtype, + bool mat2_is_swizzled, + const void* bias_ptr, + ScalarType bias_dtype, + void* result_ptr, + const void *result_scale_ptr, + int64_t result_ld, + ScalarType result_dtype, + bool use_rowwise) { + const auto computeType = HIPBLAS_COMPUTE_32F; + const auto scaleType = HIP_R_32F; + const float alpha_val = 1.0; + const float beta_val = 0.0; + HipBlasLtMatmulDescriptor computeDesc(computeType, scaleType); + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa)); + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb)); + hipblasLtMatmulDescAttributes_t matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER; + hipblasLtMatmulDescAttributes_t matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER; +#if defined(HIPBLASLT_VEC_EXT) + if (use_rowwise) { + matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT; + matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT; + } +#else + // rowwise isn't supported using cublaslt or older hipblaslt + TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt"); +#endif + computeDesc.setAttribute(matmulDescA, mat1_scale_ptr); + computeDesc.setAttribute(matmulDescB, mat2_scale_ptr); + if (result_scale_ptr != nullptr) { + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); + } + HipBlasLtMatrixLayout Adesc(ScalarTypeToCudaDataType(mat1_dtype), m, k, mat1_ld, transa == 't'); + HipBlasLtMatrixLayout Bdesc(ScalarTypeToCudaDataType(mat2_dtype), k, n, mat2_ld, transb == 't'); + // Cdesc is unused, beta is 0. But hipblaslt needs this set to something reasonable. + HipBlasLtMatrixLayout Cdesc(ScalarTypeToCudaDataType(result_dtype), m, n, result_ld); + HipBlasLtMatrixLayout Ddesc(ScalarTypeToCudaDataType(result_dtype), m, n, result_ld); + if (bias_ptr) { + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr); + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS); + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype)); + } + +#ifdef HIPBLASLT_HAS_ORDER_COL16 + if (mat1_is_swizzled) { + Adesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_ORDER, HIPBLASLT_ORDER_COL16_4R16); + } + if (mat2_is_swizzled) { + Bdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_ORDER, HIPBLASLT_ORDER_COL16_4R16); + } +#endif + + auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + size_t workspaceSize = _getWorkspaceSize(); + auto& allocator = *::c10::hip::HIPCachingAllocatorMasqueradingAsCUDA::get(); + auto workspace = allocator.allocate(workspaceSize); + auto workspace_ptr = workspace.mutable_get(); + TORCH_CHECK(workspace_ptr != nullptr, "OOM trying to allocate workspace for cublaslt"); + + HipBlasLtMatmulPreference preference; + preference.setAttribute(HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize); + hipblasLtMatmulHeuristicResult_t heuristicResult = {}; + int returnedResult = 0; + hipblasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle(); + TORCH_CUDABLAS_CHECK(hipblasLtMatmulAlgoGetHeuristic( + ltHandle, + computeDesc.descriptor(), + Adesc.descriptor(), + Bdesc.descriptor(), + Cdesc.descriptor(), + Ddesc.descriptor(), + preference.descriptor(), + 1, + &heuristicResult, + &returnedResult)); + if (returnedResult == 0) { + // hipblaslt might be able to recover by returning all algos + std::vector all_algos; + TORCH_CUDABLAS_CHECK(hipblaslt_ext::getAllAlgos( + ltHandle, + hipblaslt_ext::GemmType::HIPBLASLT_GEMM, + _cublasOpFromChar(transa), + _cublasOpFromChar(transb), + ScalarTypeToCudaDataType(mat1_dtype), + ScalarTypeToCudaDataType(mat2_dtype), + // C is nullptr and beta=0, so set to something reasonable. See above. + //ScalarTypeToCudaDataType(bias_dtype), + ScalarTypeToCudaDataType(result_dtype), + ScalarTypeToCudaDataType(result_dtype), + HIPBLAS_COMPUTE_32F, + all_algos)); + if (all_algos.size() == 0) { + TORCH_CUDABLAS_CHECK(HIPBLAS_STATUS_NOT_SUPPORTED); + } + // pick first valid solution + bool found = false; + for (size_t i = 0; i < all_algos.size(); i++) { + size_t ret_workspace_size = 0; + auto is_valid_status = hipblaslt_ext::matmulIsAlgoSupported( + ltHandle, + computeDesc.descriptor(), + &alpha_val, + Adesc.descriptor(), + Bdesc.descriptor(), + &beta_val, + Cdesc.descriptor(), + Ddesc.descriptor(), + all_algos[i].algo, + ret_workspace_size); + if (is_valid_status == HIPBLAS_STATUS_SUCCESS) { + if (ret_workspace_size <= workspaceSize) { + heuristicResult = all_algos[i]; + found = true; + break; + } + } + } + TORCH_CHECK(found, "could not find valid hipblaslt solution"); + } + hipblasStatus_t cublasStatus = hipblasLtMatmul( + ltHandle, + computeDesc.descriptor(), + &alpha_val, + mat1_ptr, + Adesc.descriptor(), + mat2_ptr, + Bdesc.descriptor(), + &beta_val, + result_ptr, // unused, since beta_val is 0, but hipblaslt can't handle nullptr + Cdesc.descriptor(), + result_ptr, + Ddesc.descriptor(), + &heuristicResult.algo, + workspace_ptr, + workspaceSize, + stream); + TORCH_CHECK( + cublasStatus == HIPBLAS_STATUS_SUCCESS, + "CUDA error: ", + at::cuda::blas::_cublasGetErrorEnum(cublasStatus), + " when calling hipblasLtMatmul with transpose_mat1 ", + transa, + " transpose_mat2 ", + transb, + " m ", + m, + " n ", + n, + " k ", + k, + " mat1_ld ", + mat1_ld, + " mat2_ld ", + mat2_ld, + " result_ld ", + result_ld, + " computeType ", + computeType, + " scaleType ", + scaleType); + return; +} + +Tensor& +_scaled_mm_out(const Tensor& mat1, const Tensor& mat2, + bool mat1_is_swizzled, + bool mat2_is_swizzled, + const Tensor& scale_a, + const Tensor& scale_b, + const std::optional& bias, + const std::optional& scale_result, + std::optional out_dtype, + Tensor& out) { + // Check sizes + TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix"); + TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix"); + TORCH_CHECK( + mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", + mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); + + // Check what type of scaling we are doing based on inputs + ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat2.size(1)); + TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported"); + + TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat), + "scale_result must be a float scalar"); + TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1], + " but got ", bias->numel()); + TORCH_CHECK( + mat1.sizes()[1] % 16 == 0, + "Expected trailing dimension of mat1 to be divisible by 16 ", + "but got mat1 shape: (", + mat1.sizes()[0], + "x", + mat1.sizes()[1], + ")."); + TORCH_CHECK(mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0, "mat2 shape (", mat2.sizes()[0], "x", + mat2.sizes()[1], ") must be divisible by 16"); + // Check types + TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type"); + TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type()); + TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type()); + if (bias) { + TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32"); + TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half, + "Bias must be either Half or BFloat16, but got ", bias->scalar_type()); + TORCH_CHECK((out.scalar_type() != kFloat && out.scalar_type() != ScalarType::BFloat16) || + bias->scalar_type() == ScalarType::BFloat16, + "Bias must be BFloat16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type()); + TORCH_CHECK(out.scalar_type() != ScalarType::Half || bias->scalar_type() == ScalarType::Half, + "Bias must be Float16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type()); + } + { + auto bias_ = bias.value_or(Tensor()); + auto scale_result_ = scale_result.value_or(Tensor()); + + // NOLINTNEXTLINE(*c-array*) + TensorArg targs[]{{out, "out", 0}, {mat1, "mat1", 1}, {mat2, "mat2", 2}, + {bias_, "bias", 3}, {scale_a, "scale_a", 4}, {scale_b, "scale_b", 5}, + {scale_result_, "scale_result", 6}}; + checkAllSameGPU(__func__, targs); + } + // Validation checks have passed lets resize the output to actual size + IntArrayRef mat1_sizes = mat1.sizes(); + IntArrayRef mat2_sizes = mat2.sizes(); + at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); + + // If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm kernels + // do not support this case). + if (mat1_sizes[0] == 0 || mat1_sizes[1] == 0 || mat2_sizes[1] == 0) { + // `out` was created with `at::empty`. In the case where we are multiplying + // MxK by KxN and K is the zero dim, we need to initialize here to properly + // return a tensor of zeros. + if (mat1_sizes[1] == 0) { + out.zero_(); + } + + return out; + } + + if (scaling_choice == ScalingType::RowWise) { + // For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes. + Tensor b = mat2; + if (_scaled_mm_is_fnuz()) { + TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fnuz); + } + else { + TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fn); + } + // Until more than bf16 is supported. + TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16, + "hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type()); + } + + cublasCommonArgs args(mat1, mat2, mat1_is_swizzled, mat2_is_swizzled, out, scale_a, scale_b, scale_result); + const auto out_dtype_ = args.result->scalar_type(); + TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); + + { + _scaled_gemm( + args.transa, + args.transb, + args.m, + args.n, + args.k, + args.mata->data_ptr(), + args.scale_mata_ptr, + args.lda, + args.mata->scalar_type(), + args.scale_mata_dtype.value(), + args.mata_is_swizzled, + args.matb->data_ptr(), + args.scale_matb_ptr, + args.ldb, + args.matb->scalar_type(), + args.scale_matb_dtype.value(), + args.matb_is_swizzled, + bias ? bias->data_ptr(): nullptr, + bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_, + args.result->data_ptr(), + args.scale_result_ptr, + args.result_ld, + out_dtype_, + scaling_choice == ScalingType::RowWise); + } + + return out; +} + +Tensor +swizzle_scaled_mm(const Tensor& mat_a, const Tensor& mat_b, + bool mat1_is_swizzled, + bool mat2_is_swizzled, + const Tensor& scale_a, + const Tensor& scale_b, + const std::optional& bias, + const std::optional& scale_result, + std::optional out_dtype) { + const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type()); + Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_)); + return _scaled_mm_out(mat_a, mat_b, mat1_is_swizzled, mat2_is_swizzled, scale_a, scale_b, bias, scale_result, out_dtype, out); +} + TORCH_LIBRARY_IMPL(torchao, CUDA, m) { m.impl("torchao::swizzle_mm", &swizzle_mm); + m.impl("torchao::swizzle_scaled_mm", &swizzle_scaled_mm); } diff --git a/torchao/ops.py b/torchao/ops.py index cc8606e0d8..55745b41ac 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -1,7 +1,8 @@ import functools +from typing import Optional import torch -from torch import Tensor +from torch import dtype, Tensor from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 @@ -30,6 +31,9 @@ lib.define( "swizzle_mm(Tensor mat1, Tensor mat2, bool mat1_is_swizzled, bool mat2_is_swizzled) -> Tensor" ) +lib.define( + "swizzle_scaled_mm(Tensor mat1, Tensor mat2, bool mat1_is_swizzled, bool mat2_is_swizzled, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None) -> Tensor" +) # Note: we need to add the `torch._C.Tag.needs_fixed_stride_order` tag in order for inductor # to honor the layout constraints for `b` in the two ops below. lib.define( @@ -634,6 +638,21 @@ def _(mat1: Tensor, mat2: Tensor, mat1_is_swizzled: bool, mat2_is_swizzled: bool return mat1.new_empty(mat1.shape[0], mat2.shape[1]) +def swizzle_scaled_mm(mat1: Tensor, mat2: Tensor, mat1_is_swizzled: bool, mat2_is_swizzled: bool, scale_a: Tensor, scale_b: Tensor, bias: Optional[Tensor], scale_result: Optional[Tensor], out_dtype: Optional[dtype]) -> Tensor: + """ + Similar to torch.mm but Tensor inputs can be SwizzleTensor instances. + + """ + return torch.ops.torchao.swizzle_scaled_mm.default( + mat1, mat2, mat1_is_swizzled, mat2_is_swizzled, scale_a, scale_b, bias, scale_result, out_dtype + ) + + +@register_custom_op("torchao::swizzle_scaled_mm") +def _(mat1: Tensor, mat2: Tensor, mat1_is_swizzled: bool, mat2_is_swizzled: bool, scale_a: Tensor, scale_b: Tensor, bias: Optional[Tensor], scale_result: Optional[Tensor], out_dtype: Optional[dtype]) -> Tensor: + return mat1.new_empty(mat1.shape[0], mat2.shape[1]) + + @functools.lru_cache() def _get_dtypes(): """TODO: when e8m0 is hardened and major release lets remove uint8 support""" diff --git a/torchao/swizzle/swizzle_ops.py b/torchao/swizzle/swizzle_ops.py index 390c0c6ba3..678d47abae 100644 --- a/torchao/swizzle/swizzle_ops.py +++ b/torchao/swizzle/swizzle_ops.py @@ -67,6 +67,28 @@ def swizzle_addmm(aten_op, args, kwargs=None): return aten_op(bias, a, b, args[3:], **kwargs) +@implements([aten._scaled_mm.default]) +def swizzle_scaled_mm(aten_op, args, kwargs=None): + print("in swizzle_scaled_mm") + a = args[0] + b = args[1] + scale_a = args[2] + scale_b = args[3] + bias = None if len(args) <= 4 else args[4] + scale_result = None if len(args) <= 5 else args[5] + out_dtype = None if len(args) <= 6 else args[6] + + a_is_swizzled = False + b_is_swizzled = False + if isinstance(a, SwizzleTensor): + a = a.as_tensor() + a_is_swizzled = True + if isinstance(b, SwizzleTensor): + b = b.as_tensor() + b_is_swizzled = True + return torchao.ops.swizzle_scaled_mm(a, b, a_is_swizzled, b_is_swizzled, scale_a, scale_b, bias, scale_result, out_dtype, **kwargs) + + @implements([aten.permute.default]) def swizzle_permute(aten_op, args, kwargs=None): tensor = args[0] diff --git a/torchao/swizzle/swizzle_tensor.py b/torchao/swizzle/swizzle_tensor.py index 65bc7192b8..0839f7b99d 100644 --- a/torchao/swizzle/swizzle_tensor.py +++ b/torchao/swizzle/swizzle_tensor.py @@ -32,21 +32,24 @@ def __init__(self, original, shallow=False): return #assert original.ndim == 2 or original.ndim == 3 # (M, K) or (B, M, K) assert original.ndim == 2, "SwizzleTensor only supports ndim 2" + assert original.itemsize == 1 or original.itemsize == 2 + kdiv = 32 if original.itemsize == 2 else 64 + lastdim = 8 if original.itemsize == 2 else 16 if original.ndim == 2: M, K = original.shape B = 0 if original.ndim == 3: B, M, K = original.shape alignedM = _get_min_alignment(M, 16) - alignedK = _get_min_alignment(K, 32) + alignedK = _get_min_alignment(K, kdiv) paddedM = alignedM - M paddedK = alignedK - K x = torch.nn.functional.pad(original, (0, paddedK, 0, paddedM), "constant", 0) if original.ndim == 2: - x = x.view(alignedM//16, 16, alignedK//32, 4, 8) + x = x.view(alignedM//16, 16, alignedK//kdiv, 4, lastdim) x = x.permute(0, 2, 3, 1, 4) if original.ndim == 3: - x = x.view(B, alignedM//16, 16, alignedK//32, 4, 8) + x = x.view(B, alignedM//16, 16, alignedK//kdiv, 4, lastdim) x = x.permute(0, 1, 3, 4, 2, 5) self.x = x.contiguous() self.B = B @@ -63,16 +66,20 @@ def __repr__(self): return f"{self.__class__.__name__}(original={self.unswizzle()})" def unswizzle(self): + undone = None if self.original_ndim == 2: undone = self.x.permute(0, 3, 1, 2, 4).contiguous() undone = undone.reshape(self.alignedM, self.alignedK) undone = undone[0:self.M, 0:self.K] - return undone.reshape(self.M, self.K) + undone = undone.reshape(self.M, self.K) + if self.is_transposed: + undone = undone.T if self.original_ndim == 3: undone = self.x.permute(0, 1, 4, 2, 3, 5).contiguous() undone = undone.reshape(self.B, self.alignedM, self.alignedK) undone = undone[0:self.B, 0:self.M, 0:self.K] - return undone.reshape(self.B, self.M, self.K) + undone = undone.reshape(self.B, self.M, self.K) + return undone def as_tensor(self): # note the transpose because this causes col major hipblaslt op to be TN From 640f00e51e9552e52e1de49ccb9c8ebd54b2543d Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Wed, 19 Mar 2025 01:47:36 +0000 Subject: [PATCH 09/11] remove print statement --- torchao/swizzle/swizzle_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchao/swizzle/swizzle_ops.py b/torchao/swizzle/swizzle_ops.py index 678d47abae..4483529d99 100644 --- a/torchao/swizzle/swizzle_ops.py +++ b/torchao/swizzle/swizzle_ops.py @@ -69,7 +69,6 @@ def swizzle_addmm(aten_op, args, kwargs=None): @implements([aten._scaled_mm.default]) def swizzle_scaled_mm(aten_op, args, kwargs=None): - print("in swizzle_scaled_mm") a = args[0] b = args[1] scale_a = args[2] From 4343c78b06ee92098c319189cbbab20587fdc8e4 Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Wed, 19 Mar 2025 05:08:25 +0000 Subject: [PATCH 10/11] setup.py missing check for vec ext --- setup.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 20a3c2433b..c972e339da 100644 --- a/setup.py +++ b/setup.py @@ -310,21 +310,29 @@ def get_extensions(): extra_link_args.append("/DEBUG") if use_rocm: - # naive search for hipblalst.h, if any found contain HIPBLASLT_ORDER_COL16 - found = False + # naive search for hipblalst.h, if any found contain HIPBLASLT_ORDER_COL16 and VEC_EXT + found_col16 = False + found_vec_ext = False print("ROCM_HOME", ROCM_HOME) hipblaslt_headers = list(glob.glob(os.path.join(ROCM_HOME, "include", "hipblaslt", "hipblaslt.h"))) print("hipblaslt_headers", hipblaslt_headers) for header in hipblaslt_headers: with open(header) as f: - if "HIPBLASLT_ORDER_COL16" in f.read(): - found = True - break - if found: + text = f.read() + if "HIPBLASLT_ORDER_COL16" in text: + found_col16 = True + if "HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT" in text: + found_vec_ext = True + if found_col16: extra_compile_args["cxx"].append("-DHIPBLASLT_HAS_ORDER_COL16") print("hipblaslt found extended col order enums") else: print("hipblaslt does not have extended col order enums") + if found_vec_ext: + extra_compile_args["cxx"].append("-DHIPBLASLT_VEC_EXT") + print("hipblaslt found vec ext") + else: + print("hipblaslt does not have vec ext") curdir = os.path.dirname(os.path.curdir) extensions_dir = os.path.join(curdir, "torchao", "csrc") From 8b574242440652cb6db09f3bfa877e995e369b5c Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Tue, 1 Apr 2025 21:11:21 +0000 Subject: [PATCH 11/11] remove merge mistake --- setup.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/setup.py b/setup.py index a69ac81e64..c181267f3d 100644 --- a/setup.py +++ b/setup.py @@ -280,8 +280,6 @@ def get_extensions(): "-DNDEBUG" if not debug_mode else "-DDEBUG", "-O3" if not debug_mode else "-O0", "-std=c++17", - "-U__HIP_NO_HALF_CONVERSIONS__", - "-U__HIP_NO_HALF_OPERATORS__", ] extra_link_args = []