Skip to content

Commit

Permalink
Add mx_fp8_bf16 kernel
Browse files Browse the repository at this point in the history
stack-info: PR: #1637, branch: drisspg/stack/31
  • Loading branch information
drisspg committed Feb 4, 2025
1 parent 0bc3abe commit 7473aca
Show file tree
Hide file tree
Showing 5 changed files with 442 additions and 5 deletions.
12 changes: 7 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,7 @@ def get_extensions():
extra_link_args = []
extra_compile_args = {
"cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"],
"nvcc": [
"-O3" if not debug_mode else "-O0",
"-t=0",
],
"nvcc": ["-O3" if not debug_mode else "-O0", "-t=0", "-std=c++17"],
}

if not IS_WINDOWS:
Expand All @@ -243,13 +240,18 @@ def get_extensions():
use_cutlass = False
if use_cuda and not IS_WINDOWS:
use_cutlass = True

if use_cutlass:
cutlass_dir = os.path.join(third_party_path, "cutlass")
cutlass_include_dir = os.path.join(cutlass_dir, "include")
if use_cutlass:
cutlass_tools_include_dir = os.path.join(
cutlass_dir, "tools", "util", "include"
)
extra_compile_args["nvcc"].extend(
[
"-DTORCHAO_USE_CUTLASS",
"-I" + cutlass_include_dir,
"-I" + cutlass_tools_include_dir,
]
)

Expand Down
102 changes: 102 additions & 0 deletions test/prototype/mx_formats/test_mx_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import pytest
import torch

from torchao.float8.float8_utils import compute_error
from torchao.ops import mx_fp8_bf16
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
is_sm_at_least_100,
)

if not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


def run_matrix_test(M: int, K: int, N: int) -> float:
"""
Run matrix multiplication test with given dimensions.
Args:
M, K, N: Matrix dimensions
Returns:
float: SQNR (Signal-to-Quantization-Noise Ratio) value
"""
dtype = torch.bfloat16
device = torch.device("cuda")

# Initialize matrices
a = torch.rand((M, K), dtype=dtype, device=device)
b = torch.rand((N, K), dtype=dtype, device=device)

# Convert to MX format
a_mx = MXTensor.to_mx(a, torch.float8_e4m3fn, 32)
b_mx = MXTensor.to_mx(b, torch.float8_e4m3fn, 32)

a_fp8 = a_mx._data
b_fp8 = b_mx._data
assert b_fp8.is_contiguous()
b_fp8 = b_fp8.transpose(-1, -2)

# Get scales
a_scale_e8 = a_mx._scale_e8m0.view(M, K // 32)
b_scale_e8 = b_mx._scale_e8m0.view(N, K // 32)

a_scale_block = to_blocked(a_scale_e8)
b_scale_block = to_blocked(b_scale_e8)

# Get reference output
out_hp = a_mx.to_dtype(torch.bfloat16) @ b_mx.to_dtype(torch.bfloat16).transpose(
-1, -2
)

# Run implementation
out_e8_fp8 = mx_fp8_bf16(a_fp8, b_fp8, a_scale_block, b_scale_block)

# Calculate metrics
sqnr = compute_error(out_hp, out_e8_fp8)

return sqnr.item()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8"
)
@pytest.mark.parametrize(
"size",
[
# Small matrices
(128, 128, 128),
(256, 256, 256),
(384, 384, 384),
# Medium matrices
(512, 512, 512),
(640, 640, 640),
(768, 768, 768),
# Large matrices
(896, 896, 896),
(1024, 1024, 1024),
# Very large matrices
(8192, 8192, 8192),
# Non-square matrices
(128, 256, 384),
(256, 384, 512),
(384, 512, 640),
# Non-aligned matrices
(129, 256, 384),
(256, 384, 536),
(133, 512, 528),
],
ids=lambda x: f"{x[0]}x{x[1]}x{x[2]}",
)
def test_matrix_multiplication(size):
"""
Test matrix multiplication with various dimensions.
Verifies that the SQNR meets minimum quality threshold.
"""
M, K, N = size
sqnr = run_matrix_test(M, K, N)
assert sqnr >= 80.0, f"SQNR {sqnr} below threshold for dims {M}x{K}x{N}"
251 changes: 251 additions & 0 deletions torchao/csrc/cuda/mx_kernels/mx_fp8_bf16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
#include <torch/library.h>

#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include <ATen/cuda/CUDAUtils.h>
#include <c10/util/Exception.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAException.h>

#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \
defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
#define BUILD_MX_KERNELS_CUTLASS
#endif

#if defined(BUILD_MX_KERNELS_CUTLASS)

#include "cute/tensor.hpp"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/util/packed_stride.hpp"


#endif

namespace torchao {

#if defined(BUILD_MX_KERNELS_CUTLASS)
namespace {

using namespace cute;

template<typename Element>
constexpr int GetAlignment() {
if constexpr (std::is_same_v<Element, cutlass::nv_float4_t<cutlass::float_e2m1_t>>)
return 32;
return 16;
}

template <typename ElementA,
typename ElementB,
typename ElementD,
typename MmaTileShape,
typename ClusterShape,
typename PerSmTileShape_MNK>
void run_gemm(at::Tensor& a, at::Tensor& b, at::Tensor& a_scale,
at::Tensor& b_scale, at::Tensor& out) {
int M = a.size(0);
int K = a.size(1);
int N = b.size(1);

// A matrix configuration
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = GetAlignment<ElementA>(); // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)

// B matrix configuration
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = GetAlignment<ElementB>(); // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)

// C/D matrix configuration
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes)
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
// Kernel functional config
using ElementAccumulator = float; // Element type for internal accumulation
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag


using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
PerSmTileShape_MNK, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator, ElementAccumulator,
ElementC, LayoutCTag, AlignmentC,
ElementD, LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutATag, AlignmentA,
ElementB, LayoutBTag, AlignmentB,
ElementAccumulator,
MmaTileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy
>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue,
void>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

// Reference device GEMM implementation type
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;

// Initialize strides using packed stride configuration
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, make_shape(M, K, 1));
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, make_shape(N, K, 1));
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, make_shape(M, N, 1));

// Initialize scale factor layouts using block scaled configuration
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1));
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1));

using DtypeA = typename ElementA::DataType;
using DtypeB = typename ElementB::DataType;
using DtypeScaleA = typename ElementA::ScaleFactorType;
using DtypeScaleB = typename ElementB::ScaleFactorType;
using DtypeOut = ElementD;

Gemm gemm;

auto A_ptr = reinterpret_cast<DtypeA*>(a.data_ptr());
auto B_ptr = reinterpret_cast<DtypeB*>(b.data_ptr());
auto SFA_ptr = reinterpret_cast<DtypeScaleA*>(a_scale.data_ptr());
auto SFB_ptr = reinterpret_cast<DtypeScaleB*>(b_scale.data_ptr());
auto out_ptr = reinterpret_cast<DtypeOut*>(out.data_ptr());

typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K, 1},
{ // Mainloop arguments
A_ptr, stride_A,
B_ptr, stride_B,
SFA_ptr, layout_SFA,
SFB_ptr, layout_SFB
},
{ // Epilogue arguments
{1.0, 0.0},
nullptr, StrideC{}, // No bias for now
out_ptr, stride_D
}
};

// arguments.scheduler.max_swizzle_size = 8;

// Check the problem size is supported or not
cutlass::Status status = gemm.can_implement(arguments);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot implement");
// Allocate workspace memory
size_t workspace_size = Gemm::get_workspace_size(arguments);
auto workspace = a.new_empty(
{static_cast<int64_t>(workspace_size)},
at::TensorOptions().dtype(at::kByte));


// Initialize CUTLASS kernel with arguments and workspace pointer
status = gemm.initialize(arguments, workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot initialize");

status = gemm.run(at::cuda::getCurrentCUDAStream());
TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot run", cutlass::cutlassGetStatusString(status));

C10_CUDA_KERNEL_LAUNCH_CHECK();

}
}
#endif

void validate(at::Tensor a, at::Tensor b, at::Tensor a_scale, at::Tensor b_scale){
TORCH_CHECK(a.is_cuda(), "a must be CUDA tensor");
TORCH_CHECK(b.is_cuda(), "b must be CUDA tensor");
TORCH_CHECK(a_scale.is_cuda(), "a_scale must be CUDA tensor");
TORCH_CHECK(b_scale.is_cuda(), "b_scale must be CUDA tensor");

// Check matrix dimensions
TORCH_CHECK(a.dim() == 2, "a must be a matrix");
TORCH_CHECK(b.dim() == 2, "b must be a matrix");

// Get dimensions
auto M = a.size(0);
auto K = a.size(1);
auto N = b.size(1);

TORCH_CHECK(b.size(0) == K,
"Incompatible matrix dimensions: a is ", M, "x", K, " but b is ", b.size(0), "x", N);

// Needed for TMA store
TORCH_CHECK(N % 8 == 0, "N must be a multiple of 16 but got, ", N);

// Check 16-byte alignment for input tensors
TORCH_CHECK(
reinterpret_cast<std::uintptr_t>(a.data_ptr()) % 16 == 0,
"Input tensor 'a' must be 16-byte aligned");
TORCH_CHECK(
reinterpret_cast<std::uintptr_t>(b.data_ptr()) % 16 == 0,
"Input tensor 'b' must be 16-byte aligned");

auto ceil_div = [](auto a, auto b) { return (a + b - 1) / b; };
auto num_k_blocks = ceil_div(K, 32);
// For a_scale, we expect elements or M* ceil(K/32) elements
auto expected_a_scale_size = 128 * ceil_div(M, 128) * num_k_blocks;
TORCH_CHECK(a_scale.numel() == expected_a_scale_size, "Expected b_scale_size to be ", expected_a_scale_size, " but got ", a_scale.numel());

// For b_scale, we expect N * ceil(K/32) elements
auto expected_b_scale_size = 128 * ceil_div(N, 128) * num_k_blocks;
TORCH_CHECK(b_scale.numel() == expected_b_scale_size, "Expected a_scale_size to be ", expected_b_scale_size, " but got ", b_scale.numel());

// Check tensor strides for optimal memory layout
TORCH_CHECK(
a.stride(1) == 1,
"Input tensor 'a' must be contiguous in the K dimension (row-major)");
TORCH_CHECK(
b.stride(0) == 1,
"Input tensor 'b' must be contiguous in the K dimension (column-major)");
}


at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
at::Tensor b_scale) {
#if defined(BUILD_MX_KERNELS_CUTLASS)
validate(a, b, a_scale, b_scale);

auto out =
at::empty({a.size(0), b.size(1)}, a.options().dtype(at::kBFloat16));
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
using ElementD = cutlass::bfloat16_t;

using MmaTileShape = Shape<_128,_128,_128>;
using ClusterShape = Shape<_2,_1,_1>;
using PerSmTileShape_MNK = Shape<_128,_128,_128>;

run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out);
return out;
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
return at::Tensor{};
#endif
}

TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
m.impl("torchao::mx_fp8_bf16", &mx_fp8_bf16);
}

} // namespace torchao
Loading

0 comments on commit 7473aca

Please sign in to comment.