Skip to content

Commit

Permalink
Add mx_fp4_kernel
Browse files Browse the repository at this point in the history
stack-info: PR: #1661, branch: drisspg/stack/34
  • Loading branch information
drisspg committed Feb 4, 2025
1 parent 7473aca commit a4ddbce
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 67 deletions.
90 changes: 31 additions & 59 deletions test/prototype/mx_formats/test_mx_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,63 +2,45 @@
import torch

from torchao.float8.float8_utils import compute_error
from torchao.ops import mx_fp8_bf16
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.ops import mx_fp4_bf16, mx_fp8_bf16
from torchao.prototype.mx_formats.mx_tensor import DTYPE_FP4, MXTensor
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
is_sm_at_least_100,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_100

if not TORCH_VERSION_AT_LEAST_2_4:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)


def run_matrix_test(M: int, K: int, N: int) -> float:
"""
Run matrix multiplication test with given dimensions.
Args:
M, K, N: Matrix dimensions
Returns:
float: SQNR (Signal-to-Quantization-Noise Ratio) value
"""
def run_matrix_test(M: int, K: int, N: int, format) -> float:
dtype = torch.bfloat16
device = torch.device("cuda")

# Initialize matrices
a = torch.rand((M, K), dtype=dtype, device=device)
b = torch.rand((N, K), dtype=dtype, device=device)

# Convert to MX format
a_mx = MXTensor.to_mx(a, torch.float8_e4m3fn, 32)
b_mx = MXTensor.to_mx(b, torch.float8_e4m3fn, 32)
fmt = torch.float8_e4m3fn if format == "fp8" else DTYPE_FP4
mx_func = mx_fp8_bf16 if format == "fp8" else mx_fp4_bf16

a_fp8 = a_mx._data
b_fp8 = b_mx._data
assert b_fp8.is_contiguous()
b_fp8 = b_fp8.transpose(-1, -2)
a_mx = MXTensor.to_mx(a, fmt, 32)
b_mx = MXTensor.to_mx(b, fmt, 32)

# Get scales
a_scale_e8 = a_mx._scale_e8m0.view(M, K // 32)
b_scale_e8 = b_mx._scale_e8m0.view(N, K // 32)
a_data = a_mx._data
b_data = b_mx._data
assert b_data.is_contiguous()
b_data = b_data.transpose(-1, -2)

a_scale_block = to_blocked(a_scale_e8)
b_scale_block = to_blocked(b_scale_e8)
a_scale = a_mx._scale_e8m0.view(M, K // 32)
b_scale = b_mx._scale_e8m0.view(N, K // 32)

a_scale_block = to_blocked(a_scale)
b_scale_block = to_blocked(b_scale)

# Get reference output
out_hp = a_mx.to_dtype(torch.bfloat16) @ b_mx.to_dtype(torch.bfloat16).transpose(
-1, -2
)
out = mx_func(a_data, b_data, a_scale_block, b_scale_block)

# Run implementation
out_e8_fp8 = mx_fp8_bf16(a_fp8, b_fp8, a_scale_block, b_scale_block)

# Calculate metrics
sqnr = compute_error(out_hp, out_e8_fp8)

return sqnr.item()
return compute_error(out_hp, out).item()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand All @@ -68,35 +50,25 @@ def run_matrix_test(M: int, K: int, N: int) -> float:
@pytest.mark.parametrize(
"size",
[
# Small matrices
(128, 128, 128),
(256, 256, 256),
(384, 384, 384),
# Medium matrices
(384, 384, 384), # Small
(512, 512, 512),
(640, 640, 640),
(768, 768, 768),
# Large matrices
(896, 896, 896),
(768, 768, 768), # Medium
(1024, 1024, 1024),
# Very large matrices
(8192, 8192, 8192),
# Non-square matrices
(8192, 8192, 8192), # Large
(128, 256, 384),
(256, 384, 512),
(384, 512, 640),
# Non-aligned matrices
(256, 384, 512), # Non-square
(129, 256, 384),
(256, 384, 536),
(133, 512, 528),
(133, 512, 528), # Non-aligned
],
ids=lambda x: f"{x[0]}x{x[1]}x{x[2]}",
)
def test_matrix_multiplication(size):
"""
Test matrix multiplication with various dimensions.
Verifies that the SQNR meets minimum quality threshold.
"""
@pytest.mark.parametrize("format", ["fp8", "fp4"])
def test_matrix_multiplication(size, format):
M, K, N = size
sqnr = run_matrix_test(M, K, N)
assert sqnr >= 80.0, f"SQNR {sqnr} below threshold for dims {M}x{K}x{N}"
sqnr = run_matrix_test(M, K, N, format)
threshold = 80.0
assert (
sqnr >= threshold
), f"{format} SQNR {sqnr} below threshold for dims {M}x{K}x{N}"
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ using namespace cute;

template<typename Element>
constexpr int GetAlignment() {
if constexpr (std::is_same_v<Element, cutlass::nv_float4_t<cutlass::float_e2m1_t>>)
if constexpr (std::is_same_v<Element, cutlass::mx_float4_t<cutlass::float_e2m1_t>>)
return 32;
return 16;
}
Expand All @@ -46,11 +46,7 @@ template <typename ElementA,
typename ClusterShape,
typename PerSmTileShape_MNK>
void run_gemm(at::Tensor& a, at::Tensor& b, at::Tensor& a_scale,
at::Tensor& b_scale, at::Tensor& out) {
int M = a.size(0);
int K = a.size(1);
int N = b.size(1);

at::Tensor& b_scale, at::Tensor& out, int M, int K, int N) {
// A matrix configuration
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = GetAlignment<ElementA>(); // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
Expand Down Expand Up @@ -225,9 +221,12 @@ at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
at::Tensor b_scale) {
#if defined(BUILD_MX_KERNELS_CUTLASS)
validate(a, b, a_scale, b_scale);
auto M = a.size(0);
auto K = a.size(1);
auto N = b.size(1);

auto out =
at::empty({a.size(0), b.size(1)}, a.options().dtype(at::kBFloat16));
at::empty({M, N}, a.options().dtype(at::kBFloat16));
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
using ElementD = cutlass::bfloat16_t;
Expand All @@ -236,16 +235,51 @@ at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
using ClusterShape = Shape<_2,_1,_1>;
using PerSmTileShape_MNK = Shape<_128,_128,_128>;

run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out);
run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out, M, K, N);
return out;
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
return at::Tensor{};
#endif
}

at::Tensor mx_fp4_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
at::Tensor b_scale) {
#if defined(BUILD_MX_KERNELS_CUTLASS)
TORCH_CHECK(a.is_cuda(), "a must be CUDA tensor");
TORCH_CHECK(b.is_cuda(), "b must be CUDA tensor");
TORCH_CHECK(a_scale.is_cuda(), "a_scale must be CUDA tensor");
TORCH_CHECK(b_scale.is_cuda(), "b_scale must be CUDA tensor");

auto M = a.size(0);
auto K = a.size(1) * 2;
auto N = b.size(1);

auto out =
at::empty({M, N}, a.options().dtype(at::kBFloat16));
using ElementA = cutlass::mx_float4_t<cutlass::float_e2m1_t>;
using ElementB = cutlass::mx_float4_t<cutlass::float_e2m1_t>;
using ElementD = cutlass::bfloat16_t;

using MmaTileShape = Shape<_128,_128,_128>;
using ClusterShape = Shape<_2,_1,_1>;
using PerSmTileShape_MNK = Shape<_128,_128,_128>;

run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out, M, K, N);
return out;
#else
TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
return at::Tensor{};
#endif
}

TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
m.impl("torchao::mx_fp8_bf16", &mx_fp8_bf16);
}
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
m.impl("torchao::mx_fp4_bf16", &mx_fp4_bf16);
}



} // namespace torchao
30 changes: 30 additions & 0 deletions torchao/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"s8s4_linear_cutlass(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"
)
lib.define("mx_fp8_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor")
lib.define("mx_fp4_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor")


def register_custom_op(name):
Expand Down Expand Up @@ -644,3 +645,32 @@ def mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
def meta_mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
"""Meta impl for mx_fp8_bf16"""
return torch.empty((A.size(0), B.size(1)), dtype=torch.bfloat16, device=A.device)


def mx_fp4_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
"""Defines a matmul between two fp4 tensors w/ MX scales in E8MO and returns a bf16 tensor.
This op is prototype subject to change.
Note: The mx scales are E8MO tensors stored in uint8 tensors (for now).
The layout of the scales is very particular, see:
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
Args:
A: fp4 tensor (2 fp4 elements are packed into 1 byte -> elem0|elem1)
B: fp4 tensor (2 fp4 elements are packed into 1 byte -> elem0|elem1)
A_scale: E8M0 scale tensor for A with groupsize=32 in swizzled layout
B_scale: E8M0 scale tensor for B with groupsize=32 in swizzled layout
Returns:
MXN bf16 Tensor
"""
return torch.ops.torchao.mx_fp4_bf16.default(A, B, A_scale, B_scale)


@register_custom_op("torchao::mx_fp4_bf16")
def meta_mx_fp4_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
"""Meta impl for mx_fp4_bf16"""
# Assume that the contraction happens in the K dim thus M,N are perserved post bit pack
return torch.empty((A.size(0), B.size(1)), dtype=torch.bfloat16, device=A.device)

0 comments on commit a4ddbce

Please sign in to comment.