From a4ddbcecaf46d95e89cc40a1b193506bad682359 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 4 Feb 2025 11:54:58 -0800 Subject: [PATCH] Add mx_fp4_kernel stack-info: PR: https://github.com/pytorch/ao/pull/1661, branch: drisspg/stack/34 --- test/prototype/mx_formats/test_mx_mm.py | 90 +++++++------------ .../{mx_fp8_bf16.cu => mx_fp_bf16.cu} | 50 +++++++++-- torchao/ops.py | 30 +++++++ 3 files changed, 103 insertions(+), 67 deletions(-) rename torchao/csrc/cuda/mx_kernels/{mx_fp8_bf16.cu => mx_fp_bf16.cu} (87%) diff --git a/test/prototype/mx_formats/test_mx_mm.py b/test/prototype/mx_formats/test_mx_mm.py index dca1b26c05..7c66c5d053 100644 --- a/test/prototype/mx_formats/test_mx_mm.py +++ b/test/prototype/mx_formats/test_mx_mm.py @@ -2,63 +2,45 @@ import torch from torchao.float8.float8_utils import compute_error -from torchao.ops import mx_fp8_bf16 -from torchao.prototype.mx_formats.mx_tensor import MXTensor +from torchao.ops import mx_fp4_bf16, mx_fp8_bf16 +from torchao.prototype.mx_formats.mx_tensor import DTYPE_FP4, MXTensor from torchao.prototype.mx_formats.utils import to_blocked -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, - is_sm_at_least_100, -) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_100 if not TORCH_VERSION_AT_LEAST_2_4: pytest.skip("Unsupported PyTorch version", allow_module_level=True) -def run_matrix_test(M: int, K: int, N: int) -> float: - """ - Run matrix multiplication test with given dimensions. - - Args: - M, K, N: Matrix dimensions - - Returns: - float: SQNR (Signal-to-Quantization-Noise Ratio) value - """ +def run_matrix_test(M: int, K: int, N: int, format) -> float: dtype = torch.bfloat16 device = torch.device("cuda") - # Initialize matrices a = torch.rand((M, K), dtype=dtype, device=device) b = torch.rand((N, K), dtype=dtype, device=device) - # Convert to MX format - a_mx = MXTensor.to_mx(a, torch.float8_e4m3fn, 32) - b_mx = MXTensor.to_mx(b, torch.float8_e4m3fn, 32) + fmt = torch.float8_e4m3fn if format == "fp8" else DTYPE_FP4 + mx_func = mx_fp8_bf16 if format == "fp8" else mx_fp4_bf16 - a_fp8 = a_mx._data - b_fp8 = b_mx._data - assert b_fp8.is_contiguous() - b_fp8 = b_fp8.transpose(-1, -2) + a_mx = MXTensor.to_mx(a, fmt, 32) + b_mx = MXTensor.to_mx(b, fmt, 32) - # Get scales - a_scale_e8 = a_mx._scale_e8m0.view(M, K // 32) - b_scale_e8 = b_mx._scale_e8m0.view(N, K // 32) + a_data = a_mx._data + b_data = b_mx._data + assert b_data.is_contiguous() + b_data = b_data.transpose(-1, -2) - a_scale_block = to_blocked(a_scale_e8) - b_scale_block = to_blocked(b_scale_e8) + a_scale = a_mx._scale_e8m0.view(M, K // 32) + b_scale = b_mx._scale_e8m0.view(N, K // 32) + + a_scale_block = to_blocked(a_scale) + b_scale_block = to_blocked(b_scale) - # Get reference output out_hp = a_mx.to_dtype(torch.bfloat16) @ b_mx.to_dtype(torch.bfloat16).transpose( -1, -2 ) + out = mx_func(a_data, b_data, a_scale_block, b_scale_block) - # Run implementation - out_e8_fp8 = mx_fp8_bf16(a_fp8, b_fp8, a_scale_block, b_scale_block) - - # Calculate metrics - sqnr = compute_error(out_hp, out_e8_fp8) - - return sqnr.item() + return compute_error(out_hp, out).item() @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -68,35 +50,25 @@ def run_matrix_test(M: int, K: int, N: int) -> float: @pytest.mark.parametrize( "size", [ - # Small matrices (128, 128, 128), (256, 256, 256), - (384, 384, 384), - # Medium matrices + (384, 384, 384), # Small (512, 512, 512), - (640, 640, 640), - (768, 768, 768), - # Large matrices - (896, 896, 896), + (768, 768, 768), # Medium (1024, 1024, 1024), - # Very large matrices - (8192, 8192, 8192), - # Non-square matrices + (8192, 8192, 8192), # Large (128, 256, 384), - (256, 384, 512), - (384, 512, 640), - # Non-aligned matrices + (256, 384, 512), # Non-square (129, 256, 384), - (256, 384, 536), - (133, 512, 528), + (133, 512, 528), # Non-aligned ], ids=lambda x: f"{x[0]}x{x[1]}x{x[2]}", ) -def test_matrix_multiplication(size): - """ - Test matrix multiplication with various dimensions. - Verifies that the SQNR meets minimum quality threshold. - """ +@pytest.mark.parametrize("format", ["fp8", "fp4"]) +def test_matrix_multiplication(size, format): M, K, N = size - sqnr = run_matrix_test(M, K, N) - assert sqnr >= 80.0, f"SQNR {sqnr} below threshold for dims {M}x{K}x{N}" + sqnr = run_matrix_test(M, K, N, format) + threshold = 80.0 + assert ( + sqnr >= threshold + ), f"{format} SQNR {sqnr} below threshold for dims {M}x{K}x{N}" diff --git a/torchao/csrc/cuda/mx_kernels/mx_fp8_bf16.cu b/torchao/csrc/cuda/mx_kernels/mx_fp_bf16.cu similarity index 87% rename from torchao/csrc/cuda/mx_kernels/mx_fp8_bf16.cu rename to torchao/csrc/cuda/mx_kernels/mx_fp_bf16.cu index 887e0d59eb..e01d363ec3 100644 --- a/torchao/csrc/cuda/mx_kernels/mx_fp8_bf16.cu +++ b/torchao/csrc/cuda/mx_kernels/mx_fp_bf16.cu @@ -34,7 +34,7 @@ using namespace cute; template constexpr int GetAlignment() { - if constexpr (std::is_same_v>) + if constexpr (std::is_same_v>) return 32; return 16; } @@ -46,11 +46,7 @@ template void run_gemm(at::Tensor& a, at::Tensor& b, at::Tensor& a_scale, - at::Tensor& b_scale, at::Tensor& out) { - int M = a.size(0); - int K = a.size(1); - int N = b.size(1); - + at::Tensor& b_scale, at::Tensor& out, int M, int K, int N) { // A matrix configuration using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand constexpr int AlignmentA = GetAlignment(); // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) @@ -225,9 +221,12 @@ at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale, at::Tensor b_scale) { #if defined(BUILD_MX_KERNELS_CUTLASS) validate(a, b, a_scale, b_scale); + auto M = a.size(0); + auto K = a.size(1); + auto N = b.size(1); auto out = - at::empty({a.size(0), b.size(1)}, a.options().dtype(at::kBFloat16)); + at::empty({M, N}, a.options().dtype(at::kBFloat16)); using ElementA = cutlass::mx_float8_t; using ElementB = cutlass::mx_float8_t; using ElementD = cutlass::bfloat16_t; @@ -236,7 +235,7 @@ at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale, using ClusterShape = Shape<_2,_1,_1>; using PerSmTileShape_MNK = Shape<_128,_128,_128>; - run_gemm(a, b, a_scale, b_scale, out); + run_gemm(a, b, a_scale, b_scale, out, M, K, N); return out; #else TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); @@ -244,8 +243,43 @@ at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale, #endif } +at::Tensor mx_fp4_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale, + at::Tensor b_scale) { +#if defined(BUILD_MX_KERNELS_CUTLASS) + TORCH_CHECK(a.is_cuda(), "a must be CUDA tensor"); + TORCH_CHECK(b.is_cuda(), "b must be CUDA tensor"); + TORCH_CHECK(a_scale.is_cuda(), "a_scale must be CUDA tensor"); + TORCH_CHECK(b_scale.is_cuda(), "b_scale must be CUDA tensor"); + + auto M = a.size(0); + auto K = a.size(1) * 2; + auto N = b.size(1); + + auto out = + at::empty({M, N}, a.options().dtype(at::kBFloat16)); + using ElementA = cutlass::mx_float4_t; + using ElementB = cutlass::mx_float4_t; + using ElementD = cutlass::bfloat16_t; + + using MmaTileShape = Shape<_128,_128,_128>; + using ClusterShape = Shape<_2,_1,_1>; + using PerSmTileShape_MNK = Shape<_128,_128,_128>; + + run_gemm(a, b, a_scale, b_scale, out, M, K, N); + return out; +#else + TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); + return at::Tensor{}; +#endif +} + TORCH_LIBRARY_IMPL(torchao, CUDA, m) { m.impl("torchao::mx_fp8_bf16", &mx_fp8_bf16); } +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::mx_fp4_bf16", &mx_fp4_bf16); +} + + } // namespace torchao diff --git a/torchao/ops.py b/torchao/ops.py index 5845c6cbc6..c92fc0e5df 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -23,6 +23,7 @@ "s8s4_linear_cutlass(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor" ) lib.define("mx_fp8_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor") +lib.define("mx_fp4_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor") def register_custom_op(name): @@ -644,3 +645,32 @@ def mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): def meta_mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): """Meta impl for mx_fp8_bf16""" return torch.empty((A.size(0), B.size(1)), dtype=torch.bfloat16, device=A.device) + + +def mx_fp4_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): + """Defines a matmul between two fp4 tensors w/ MX scales in E8MO and returns a bf16 tensor. + + This op is prototype subject to change. + + Note: The mx scales are E8MO tensors stored in uint8 tensors (for now). + The layout of the scales is very particular, see: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + A: fp4 tensor (2 fp4 elements are packed into 1 byte -> elem0|elem1) + B: fp4 tensor (2 fp4 elements are packed into 1 byte -> elem0|elem1) + A_scale: E8M0 scale tensor for A with groupsize=32 in swizzled layout + B_scale: E8M0 scale tensor for B with groupsize=32 in swizzled layout + + Returns: + MXN bf16 Tensor + + """ + return torch.ops.torchao.mx_fp4_bf16.default(A, B, A_scale, B_scale) + + +@register_custom_op("torchao::mx_fp4_bf16") +def meta_mx_fp4_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor): + """Meta impl for mx_fp4_bf16""" + # Assume that the contraction happens in the K dim thus M,N are perserved post bit pack + return torch.empty((A.size(0), B.size(1)), dtype=torch.bfloat16, device=A.device)