diff --git a/setup.py b/setup.py index 0b5de7e855..55ec3644d2 100644 --- a/setup.py +++ b/setup.py @@ -83,8 +83,6 @@ def use_debug_mode(): _get_cuda_arch_flags, ) -IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None) - class BuildOptions: def __init__(self): @@ -259,28 +257,37 @@ def get_extensions(): print( "PyTorch GPU support is not available. Skipping compilation of CUDA extensions" ) - if (CUDA_HOME is None and ROCM_HOME is None) and torch.cuda.is_available(): - print( - "CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions" - ) + if CUDA_HOME is None and torch.cuda.is_available() and torch.version.cuda: + print("CUDA toolkit is not available. Skipping compilation of CUDA extensions") print( "If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit" ) + if ROCM_HOME is None and torch.cuda.is_available() and torch.version.hip: + print("ROCm is not available. Skipping compilation of ROCm extensions") + print("If you'd like to compile ROCm extensions locally please install ROCm") - use_cuda = torch.cuda.is_available() and ( - CUDA_HOME is not None or ROCM_HOME is not None + use_cuda = ( + torch.cuda.is_available() and torch.version.cuda and CUDA_HOME is not None ) - extension = CUDAExtension if use_cuda else CppExtension + use_hip = torch.cuda.is_available() and torch.version.hip and ROCM_HOME is not None + extension = CUDAExtension if (use_cuda or use_hip) else CppExtension + + nvcc_args = [ + "-DNDEBUG" if not debug_mode else "-DDEBUG", + "-O3" if not debug_mode else "-O0", + "-t=0", + "-std=c++17", + ] + hip_args = [ + "-DNDEBUG" if not debug_mode else "-DDEBUG", + "-O3" if not debug_mode else "-O0", + "-std=c++17", + ] extra_link_args = [] extra_compile_args = { "cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"], - "nvcc": [ - "-DNDEBUG" if not debug_mode else "-DDEBUG", - "-O3" if not debug_mode else "-O0", - "-t=0", - "-std=c++17", - ], + "nvcc": nvcc_args if use_cuda else hip_args, } if not IS_WINDOWS: @@ -314,6 +321,38 @@ def get_extensions(): extra_compile_args["nvcc"].append("-g") extra_link_args.append("/DEBUG") + hip_sparse_marlin_supported = True + if use_hip: + # naive search for hipblalst.h, if any found contain HIPBLASLT_ORDER_COL16 and VEC_EXT + found_col16 = False + found_vec_ext = False + print("ROCM_HOME", ROCM_HOME) + hipblaslt_headers = list( + glob.glob(os.path.join(ROCM_HOME, "include", "hipblaslt", "hipblaslt.h")) + ) + print("hipblaslt_headers", hipblaslt_headers) + for header in hipblaslt_headers: + with open(header) as f: + text = f.read() + if "HIPBLASLT_ORDER_COL16" in text: + found_col16 = True + if "HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT" in text: + found_vec_ext = True + if found_col16: + extra_compile_args["cxx"].append("-DHIPBLASLT_HAS_ORDER_COL16") + print("hipblaslt found extended col order enums") + else: + print("hipblaslt does not have extended col order enums") + if found_vec_ext: + extra_compile_args["cxx"].append("-DHIPBLASLT_VEC_EXT") + print("hipblaslt found vec ext") + else: + print("hipblaslt does not have vec ext") + + # sparse_marlin depends on features in ROCm 6.4, __builtin_amdgcn_global_load_lds + ROCM_VERSION = tuple(int(v) for v in torch.version.hip.split(".")[:2]) + hip_sparse_marlin_supported = ROCM_VERSION >= (6, 4) + # Get base directory and source paths curdir = os.path.dirname(os.path.curdir) extensions_dir = os.path.join(curdir, "torchao", "csrc") @@ -327,41 +366,50 @@ def get_extensions(): ) sources = [s for s in sources if s not in excluded_sources] + # Collect CUDA source files extensions_cuda_dir = os.path.join(extensions_dir, "cuda") cuda_sources = list( glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True) ) + # Collect HIP source files extensions_hip_dir = os.path.join( extensions_dir, "cuda", "tensor_core_tiled_layout" ) hip_sources = list( glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) ) - extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin") + if hip_sparse_marlin_supported: + extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin") + hip_sources += list( + glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) + ) + extensions_hip_dir = os.path.join(extensions_dir, "rocm") hip_sources += list( - glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True) + glob.glob(os.path.join(extensions_hip_dir, "**/*.hip"), recursive=True) + ) + hip_sources += list( + glob.glob(os.path.join(extensions_hip_dir, "**/*.cpp"), recursive=True) ) - # Collect CUDA source files if needed - if not IS_ROCM and use_cuda: + # Add CUDA source files if needed + if use_cuda: sources += cuda_sources - # TOOD: Remove this and use what CUDA has once we fix all the builds. - if IS_ROCM and use_cuda: + # TODO: Remove this and use what CUDA has once we fix all the builds. + # Add HIP source files if needed + if use_hip: # Add ROCm GPU architecture check gpu_arch = torch.cuda.get_device_properties(0).name if gpu_arch != "gfx942": print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}") - print( - "Currently only gfx942 is supported. Skipping compilation of ROCm extensions" - ) - else: - sources += hip_sources + print("Currently only gfx942 is supported. Compiling only for gfx942.") + extra_compile_args["nvcc"].append("--offload-arch=gfx942") + sources += hip_sources use_cutlass = False cutlass_90a_sources = None - if use_cuda and not IS_ROCM and not IS_WINDOWS: + if use_cuda and not IS_WINDOWS: use_cutlass = True cutlass_dir = os.path.join(third_party_path, "cutlass") cutlass_include_dir = os.path.join(cutlass_dir, "include") diff --git a/test/test_ops.py b/test/test_ops.py index 646d1c76af..5025b8a19b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -31,8 +31,8 @@ compute_max_diff, ) -if torch.version.hip is not None: - pytest.skip("Skipping the test in ROCm", allow_module_level=True) +IS_CUDA = torch.cuda.is_available() and torch.version.cuda +IS_ROCM = torch.cuda.is_available() and torch.version.hip try: import torchao.ops @@ -58,7 +58,7 @@ def _create_floatx_inputs( fp16_act = torch.rand(BS, IC).to(dtype) + 0.5 return floatx_weight.to(device), scale.to(device), fp16_act.to(device) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") @parametrize("ebits,mbits", [(3, 2), (2, 2)]) @parametrize("dtype", [torch.half, torch.bfloat16]) def test_quant_llm_linear(self, ebits, mbits, dtype): @@ -88,7 +88,7 @@ def test_quant_llm_linear(self, ebits, mbits, dtype): test_utils=test_utils, ) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") @parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) @parametrize("ebits,mbits", [(3, 2), (2, 2)]) @parametrize("dtype", [torch.half, torch.bfloat16]) @@ -274,7 +274,7 @@ def make_test_id(param): return f"tiles_{param}" -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id) def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles): @@ -292,7 +292,7 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles): # TODO: Fix "test_aot_dispatch_dynamic" test failure -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=make_test_id) def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles): @@ -338,7 +338,7 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16): return dq.reshape(n, k) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize( "shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str @@ -406,7 +406,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant( # This test differs from one above in that it uses `unpack_tensor_core_tiled_layout` to unpack then dequantize -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize( "shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str @@ -472,7 +472,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant( assert diff_op_ao < 1e-1 -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") # @pytest.mark.skipif(TORCH_VERSION_AT_LEAST_2_5, reason="weight packing is updated in 2.5+") @pytest.mark.parametrize( "shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str @@ -583,7 +583,7 @@ def reshape_w(w): ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") @pytest.mark.parametrize( "batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, @@ -673,7 +673,7 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not IS_CUDA, reason="CUDA not available") @pytest.mark.parametrize( "batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", MARLIN_TEST_PARAMS, @@ -752,5 +752,27 @@ def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_fact ) +@pytest.mark.skipif(not IS_ROCM, reason="ROCm not available") +def test_swizzle_mm(): + test_utils = [ + "test_schema", + "test_autograd_registration", + "test_faketensor", + ] + + # TODO: Figure out why test fails unless torch >= 2.5 + if TORCH_VERSION_AT_LEAST_2_5: + test_utils.append("test_aot_dispatch_dynamic") + + mat1 = torch.randint(0, 16, dtype=torch.float, size=(16, 32), device="cuda") + mat2 = torch.randint(0, 16, dtype=torch.float, size=(32, 16), device="cuda") + + opcheck( + torch.ops.torchao.swizzle_mm, + (mat1, mat2, False, False), + test_utils=test_utils, + ) + + if __name__ == "__main__": pytest.main(sys.argv) diff --git a/torchao/__init__.py b/torchao/__init__.py index 752aa94a4f..fb96282d77 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -43,13 +43,14 @@ quantize_, ) -from . import dtypes, optim, testing +from . import dtypes, optim, swizzle, testing __all__ = [ "dtypes", "autoquant", "optim", "quantize_", + "swizzle", "testing", "ops", ] diff --git a/torchao/csrc/rocm/swizzle/swizzle.cpp b/torchao/csrc/rocm/swizzle/swizzle.cpp new file mode 100644 index 0000000000..bfaf6bf466 --- /dev/null +++ b/torchao/csrc/rocm/swizzle/swizzle.cpp @@ -0,0 +1,911 @@ +// setup.py glob includes all *.cpp files +// but only build this for ROCm +#ifdef USE_ROCM +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using at::Scalar; +using at::Tensor; +using at::TensorArg; +using c10::kFloat; +using c10::ScalarType; +using c10::IntArrayRef; +using at::cuda::ScalarTypeToCudaDataType; + +// +// copied from aten/src/ATen/cuda/CUDABlas.cpp +// +namespace { + +static hipblasOperation_t _cublasOpFromChar(char op) { + // NOLINTNEXTLINE(bugprone-switch-missing-default-case) + switch (op) { + case 'n': + case 'N': + return HIPBLAS_OP_N; + case 't': + case 'T': + return HIPBLAS_OP_T; + case 'c': + case 'C': + return HIPBLAS_OP_C; + } + TORCH_CHECK(false, + "_cublasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`"); +} + +static void _cublasAdjustLdLevel3( + char transa, + char transb, + int64_t m, + int64_t n, + int64_t k, + int64_t* lda, + int64_t* ldb, + int64_t* ldc) { + bool transa_ = ((transa != 'n') && (transa != 'N')); + bool transb_ = ((transb != 'n') && (transb != 'N')); + + // Note: leading dimensions generally are checked that they are > 0 + // and at least as big the result requires (even if the value won't + // be used). + if (n <= 1) + *ldc = std::max(m, 1); + + if (transa_) { + if (m <= 1) + *lda = std::max(k, 1); + } else { + if (k <= 1) + *lda = std::max(m, 1); + } + + if (transb_) { + if (k <= 1) + *ldb = std::max(n, 1); + } else { + if (n <= 1) + *ldb = std::max(k, 1); + } +} + +// Following the pattern of CuSparseDescriptor +// Defined here for now because this is the only place cublas_lt interface is +// used but can be moved to a header once cublas_lt interface is used in +// multiple places. +template +struct HipBlasLtDeleter { + void operator()(T* x) { + if (x != nullptr) { + TORCH_CUDABLAS_CHECK(destructor(x)); + } + } +}; + +template +class HipBlasLtDescriptor { + public: + T* descriptor() const { + return descriptor_.get(); + } + T* descriptor() { + return descriptor_.get(); + } + + protected: + std::unique_ptr> descriptor_; +}; + +class HipBlasLtMatmulDescriptor : public HipBlasLtDescriptor< + hipblasLtMatmulDescOpaque_t, + &hipblasLtMatmulDescDestroy> { + public: + HipBlasLtMatmulDescriptor( + hipblasComputeType_t compute_type, + hipDataType scale_type) { + hipblasLtMatmulDesc_t raw_descriptor = nullptr; + TORCH_CUDABLAS_CHECK( + hipblasLtMatmulDescCreate(&raw_descriptor, compute_type, scale_type)); + descriptor_.reset(raw_descriptor); + } + template + inline void setAttribute(hipblasLtMatmulDescAttributes_t attr, const T value) { + // NOLINTNEXTLINE(bugprone-sizeof-expression) + TORCH_CUDABLAS_CHECK(::hipblasLtMatmulDescSetAttribute(descriptor(), attr, &value, sizeof(value))); + } +}; + +class HipBlasLtMatrixLayout : public HipBlasLtDescriptor< + hipblasLtMatrixLayoutOpaque_t, + &hipblasLtMatrixLayoutDestroy> { + public: + HipBlasLtMatrixLayout( + hipDataType type, + uint64_t rows, + uint64_t cols, + int64_t ld, + bool t = false) { + hipblasLtMatrixLayout_t raw_descriptor = nullptr; + TORCH_CUDABLAS_CHECK( + hipblasLtMatrixLayoutCreate(&raw_descriptor, type, t ? cols : rows, t ? rows : cols, ld)); + descriptor_.reset(raw_descriptor); + } + template + inline void setAttribute(hipblasLtMatrixLayoutAttribute_t attr, const T value) { + TORCH_CUDABLAS_CHECK(::hipblasLtMatrixLayoutSetAttribute(descriptor(), attr, &value, sizeof(T))); + } +}; + +class HipBlasLtMatmulPreference : public HipBlasLtDescriptor< + hipblasLtMatmulPreferenceOpaque_t, + &hipblasLtMatmulPreferenceDestroy> { + public: + HipBlasLtMatmulPreference() { + hipblasLtMatmulPreference_t raw_descriptor = nullptr; + TORCH_CUDABLAS_CHECK(hipblasLtMatmulPreferenceCreate(&raw_descriptor)); + descriptor_.reset(raw_descriptor); + } + template + inline void setAttribute(hipblasLtMatmulPreferenceAttributes_t attr, const T value) { + TORCH_CUDABLAS_CHECK(::hipblasLtMatmulPreferenceSetAttribute(descriptor(), attr, &value, sizeof(T))); + } +}; + +static size_t _parseChosenWorkspaceSize() { + auto val = c10::utils::get_env("CUBLASLT_WORKSPACE_SIZE"); + if (!val.has_value()) { + // accept either env var + val = c10::utils::get_env("HIPBLASLT_WORKSPACE_SIZE"); + } + size_t workspace_size = 76*1024; /* Use 76 MB for hipBLASLt */ + + if (val.has_value()) { + try { + workspace_size = std::stoi(val.value()); + } catch(std::invalid_argument const& e) { + TORCH_WARN("invalid CUBLASLT_WORKSPACE_SIZE,", + " using default workspace size of ", workspace_size, " KiB."); + } catch(std::out_of_range const& e) { + TORCH_WARN("CUBLASLT_WORKSPACE_SIZE out of range,", + " using default workspace size of ", workspace_size, " KiB."); + } + } + return workspace_size * 1024; +} + +static size_t _getWorkspaceSize() { + static size_t workspace_size = _parseChosenWorkspaceSize(); + return workspace_size; +} + +static bool _scaled_mm_is_fnuz() { + auto dprops = at::cuda::getCurrentDeviceProperties(); + std::string device_arch = dprops->gcnArchName; + static const std::vector archs = {"gfx940", "gfx941", "gfx942"}; + for (std::string arch : archs) { + size_t substring = device_arch.find(arch); + if (substring != std::string::npos) { + return true; + } + } + return false; +} + +} // namespace + +// +// copied from aten/src/ATen/native/cuda/Blas.cpp +// +namespace { + +// TODO: https://github.com/pytorch/pytorch/pull/59380#pullrequestreview-725310492 +c10::MaybeOwned inline resolve_conj_if_indicated(const Tensor& tensor, bool resolve_conj) { + if (resolve_conj && tensor.is_conj()) { + return c10::MaybeOwned::owned(tensor.resolve_conj()); + } else { + return c10::MaybeOwned::borrowed(tensor); + } +} + +c10::MaybeOwned inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor, bool transpose_result) { + if (tensor.is_non_overlapping_and_dense()) { // common case + transpose_tensor = tensor.is_contiguous(); + return resolve_conj_if_indicated(tensor, transpose_result ? transpose_tensor : !transpose_tensor); + } + IntArrayRef tensor_strides = tensor.strides(); + IntArrayRef tensor_sizes = tensor.sizes(); + if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max(1, tensor_sizes[0]))) { + transpose_tensor = false; + return resolve_conj_if_indicated(tensor, !transpose_result); + } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max(1, tensor_sizes[1]))) { + transpose_tensor = true; + return resolve_conj_if_indicated(tensor, transpose_result); + } else { + transpose_tensor = true; + return c10::MaybeOwned::owned(tensor.clone(at::MemoryFormat::Contiguous)); + } +} + +c10::MaybeOwned inline prepare_matrix_for_cublas(const Tensor& tensor, bool& transpose_tensor) { + if (tensor.is_non_overlapping_and_dense()) { // common case + transpose_tensor = tensor.is_contiguous(); + return resolve_conj_if_indicated(tensor, true); + } + + IntArrayRef tensor_strides = tensor.strides(); + IntArrayRef tensor_sizes = tensor.sizes(); + if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max(1, tensor_sizes[0]))) { + transpose_tensor = false; + return resolve_conj_if_indicated(tensor, true); + } else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max(1, tensor_sizes[1]))) { + transpose_tensor = true; + return resolve_conj_if_indicated(tensor, true); + } else { + transpose_tensor = true; + return c10::MaybeOwned::owned(tensor.clone(at::MemoryFormat::Contiguous)); + } +} + +struct cublasCommonArgs { + cublasCommonArgs( + const Tensor& mat1, + const Tensor& mat2, + bool swizzle1, + bool swizzle2, + Tensor& c, + const std::optional& scale_a = std::nullopt, + const std::optional& scale_b = std::nullopt, + const std::optional& scale_result = std::nullopt) { + bool transpose_result = false, transpose_a = false, transpose_b = false; + result = prepare_matrix_for_cublas(c, transpose_result); + mata = prepare_matrix_for_cublas(transpose_result ? mat2 : mat1, transpose_a, transpose_result); + matb = prepare_matrix_for_cublas(transpose_result ? mat1 : mat2, transpose_b, transpose_result); + + // Handle scale tensors if provided + if (scale_a && scale_b) { + // By default since we return in row-major we run the gemm + // as B.T @ A.T, check transpose_result to determine if we flip the scales + scale_mata_ptr = transpose_result ? scale_b->data_ptr() : scale_a->data_ptr(); + scale_mata_dtype = transpose_result ? scale_b->scalar_type() : scale_a->scalar_type(); + scale_matb_ptr = transpose_result ? scale_a->data_ptr() : scale_b->data_ptr(); + scale_matb_dtype = transpose_result ? scale_a->scalar_type() : scale_b->scalar_type(); + } + + if (scale_result) { + scale_result_ptr = scale_result->data_ptr(); + scale_result_dtype = scale_result->scalar_type(); + } + + // Update transpose flags + if (transpose_result) { + transpose_a = !transpose_a; + transpose_b = !transpose_b; + } + + auto sizes_a = mata->sizes(); + auto sizes_b = matb->sizes(); + + m = sizes_a[transpose_result ? 1 : 0]; + k = sizes_a[transpose_result ? 0 : 1]; + n = sizes_b[transpose_result ? 0 : 1]; + lda = mata->stride((transpose_a == transpose_result) ? 1 : 0); + ldb = matb->stride((transpose_b == transpose_result) ? 1 : 0); + result_ld = result->stride(transpose_result ? 0 : 1); + transa = transpose_a ? mata->is_conj() ? 'c' : 't' : 'n'; + transb = transpose_b ? matb->is_conj() ? 'c' : 't' : 'n'; + + mata_is_swizzled = transpose_result ? swizzle2 : swizzle1; + matb_is_swizzled = transpose_result ? swizzle1 : swizzle2; + } + + // Matrix members + char transa, transb; + int64_t m, n, k; + int64_t lda, ldb, result_ld; + c10::MaybeOwned mata, matb, result; + + // Scale members + void* scale_mata_ptr = nullptr; + void* scale_matb_ptr = nullptr; + void* scale_result_ptr = nullptr; + std::optional scale_mata_dtype; + std::optional scale_matb_dtype; + std::optional scale_result_dtype; + + // swizzle members + bool mata_is_swizzled; + bool matb_is_swizzled; +}; + +enum class ScalingType { + TensorWise, + RowWise, + Error +}; + +ScalingType get_scaling_type( + const at::Tensor& scale_a, + const at::Tensor& scale_b, + int64_t dim_m, + int64_t dim_n) { + // Both Per-Tensor and Row-wise scaling expect fp32 tensors + TORCH_CHECK( + scale_a.scalar_type() == kFloat && scale_b.scalar_type() == kFloat, + "Both scale_a and scale_b must be float (fp32) tensors."); + + // Check the singluar scale case for per-tensor scaling + if (scale_a.numel() == 1 && scale_b.numel() == 1) { + return ScalingType::TensorWise; + } + + // For non-TensorWise scaling, enforce 2D input tensors + TORCH_CHECK( + scale_a.dim() == 2 && scale_b.dim() == 2, + "For non-TensorWise scaling, scale tensors must be 2-dimensional, " + "but got scale_a.dim()=", + scale_a.dim(), + " and scale_b.dim()=", + scale_b.dim()); + + // Check for RowWise scaling + if (scale_a.size(0) == dim_m && scale_a.size(1) == 1 && + scale_b.size(0) == 1 && scale_b.size(1) == dim_n) { +#if defined(HIPBLASLT_VEC_EXT) + TORCH_CHECK( + scale_a.is_contiguous() && scale_b.is_contiguous(), + "Both scale_a and scale_b must be contiguous for RowWise scaling."); + return ScalingType::RowWise; +#else + TORCH_CHECK(false, "Per-row scaling is not supported for this platform!"); + return ScalingType::Error; +#endif + } + + // If we reach here, the input doesn't match any valid scaling type + TORCH_CHECK( + false, + "Invalid scaling configuration. For TensorWise scaling, both scales should be scalar. " + "For RowWise scaling, scale_a should be (", + dim_m, + ", 1) and scale_b should be (1, ", + dim_n, + "). " + "Got scale_a.size()=(", + scale_a.size(0), + ", ", + scale_a.size(1), + ") and ", + "scale_b.size()=(", + scale_b.size(0), + ", ", + scale_b.size(1), + ")"); + + return ScalingType::Error; +} + +} // namespace + +template +inline void bgemm_hipblaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype), bool mat1_is_swizzled, bool mat2_is_swizzled) { + hipDataType abcType = HIP_R_32F; + hipblasComputeType_t computeType = HIPBLAS_COMPUTE_32F; + hipDataType scaleType = HIP_R_32F; + if constexpr (std::is_same_v) { + abcType = HIP_R_64F; + computeType = HIPBLAS_COMPUTE_64F; + scaleType = HIP_R_64F; + } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v>) { + abcType = HIP_C_64F; + computeType = HIPBLAS_COMPUTE_64F; + scaleType = HIP_C_64F; + } else if constexpr (std::is_same_v>) { + abcType = HIP_C_32F; + scaleType = HIP_C_32F; + } else if constexpr (std::is_same_v) { + abcType = HIP_R_16F; + } else if constexpr (std::is_same_v) { + abcType = HIP_R_16BF; + } else { + static_assert(false && sizeof(Dtype), "at::cuda::blas::bgemm_internal_cublaslt: not implemented"); + } + + hipblasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle(); + hipblasOperation_t opa = _cublasOpFromChar(transa); + hipblasOperation_t opb = _cublasOpFromChar(transb); + _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); + + HipBlasLtMatmulDescriptor computeDesc(computeType, scaleType); + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, opa); + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, opb); + HipBlasLtMatrixLayout Adesc(abcType, m, k, lda, opa == HIPBLAS_OP_T); + HipBlasLtMatrixLayout Bdesc(abcType, k, n, ldb, opb == HIPBLAS_OP_T); + HipBlasLtMatrixLayout Cdesc(abcType, m, n, ldc); +#ifdef HIPBLASLT_HAS_ORDER_COL16 + if (mat1_is_swizzled) { + Adesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_ORDER, HIPBLASLT_ORDER_COL16_4R8); + } + if (mat2_is_swizzled) { + Bdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_ORDER, HIPBLASLT_ORDER_COL16_4R8); + } +#endif + + if (num_batches > 1) { + int num_batches_as_int = static_cast(num_batches); + Adesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int); + Bdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int); + Cdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_BATCH_COUNT, num_batches_as_int); + Adesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stridea); + Bdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, strideb); + Cdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, stridec); + } + + hipblasLtEpilogue_t epilogue = HIPBLASLT_EPILOGUE_DEFAULT; + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, epilogue); + + HipBlasLtMatmulPreference preference; + // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind + // setting this to 1M. + size_t workspaceSize = _getWorkspaceSize(); + preference.setAttribute(HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize); + + auto workspace = at::empty(static_cast(workspaceSize), at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + + hipblasLtMatmulHeuristicResult_t heuristicResult = {}; + int returnedResult = 0; + TORCH_CUDABLAS_CHECK(hipblasLtMatmulAlgoGetHeuristic( + ltHandle, + computeDesc.descriptor(), + Adesc.descriptor(), + Bdesc.descriptor(), + Cdesc.descriptor(), + Cdesc.descriptor(), + preference.descriptor(), + 1, + &heuristicResult, + &returnedResult)); + if (returnedResult == 0) { + TORCH_CUDABLAS_CHECK(HIPBLAS_STATUS_NOT_SUPPORTED); + } + + hipblasStatus_t cublasStatus = hipblasLtMatmul( + ltHandle, + computeDesc.descriptor(), + &alpha, + a, + Adesc.descriptor(), + b, + Bdesc.descriptor(), + &beta, + c, + Cdesc.descriptor(), + c, + Cdesc.descriptor(), + &heuristicResult.algo, + workspace.mutable_data_ptr(), + workspaceSize, + at::hip::getCurrentHIPStreamMasqueradingAsCUDA()); + TORCH_CHECK( + cublasStatus == HIPBLAS_STATUS_SUCCESS, + "CUDA error: ", + at::cuda::blas::_cublasGetErrorEnum(cublasStatus), + " when calling hipblasLtMatmul with transpose_mat1 ", + (opa == HIPBLAS_OP_T), + " transpose_mat2 ", + (opb == HIPBLAS_OP_T), + " m ", + m, + " n ", + n, + " k ", + k, + " lda ", + lda, + " ldb ", + ldb, + " ldc ", + ldc, + " abcType ", + abcType, + " computeType ", + computeType, + " scaleType ", + scaleType); +} + + +template +inline void gemm_hipblaslt(CUDABLAS_GEMM_ARGTYPES(Dtype), bool mat1_is_swizzled, bool mat2_is_swizzled) { + // forward to bgemm implementation but set strides and batches to 0 + bgemm_hipblaslt(transa, transb, m, n, k, alpha, a, lda, 0, b, ldb, 0, beta, c, ldc, 0, 0, mat1_is_swizzled, mat2_is_swizzled); +} + + +Tensor swizzle_mm(const Tensor& mat1, const Tensor& mat2, bool mat1_is_swizzled, bool mat2_is_swizzled) { + TORCH_CHECK( + mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() + ); + + // NOLINTNEXTLINE(*c-array*) + TensorArg targs[]{{mat1, "mat1", 0}, {mat2, "mat2", 1}}; + checkAllSameGPU(__func__, targs); + + Tensor meta_mat1 = mat1.to("meta"); + Tensor meta_mat2 = mat2.to("meta"); + Tensor meta_result = at::mm(meta_mat1, meta_mat2); + Tensor result = at::empty_like(meta_result, mat1.device()); + at::ScalarType scalar_type = result.scalar_type(); + + cublasCommonArgs args(mat1, mat2, mat1_is_swizzled, mat2_is_swizzled, result); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, + at::ScalarType::BFloat16, + scalar_type, + "addmm_cuda", + [&] { + using opmath_t = at::opmath_type; + opmath_t alpha_val = opmath_t(1.0); + opmath_t beta_val = opmath_t(0.0); + const scalar_t* mat1_ptr = args.mata->const_data_ptr(); + const scalar_t* mat2_ptr = args.matb->const_data_ptr(); + scalar_t* result_ptr = args.result->mutable_data_ptr(); + gemm_hipblaslt( + args.transa, + args.transb, + args.m, + args.n, + args.k, + alpha_val, + mat1_ptr, + args.lda, + mat2_ptr, + args.ldb, + beta_val, + result_ptr, + args.result_ld, + args.mata_is_swizzled, + args.matb_is_swizzled); + }); + + return result; +} + +void _scaled_gemm( + char transa, + char transb, + int64_t m, + int64_t n, + int64_t k, + const void* mat1_ptr, + const void* mat1_scale_ptr, + int64_t mat1_ld, + ScalarType mat1_dtype, + ScalarType mat1_scale_dtype, + bool mat1_is_swizzled, + const void* mat2_ptr, + const void* mat2_scale_ptr, + int64_t mat2_ld, + ScalarType mat2_dtype, + ScalarType mat2_scale_dtype, + bool mat2_is_swizzled, + const void* bias_ptr, + ScalarType bias_dtype, + void* result_ptr, + const void *result_scale_ptr, + int64_t result_ld, + ScalarType result_dtype, + bool use_rowwise) { + const auto computeType = HIPBLAS_COMPUTE_32F; + const auto scaleType = HIP_R_32F; + const float alpha_val = 1.0; + const float beta_val = 0.0; + HipBlasLtMatmulDescriptor computeDesc(computeType, scaleType); + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSA, _cublasOpFromChar(transa)); + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_TRANSB, _cublasOpFromChar(transb)); + hipblasLtMatmulDescAttributes_t matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER; + hipblasLtMatmulDescAttributes_t matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER; +#if defined(HIPBLASLT_VEC_EXT) + if (use_rowwise) { + matmulDescA = HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER_VEC_EXT; + matmulDescB = HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER_VEC_EXT; + } +#else + // rowwise isn't supported using cublaslt or older hipblaslt + TORCH_INTERNAL_ASSERT(use_rowwise == false, "rowwise scaled_gemm not supported with blaslt"); +#endif + computeDesc.setAttribute(matmulDescA, mat1_scale_ptr); + computeDesc.setAttribute(matmulDescB, mat2_scale_ptr); + if (result_scale_ptr != nullptr) { + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); + } + HipBlasLtMatrixLayout Adesc(ScalarTypeToCudaDataType(mat1_dtype), m, k, mat1_ld, transa == 't'); + HipBlasLtMatrixLayout Bdesc(ScalarTypeToCudaDataType(mat2_dtype), k, n, mat2_ld, transb == 't'); + // Cdesc is unused, beta is 0. But hipblaslt needs this set to something reasonable. + HipBlasLtMatrixLayout Cdesc(ScalarTypeToCudaDataType(result_dtype), m, n, result_ld); + HipBlasLtMatrixLayout Ddesc(ScalarTypeToCudaDataType(result_dtype), m, n, result_ld); + if (bias_ptr) { + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_POINTER, bias_ptr); + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_EPILOGUE, HIPBLASLT_EPILOGUE_BIAS); + computeDesc.setAttribute(HIPBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, ScalarTypeToCudaDataType(bias_dtype)); + } + +#ifdef HIPBLASLT_HAS_ORDER_COL16 + if (mat1_is_swizzled) { + Adesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_ORDER, HIPBLASLT_ORDER_COL16_4R16); + } + if (mat2_is_swizzled) { + Bdesc.setAttribute(HIPBLASLT_MATRIX_LAYOUT_ORDER, HIPBLASLT_ORDER_COL16_4R16); + } +#endif + + auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + size_t workspaceSize = _getWorkspaceSize(); + auto& allocator = *::c10::hip::HIPCachingAllocatorMasqueradingAsCUDA::get(); + auto workspace = allocator.allocate(workspaceSize); + auto workspace_ptr = workspace.mutable_get(); + TORCH_CHECK(workspace_ptr != nullptr, "OOM trying to allocate workspace for cublaslt"); + + HipBlasLtMatmulPreference preference; + preference.setAttribute(HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, workspaceSize); + hipblasLtMatmulHeuristicResult_t heuristicResult = {}; + int returnedResult = 0; + hipblasLtHandle_t ltHandle = at::cuda::getCurrentCUDABlasLtHandle(); + TORCH_CUDABLAS_CHECK(hipblasLtMatmulAlgoGetHeuristic( + ltHandle, + computeDesc.descriptor(), + Adesc.descriptor(), + Bdesc.descriptor(), + Cdesc.descriptor(), + Ddesc.descriptor(), + preference.descriptor(), + 1, + &heuristicResult, + &returnedResult)); + if (returnedResult == 0) { + // hipblaslt might be able to recover by returning all algos + std::vector all_algos; + TORCH_CUDABLAS_CHECK(hipblaslt_ext::getAllAlgos( + ltHandle, + hipblaslt_ext::GemmType::HIPBLASLT_GEMM, + _cublasOpFromChar(transa), + _cublasOpFromChar(transb), + ScalarTypeToCudaDataType(mat1_dtype), + ScalarTypeToCudaDataType(mat2_dtype), + // C is nullptr and beta=0, so set to something reasonable. See above. + //ScalarTypeToCudaDataType(bias_dtype), + ScalarTypeToCudaDataType(result_dtype), + ScalarTypeToCudaDataType(result_dtype), + HIPBLAS_COMPUTE_32F, + all_algos)); + if (all_algos.size() == 0) { + TORCH_CUDABLAS_CHECK(HIPBLAS_STATUS_NOT_SUPPORTED); + } + // pick first valid solution + bool found = false; + for (size_t i = 0; i < all_algos.size(); i++) { + size_t ret_workspace_size = 0; + auto is_valid_status = hipblaslt_ext::matmulIsAlgoSupported( + ltHandle, + computeDesc.descriptor(), + &alpha_val, + Adesc.descriptor(), + Bdesc.descriptor(), + &beta_val, + Cdesc.descriptor(), + Ddesc.descriptor(), + all_algos[i].algo, + ret_workspace_size); + if (is_valid_status == HIPBLAS_STATUS_SUCCESS) { + if (ret_workspace_size <= workspaceSize) { + heuristicResult = all_algos[i]; + found = true; + break; + } + } + } + TORCH_CHECK(found, "could not find valid hipblaslt solution"); + } + hipblasStatus_t cublasStatus = hipblasLtMatmul( + ltHandle, + computeDesc.descriptor(), + &alpha_val, + mat1_ptr, + Adesc.descriptor(), + mat2_ptr, + Bdesc.descriptor(), + &beta_val, + result_ptr, // unused, since beta_val is 0, but hipblaslt can't handle nullptr + Cdesc.descriptor(), + result_ptr, + Ddesc.descriptor(), + &heuristicResult.algo, + workspace_ptr, + workspaceSize, + stream); + TORCH_CHECK( + cublasStatus == HIPBLAS_STATUS_SUCCESS, + "CUDA error: ", + at::cuda::blas::_cublasGetErrorEnum(cublasStatus), + " when calling hipblasLtMatmul with transpose_mat1 ", + transa, + " transpose_mat2 ", + transb, + " m ", + m, + " n ", + n, + " k ", + k, + " mat1_ld ", + mat1_ld, + " mat2_ld ", + mat2_ld, + " result_ld ", + result_ld, + " computeType ", + computeType, + " scaleType ", + scaleType); + return; +} + +Tensor& +_scaled_mm_out(const Tensor& mat1, const Tensor& mat2, + bool mat1_is_swizzled, + bool mat2_is_swizzled, + const Tensor& scale_a, + const Tensor& scale_b, + const std::optional& bias, + const std::optional& scale_result, + std::optional out_dtype, + Tensor& out) { + // Check sizes + TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix"); + TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix"); + TORCH_CHECK( + mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", + mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); + + // Check what type of scaling we are doing based on inputs + ScalingType scaling_choice = get_scaling_type(scale_a, scale_b, mat1.size(0), mat2.size(1)); + TORCH_INTERNAL_ASSERT(scaling_choice != ScalingType::Error, "Scaling type not supported"); + + TORCH_CHECK(!scale_result || (scale_result->numel() == 1 && scale_result->scalar_type() == kFloat), + "scale_result must be a float scalar"); + TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1], + " but got ", bias->numel()); + TORCH_CHECK( + mat1.sizes()[1] % 16 == 0, + "Expected trailing dimension of mat1 to be divisible by 16 ", + "but got mat1 shape: (", + mat1.sizes()[0], + "x", + mat1.sizes()[1], + ")."); + TORCH_CHECK(mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0, "mat2 shape (", mat2.sizes()[0], "x", + mat2.sizes()[1], ") must be divisible by 16"); + // Check types + TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type"); + TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type()); + TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type()); + if (bias) { + TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32"); + TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half, + "Bias must be either Half or BFloat16, but got ", bias->scalar_type()); + TORCH_CHECK((out.scalar_type() != kFloat && out.scalar_type() != ScalarType::BFloat16) || + bias->scalar_type() == ScalarType::BFloat16, + "Bias must be BFloat16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type()); + TORCH_CHECK(out.scalar_type() != ScalarType::Half || bias->scalar_type() == ScalarType::Half, + "Bias must be Float16 to compute ", out.scalar_type(), " output, but got ", bias->scalar_type()); + } + { + auto bias_ = bias.value_or(Tensor()); + auto scale_result_ = scale_result.value_or(Tensor()); + + // NOLINTNEXTLINE(*c-array*) + TensorArg targs[]{{out, "out", 0}, {mat1, "mat1", 1}, {mat2, "mat2", 2}, + {bias_, "bias", 3}, {scale_a, "scale_a", 4}, {scale_b, "scale_b", 5}, + {scale_result_, "scale_result", 6}}; + checkAllSameGPU(__func__, targs); + } + // Validation checks have passed lets resize the output to actual size + IntArrayRef mat1_sizes = mat1.sizes(); + IntArrayRef mat2_sizes = mat2.sizes(); + at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]}); + + // If any of M, K, N is 0 - return early (the tensorwise/rowwise float8 gemm kernels + // do not support this case). + if (mat1_sizes[0] == 0 || mat1_sizes[1] == 0 || mat2_sizes[1] == 0) { + // `out` was created with `at::empty`. In the case where we are multiplying + // MxK by KxN and K is the zero dim, we need to initialize here to properly + // return a tensor of zeros. + if (mat1_sizes[1] == 0) { + out.zero_(); + } + + return out; + } + + if (scaling_choice == ScalingType::RowWise) { + // For ROCm, match behavior of f8f8bf16_rowwise type checking, for unit test purposes. + Tensor b = mat2; + if (_scaled_mm_is_fnuz()) { + TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fnuz); + } + else { + TORCH_CHECK(b.dtype() == at::kFloat8_e4m3fn); + } + // Until more than bf16 is supported. + TORCH_CHECK(out.scalar_type() == ScalarType::BFloat16, + "hipblaslt rowwise _scaled_mm only supports BFloat16 output but got ", out.scalar_type()); + } + + cublasCommonArgs args(mat1, mat2, mat1_is_swizzled, mat2_is_swizzled, out, scale_a, scale_b, scale_result); + const auto out_dtype_ = args.result->scalar_type(); + TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt"); + + { + _scaled_gemm( + args.transa, + args.transb, + args.m, + args.n, + args.k, + args.mata->data_ptr(), + args.scale_mata_ptr, + args.lda, + args.mata->scalar_type(), + args.scale_mata_dtype.value(), + args.mata_is_swizzled, + args.matb->data_ptr(), + args.scale_matb_ptr, + args.ldb, + args.matb->scalar_type(), + args.scale_matb_dtype.value(), + args.matb_is_swizzled, + bias ? bias->data_ptr(): nullptr, + bias ? bias->scalar_type() : isFloat8Type(out_dtype_) ? at::ScalarType::Half : out_dtype_, + args.result->data_ptr(), + args.scale_result_ptr, + args.result_ld, + out_dtype_, + scaling_choice == ScalingType::RowWise); + } + + return out; +} + +Tensor +swizzle_scaled_mm(const Tensor& mat_a, const Tensor& mat_b, + bool mat1_is_swizzled, + bool mat2_is_swizzled, + const Tensor& scale_a, + const Tensor& scale_b, + const std::optional& bias, + const std::optional& scale_result, + std::optional out_dtype) { + const auto out_dtype_ = out_dtype.value_or(mat_a.scalar_type()); + Tensor out = at::empty({0}, mat_a.options().dtype(out_dtype_)); + return _scaled_mm_out(mat_a, mat_b, mat1_is_swizzled, mat2_is_swizzled, scale_a, scale_b, bias, scale_result, out_dtype, out); +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::swizzle_mm", &swizzle_mm); + m.impl("torchao::swizzle_scaled_mm", &swizzle_scaled_mm); +} +#endif // USE_ROCM diff --git a/torchao/ops.py b/torchao/ops.py index a94aee589f..2f8b4ae645 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -7,7 +7,7 @@ from typing import Optional import torch -from torch import Tensor +from torch import Tensor, dtype from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 @@ -39,6 +39,12 @@ lib.define( "to_sparse_semi_structured_cutlass_sm9x_f8(Tensor weight) -> (Tensor, Tensor)" ) +lib.define( + "swizzle_mm(Tensor mat1, Tensor mat2, bool mat1_is_swizzled, bool mat2_is_swizzled) -> Tensor" +) +lib.define( + "swizzle_scaled_mm(Tensor mat1, Tensor mat2, bool mat1_is_swizzled, bool mat2_is_swizzled, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None) -> Tensor" +) # Note: we need to add the `torch._C.Tag.needs_fixed_stride_order` tag in order for inductor # to honor the layout constraints for `b` in the two ops below. lib.define( @@ -820,6 +826,68 @@ def _( ) +def swizzle_mm( + mat1: Tensor, mat2: Tensor, mat1_is_swizzled: bool, mat2_is_swizzled: bool +) -> Tensor: + """ + Similar to torch.mm but Tensor inputs can be SwizzleTensor instances. + + """ + return torch.ops.torchao.swizzle_mm.default( + mat1, mat2, mat1_is_swizzled, mat2_is_swizzled + ) + + +@register_custom_op("torchao::swizzle_mm") +def _( + mat1: Tensor, mat2: Tensor, mat1_is_swizzled: bool, mat2_is_swizzled: bool +) -> Tensor: + return mat1.new_empty(mat1.shape[0], mat2.shape[1]) + + +def swizzle_scaled_mm( + mat1: Tensor, + mat2: Tensor, + mat1_is_swizzled: bool, + mat2_is_swizzled: bool, + scale_a: Tensor, + scale_b: Tensor, + bias: Optional[Tensor], + scale_result: Optional[Tensor], + out_dtype: Optional[dtype], +) -> Tensor: + """ + Similar to torch.mm but Tensor inputs can be SwizzleTensor instances. + + """ + return torch.ops.torchao.swizzle_scaled_mm.default( + mat1, + mat2, + mat1_is_swizzled, + mat2_is_swizzled, + scale_a, + scale_b, + bias, + scale_result, + out_dtype, + ) + + +@register_custom_op("torchao::swizzle_scaled_mm") +def _( + mat1: Tensor, + mat2: Tensor, + mat1_is_swizzled: bool, + mat2_is_swizzled: bool, + scale_a: Tensor, + scale_b: Tensor, + bias: Optional[Tensor], + scale_result: Optional[Tensor], + out_dtype: Optional[dtype], +) -> Tensor: + return mat1.new_empty(mat1.shape[0], mat2.shape[1]) + + @functools.lru_cache() def _get_dtypes(): """TODO: when e8m0 is hardened and major release lets remove uint8 support""" diff --git a/torchao/swizzle/__init__.py b/torchao/swizzle/__init__.py new file mode 100644 index 0000000000..7aa001267c --- /dev/null +++ b/torchao/swizzle/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .swizzle_tensor import SwizzleTensor + +__all__ = ["SwizzleTensor"] diff --git a/torchao/swizzle/swizzle_ops.py b/torchao/swizzle/swizzle_ops.py new file mode 100644 index 0000000000..7e62922a02 --- /dev/null +++ b/torchao/swizzle/swizzle_ops.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from typing import Any, Dict + +import torch + +import torchao.ops +from torchao.swizzle.swizzle_tensor import SwizzleTensor + +aten = torch.ops.aten +SWIZZLE_OPS_TABLE: Dict[Any, Any] = {} + + +def implements(aten_ops): + """Register aten ops to the swizzle op table""" + + def decorator(func): + for op in aten_ops: + SWIZZLE_OPS_TABLE[op] = func + return func + + return decorator + + +@implements([aten.mm.default, aten.matmul.default]) +def swizzle_mm(aten_op, args, kwargs=None): + a = args[0] + b = args[1] + + if torch.is_floating_point(a) and torch.is_floating_point(b) and a.ndim == 2 and b.ndim == 2: + a_is_swizzled = False + b_is_swizzled = False + if isinstance(a, SwizzleTensor): + a = a.as_tensor() + a_is_swizzled = True + if isinstance(b, SwizzleTensor): + b = b.as_tensor() + b_is_swizzled = True + tensor_out = torchao.ops.swizzle_mm(a, b, a_is_swizzled, b_is_swizzled) + else: + a = a.unswizzle() if isinstance(a, SwizzleTensor) else a + b = b.unswizzle() if isinstance(b, SwizzleTensor) else b + tensor_out = aten_op(a, b, **kwargs) + return tensor_out + + +@implements([aten.bmm.default]) +def swizzle_bmm(aten_op, args, kwargs=None): + a = args[0] + b = args[1] + + a = a.unswizzle() if isinstance(a, SwizzleTensor) else a + b = b.unswizzle() if isinstance(b, SwizzleTensor) else b + return aten_op(a, b, **kwargs) + + +@implements([aten.addmm.default]) +def swizzle_addmm(aten_op, args, kwargs=None): + bias = args[0] + a = args[1] + b = args[2] + a = a.unswizzle() if isinstance(a, SwizzleTensor) else a + b = b.unswizzle() if isinstance(b, SwizzleTensor) else b + return aten_op(bias, a, b, args[3:], **kwargs) + + +@implements([aten._scaled_mm.default]) +def swizzle_scaled_mm(aten_op, args, kwargs=None): + a = args[0] + b = args[1] + scale_a = args[2] + scale_b = args[3] + bias = None if len(args) <= 4 else args[4] + scale_result = None if len(args) <= 5 else args[5] + out_dtype = None if len(args) <= 6 else args[6] + + a_is_swizzled = False + b_is_swizzled = False + if isinstance(a, SwizzleTensor): + a = a.as_tensor() + a_is_swizzled = True + if isinstance(b, SwizzleTensor): + b = b.as_tensor() + b_is_swizzled = True + return torchao.ops.swizzle_scaled_mm( + a, + b, + a_is_swizzled, + b_is_swizzled, + scale_a, + scale_b, + bias, + scale_result, + out_dtype, + **kwargs, + ) + + +@implements([aten.permute.default]) +def swizzle_permute(aten_op, args, kwargs=None): + tensor = args[0] + dims = args[1] + if len(dims) == 2 and dims[0] == 1 and dims[1] == 0: + return tensor.shallow_transpose() + return aten_op(tensor.unswizzle(), dims) + + +@implements([aten.numpy_T.default]) +def swizzle_numpy_T(aten_op, args, kwargs=None): + tensor = args[0] + return tensor.shallow_transpose() diff --git a/torchao/swizzle/swizzle_tensor.py b/torchao/swizzle/swizzle_tensor.py new file mode 100644 index 0000000000..8ddfd9308a --- /dev/null +++ b/torchao/swizzle/swizzle_tensor.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.utils._pytree import tree_map + + +# copied from float8_utils.py +def _get_min_alignment(size: int, alignment_value: int) -> int: + return (1 + ((size - 1) // alignment_value)) * alignment_value + + +class SwizzleTensor(torch.Tensor): + """ + A Python-only swizzled tensor subclass. + + Intended usage of this abstraction: + Swizzle weight Tensor to avoid LDS use during GEMMs on ROCm hardware. + """ + + def __new__( + cls, + original: torch.Tensor, + shallow: bool = False, + ): + wrapper = torch.empty_like(original, device="meta") + return torch.Tensor._make_subclass(cls, wrapper) + + def __init__(self, original, shallow=False): + if shallow: + return + # assert original.ndim == 2 or original.ndim == 3 # (M, K) or (B, M, K) + assert original.ndim == 2, "SwizzleTensor only supports ndim 2" + assert original.itemsize == 1 or original.itemsize == 2 + kdiv = 32 if original.itemsize == 2 else 64 + lastdim = 8 if original.itemsize == 2 else 16 + if original.ndim == 2: + M, K = original.shape + B = 0 + if original.ndim == 3: + B, M, K = original.shape + alignedM = _get_min_alignment(M, 16) + alignedK = _get_min_alignment(K, kdiv) + paddedM = alignedM - M + paddedK = alignedK - K + x = torch.nn.functional.pad(original, (0, paddedK, 0, paddedM), "constant", 0) + if original.ndim == 2: + x = x.view(alignedM // 16, 16, alignedK // kdiv, 4, lastdim) + x = x.permute(0, 2, 3, 1, 4) + if original.ndim == 3: + x = x.view(B, alignedM // 16, 16, alignedK // kdiv, 4, lastdim) + x = x.permute(0, 1, 3, 4, 2, 5) + self.x = x.contiguous() + self.B = B + self.M = M + self.K = K + self.alignedM = alignedM + self.alignedK = alignedK + self.paddedM = paddedM + self.paddedK = paddedK + self.original_ndim = original.ndim + self.is_transposed = False + + def __repr__(self): + return f"{self.__class__.__name__}(original={self.unswizzle()})" + + def unswizzle(self): + undone = None + if self.original_ndim == 2: + undone = self.x.permute(0, 3, 1, 2, 4).contiguous() + undone = undone.reshape(self.alignedM, self.alignedK) + undone = undone[0 : self.M, 0 : self.K] + undone = undone.reshape(self.M, self.K) + if self.is_transposed: + undone = undone.T + if self.original_ndim == 3: + undone = self.x.permute(0, 1, 4, 2, 3, 5).contiguous() + undone = undone.reshape(self.B, self.alignedM, self.alignedK) + undone = undone[0 : self.B, 0 : self.M, 0 : self.K] + undone = undone.reshape(self.B, self.M, self.K) + return undone + + def as_tensor(self): + # note the transpose because this causes col major hipblaslt op to be TN + if self.original_ndim == 2: + tmp = self.x.reshape(self.alignedM, self.alignedK) + if self.is_transposed: + tmp = tmp.T + return tmp + if self.original_ndim == 3: + tmp = self.x.reshape(self.B, self.alignedM, self.alignedK) + if self.is_transposed: + tmp = tmp.T + return tmp + + def shallow_transpose(self): + shape = ( + (self.M, self.K) if self.original_ndim == 2 else (self.B, self.M, self.K), + ) + new_obj = SwizzleTensor( + torch.empty(*shape, dtype=self.dtype, layout=self.layout, device="meta"), + True, + ) + new_obj.x = self.x + new_obj.B = self.B + new_obj.M = self.M + new_obj.K = self.K + new_obj.alignedM = self.alignedM + new_obj.alignedK = self.alignedK + new_obj.paddedM = self.paddedM + new_obj.paddedK = self.paddedK + new_obj.original_ndim = self.original_ndim + new_obj.is_transposed = not self.is_transposed + return new_obj + + @property + def shape(self): + return torch.Size((self.K, self.M) if self.is_transposed else (self.M, self.K)) + + def stride(self): + return (1, self.K) if self.is_transposed else (self.K, 1) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + # Lazy import to avoid circular dependency + from torchao.swizzle.swizzle_ops import SWIZZLE_OPS_TABLE + + if func in SWIZZLE_OPS_TABLE: + return SWIZZLE_OPS_TABLE[func](func, args, kwargs) + + def unwrap(e): + return e.unswizzle() if isinstance(e, SwizzleTensor) else e + + def wrap(e): + return SwizzleTensor(e) if isinstance(e, torch.Tensor) else e + + return tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) + + # Do not force the SwizzleTensor type on the returned tensor + __torch_function__ = torch._C._disabled_torch_function_impl