From 7473acaed9eecdf4ee464d09104a26be1b209498 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 28 Jan 2025 21:03:44 -0800 Subject: [PATCH] Add mx_fp8_bf16 kernel stack-info: PR: https://github.com/pytorch/ao/pull/1637, branch: drisspg/stack/31 --- setup.py | 12 +- test/prototype/mx_formats/test_mx_mm.py | 102 ++++++++ torchao/csrc/cuda/mx_kernels/mx_fp8_bf16.cu | 251 ++++++++++++++++++++ torchao/ops.py | 29 +++ torchao/prototype/mx_formats/utils.py | 53 +++++ 5 files changed, 442 insertions(+), 5 deletions(-) create mode 100644 test/prototype/mx_formats/test_mx_mm.py create mode 100644 torchao/csrc/cuda/mx_kernels/mx_fp8_bf16.cu create mode 100644 torchao/prototype/mx_formats/utils.py diff --git a/setup.py b/setup.py index 8628dc7ef4..d2832d93a5 100644 --- a/setup.py +++ b/setup.py @@ -215,10 +215,7 @@ def get_extensions(): extra_link_args = [] extra_compile_args = { "cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"], - "nvcc": [ - "-O3" if not debug_mode else "-O0", - "-t=0", - ], + "nvcc": ["-O3" if not debug_mode else "-O0", "-t=0", "-std=c++17"], } if not IS_WINDOWS: @@ -243,13 +240,18 @@ def get_extensions(): use_cutlass = False if use_cuda and not IS_WINDOWS: use_cutlass = True + + if use_cutlass: cutlass_dir = os.path.join(third_party_path, "cutlass") cutlass_include_dir = os.path.join(cutlass_dir, "include") - if use_cutlass: + cutlass_tools_include_dir = os.path.join( + cutlass_dir, "tools", "util", "include" + ) extra_compile_args["nvcc"].extend( [ "-DTORCHAO_USE_CUTLASS", "-I" + cutlass_include_dir, + "-I" + cutlass_tools_include_dir, ] ) diff --git a/test/prototype/mx_formats/test_mx_mm.py b/test/prototype/mx_formats/test_mx_mm.py new file mode 100644 index 0000000000..dca1b26c05 --- /dev/null +++ b/test/prototype/mx_formats/test_mx_mm.py @@ -0,0 +1,102 @@ +import pytest +import torch + +from torchao.float8.float8_utils import compute_error +from torchao.ops import mx_fp8_bf16 +from torchao.prototype.mx_formats.mx_tensor import MXTensor +from torchao.prototype.mx_formats.utils import to_blocked +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_4, + is_sm_at_least_100, +) + +if not TORCH_VERSION_AT_LEAST_2_4: + pytest.skip("Unsupported PyTorch version", allow_module_level=True) + + +def run_matrix_test(M: int, K: int, N: int) -> float: + """ + Run matrix multiplication test with given dimensions. + + Args: + M, K, N: Matrix dimensions + + Returns: + float: SQNR (Signal-to-Quantization-Noise Ratio) value + """ + dtype = torch.bfloat16 + device = torch.device("cuda") + + # Initialize matrices + a = torch.rand((M, K), dtype=dtype, device=device) + b = torch.rand((N, K), dtype=dtype, device=device) + + # Convert to MX format + a_mx = MXTensor.to_mx(a, torch.float8_e4m3fn, 32) + b_mx = MXTensor.to_mx(b, torch.float8_e4m3fn, 32) + + a_fp8 = a_mx._data + b_fp8 = b_mx._data + assert b_fp8.is_contiguous() + b_fp8 = b_fp8.transpose(-1, -2) + + # Get scales + a_scale_e8 = a_mx._scale_e8m0.view(M, K // 32) + b_scale_e8 = b_mx._scale_e8m0.view(N, K // 32) + + a_scale_block = to_blocked(a_scale_e8) + b_scale_block = to_blocked(b_scale_e8) + + # Get reference output + out_hp = a_mx.to_dtype(torch.bfloat16) @ b_mx.to_dtype(torch.bfloat16).transpose( + -1, -2 + ) + + # Run implementation + out_e8_fp8 = mx_fp8_bf16(a_fp8, b_fp8, a_scale_block, b_scale_block) + + # Calculate metrics + sqnr = compute_error(out_hp, out_e8_fp8) + + return sqnr.item() + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8" +) +@pytest.mark.parametrize( + "size", + [ + # Small matrices + (128, 128, 128), + (256, 256, 256), + (384, 384, 384), + # Medium matrices + (512, 512, 512), + (640, 640, 640), + (768, 768, 768), + # Large matrices + (896, 896, 896), + (1024, 1024, 1024), + # Very large matrices + (8192, 8192, 8192), + # Non-square matrices + (128, 256, 384), + (256, 384, 512), + (384, 512, 640), + # Non-aligned matrices + (129, 256, 384), + (256, 384, 536), + (133, 512, 528), + ], + ids=lambda x: f"{x[0]}x{x[1]}x{x[2]}", +) +def test_matrix_multiplication(size): + """ + Test matrix multiplication with various dimensions. + Verifies that the SQNR meets minimum quality threshold. + """ + M, K, N = size + sqnr = run_matrix_test(M, K, N) + assert sqnr >= 80.0, f"SQNR {sqnr} below threshold for dims {M}x{K}x{N}" diff --git a/torchao/csrc/cuda/mx_kernels/mx_fp8_bf16.cu b/torchao/csrc/cuda/mx_kernels/mx_fp8_bf16.cu new file mode 100644 index 0000000000..887e0d59eb --- /dev/null +++ b/torchao/csrc/cuda/mx_kernels/mx_fp8_bf16.cu @@ -0,0 +1,251 @@ +#include + +#include +#include +#include +#include +#include +#include + +#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ + defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) +#define BUILD_MX_KERNELS_CUTLASS +#endif + +#if defined(BUILD_MX_KERNELS_CUTLASS) + +#include "cute/tensor.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/util/packed_stride.hpp" + + +#endif + +namespace torchao { + +#if defined(BUILD_MX_KERNELS_CUTLASS) +namespace { + +using namespace cute; + +template +constexpr int GetAlignment() { + if constexpr (std::is_same_v>) + return 32; + return 16; +} + +template +void run_gemm(at::Tensor& a, at::Tensor& b, at::Tensor& a_scale, + at::Tensor& b_scale, at::Tensor& out) { + int M = a.size(0); + int K = a.size(1); + int N = b.size(1); + + // A matrix configuration + using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand + constexpr int AlignmentA = GetAlignment(); // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + + // B matrix configuration + using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + constexpr int AlignmentB = GetAlignment(); // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + + // C/D matrix configuration + using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand + using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand + using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand + constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) + constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + // Kernel functional config + using ElementAccumulator = float; // Element type for internal accumulation + using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag + + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + PerSmTileShape_MNK, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutCTag, AlignmentC, + ElementD, LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutATag, AlignmentA, + ElementB, LayoutBTag, AlignmentB, + ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue, + void>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Reference device GEMM implementation type + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; + + // Initialize strides using packed stride configuration + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, make_shape(M, K, 1)); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, make_shape(N, K, 1)); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, make_shape(M, N, 1)); + + // Initialize scale factor layouts using block scaled configuration + auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1)); + auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1)); + + using DtypeA = typename ElementA::DataType; + using DtypeB = typename ElementB::DataType; + using DtypeScaleA = typename ElementA::ScaleFactorType; + using DtypeScaleB = typename ElementB::ScaleFactorType; + using DtypeOut = ElementD; + + Gemm gemm; + + auto A_ptr = reinterpret_cast(a.data_ptr()); + auto B_ptr = reinterpret_cast(b.data_ptr()); + auto SFA_ptr = reinterpret_cast(a_scale.data_ptr()); + auto SFB_ptr = reinterpret_cast(b_scale.data_ptr()); + auto out_ptr = reinterpret_cast(out.data_ptr()); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K, 1}, + { // Mainloop arguments + A_ptr, stride_A, + B_ptr, stride_B, + SFA_ptr, layout_SFA, + SFB_ptr, layout_SFB + }, + { // Epilogue arguments + {1.0, 0.0}, + nullptr, StrideC{}, // No bias for now + out_ptr, stride_D + } + }; + + // arguments.scheduler.max_swizzle_size = 8; + + // Check the problem size is supported or not + cutlass::Status status = gemm.can_implement(arguments); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot implement"); + // Allocate workspace memory + size_t workspace_size = Gemm::get_workspace_size(arguments); + auto workspace = a.new_empty( + {static_cast(workspace_size)}, + at::TensorOptions().dtype(at::kByte)); + + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm.initialize(arguments, workspace.data_ptr()); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot initialize"); + + status = gemm.run(at::cuda::getCurrentCUDAStream()); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot run", cutlass::cutlassGetStatusString(status)); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + +} +} +#endif + +void validate(at::Tensor a, at::Tensor b, at::Tensor a_scale, at::Tensor b_scale){ + TORCH_CHECK(a.is_cuda(), "a must be CUDA tensor"); + TORCH_CHECK(b.is_cuda(), "b must be CUDA tensor"); + TORCH_CHECK(a_scale.is_cuda(), "a_scale must be CUDA tensor"); + TORCH_CHECK(b_scale.is_cuda(), "b_scale must be CUDA tensor"); + + // Check matrix dimensions + TORCH_CHECK(a.dim() == 2, "a must be a matrix"); + TORCH_CHECK(b.dim() == 2, "b must be a matrix"); + + // Get dimensions + auto M = a.size(0); + auto K = a.size(1); + auto N = b.size(1); + + TORCH_CHECK(b.size(0) == K, + "Incompatible matrix dimensions: a is ", M, "x", K, " but b is ", b.size(0), "x", N); + + // Needed for TMA store + TORCH_CHECK(N % 8 == 0, "N must be a multiple of 16 but got, ", N); + + // Check 16-byte alignment for input tensors + TORCH_CHECK( + reinterpret_cast(a.data_ptr()) % 16 == 0, + "Input tensor 'a' must be 16-byte aligned"); + TORCH_CHECK( + reinterpret_cast(b.data_ptr()) % 16 == 0, + "Input tensor 'b' must be 16-byte aligned"); + + auto ceil_div = [](auto a, auto b) { return (a + b - 1) / b; }; + auto num_k_blocks = ceil_div(K, 32); + // For a_scale, we expect elements or M* ceil(K/32) elements + auto expected_a_scale_size = 128 * ceil_div(M, 128) * num_k_blocks; + TORCH_CHECK(a_scale.numel() == expected_a_scale_size, "Expected b_scale_size to be ", expected_a_scale_size, " but got ", a_scale.numel()); + + // For b_scale, we expect N * ceil(K/32) elements + auto expected_b_scale_size = 128 * ceil_div(N, 128) * num_k_blocks; + TORCH_CHECK(b_scale.numel() == expected_b_scale_size, "Expected a_scale_size to be ", expected_b_scale_size, " but got ", b_scale.numel()); + + // Check tensor strides for optimal memory layout + TORCH_CHECK( + a.stride(1) == 1, + "Input tensor 'a' must be contiguous in the K dimension (row-major)"); + TORCH_CHECK( + b.stride(0) == 1, + "Input tensor 'b' must be contiguous in the K dimension (column-major)"); +} + + +at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale, + at::Tensor b_scale) { +#if defined(BUILD_MX_KERNELS_CUTLASS) + validate(a, b, a_scale, b_scale); + + auto out = + at::empty({a.size(0), b.size(1)}, a.options().dtype(at::kBFloat16)); + using ElementA = cutlass::mx_float8_t; + using ElementB = cutlass::mx_float8_t; + using ElementD = cutlass::bfloat16_t; + + using MmaTileShape = Shape<_128,_128,_128>; + using ClusterShape = Shape<_2,_1,_1>; + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + run_gemm(a, b, a_scale, b_scale, out); + return out; + #else + TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); + return at::Tensor{}; +#endif +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::mx_fp8_bf16", &mx_fp8_bf16); +} + +} // namespace torchao diff --git a/torchao/ops.py b/torchao/ops.py index f4b55c4951..5845c6cbc6 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -22,6 +22,7 @@ lib.define( "s8s4_linear_cutlass(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" ) +lib.define("mx_fp8_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor") def register_custom_op(name): @@ -615,3 +616,31 @@ def _( dtype=input_scale.dtype, device=input.device, ) + + +def mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): + """Defines a matmul between two fp8 tensors w/ MX scales in E8MO and returns a bf16 tensor. + + This op is prototype subject to change. + + Note: The mx scales are E8MO tensors store in uint8 tensors (for now). + The layout of the scales is very particular, see: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + A: fp8 tensor + B: fp8 tensor + A_scale: E8M0 scale tensor for A with groupsize=32 in swizzled layout + B_scale: E8M0 scale tensor for B with groupsize=32 in swizzled layout + + Returns: + MXN bf16 Tensor + + """ + return torch.ops.torchao.mx_fp8_bf16.default(A, B, A_scale, B_scale) + + +@register_custom_op("torchao::mx_fp8_bf16") +def meta_mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): + """Meta impl for mx_fp8_bf16""" + return torch.empty((A.size(0), B.size(1)), dtype=torch.bfloat16, device=A.device) diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py new file mode 100644 index 0000000000..4cdc26109d --- /dev/null +++ b/torchao/prototype/mx_formats/utils.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F + +Tensor = torch.Tensor + + +def ceil_div(a, b): + return (a + b - 1) // b + + +def to_blocked(input_matrix) -> Tensor: + """ + Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern. + + See: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + input_matrix: Input tensor of shape (H, W) + + Returns: + Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4)) + """ + rows, cols = input_matrix.shape + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + # Pad out and view as tiles of (128, 4) + padded = F.pad(input_matrix, (0, -cols % 4, 0, -rows % 128)) + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + + # rearrange all tiles + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + + # Layout rearranged tiles according to second pic + return rearranged.flatten() + + +def _to_blocked_single(scales: Tensor) -> Tensor: + """Assume that we have a 128x4 block of scales in K Major order + + To see more information on the individual tile layout: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + """ + assert scales.shape == (128, 4) + scales_tiled = scales.view(4, 32, 4) # view as 4 - (32, 4) tiles + return scales_tiled.transpose(0, 1).reshape(32, 16) # Interleave tiles