Skip to content

Commit eb5a573

Browse files
committed
Add mx_fp8_bf16 kernel
stack-info: PR: #1637, branch: drisspg/stack/31
1 parent affe31f commit eb5a573

File tree

5 files changed

+440
-4
lines changed

5 files changed

+440
-4
lines changed

Diff for: setup.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,7 @@ def get_extensions():
215215
extra_link_args = []
216216
extra_compile_args = {
217217
"cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"],
218-
"nvcc": [
219-
"-O3" if not debug_mode else "-O0",
220-
"-t=0",
221-
],
218+
"nvcc": ["-O3" if not debug_mode else "-O0", "-t=0", "-std=c++17"],
222219
}
223220

224221
if not IS_WINDOWS:
@@ -257,12 +254,16 @@ def get_extensions():
257254
use_cutlass = True
258255
cutlass_dir = os.path.join(third_party_path, "cutlass")
259256
cutlass_include_dir = os.path.join(cutlass_dir, "include")
257+
cutlass_tools_include_dir = os.path.join(
258+
cutlass_dir, "tools", "util", "include"
259+
)
260260
cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir)
261261
if use_cutlass:
262262
extra_compile_args["nvcc"].extend(
263263
[
264264
"-DTORCHAO_USE_CUTLASS",
265265
"-I" + cutlass_include_dir,
266+
"-I" + cutlass_tools_include_dir,
266267
"-I" + cutlass_extensions_include_dir,
267268
]
268269
)

Diff for: test/prototype/mx_formats/test_mx_mm.py

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import pytest
2+
import torch
3+
4+
from torchao.float8.float8_utils import compute_error
5+
from torchao.ops import mx_fp8_bf16
6+
from torchao.prototype.mx_formats.mx_tensor import MXTensor
7+
from torchao.prototype.mx_formats.utils import to_blocked
8+
from torchao.utils import (
9+
TORCH_VERSION_AT_LEAST_2_4,
10+
is_sm_at_least_100,
11+
)
12+
13+
if not TORCH_VERSION_AT_LEAST_2_4:
14+
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
15+
16+
17+
def run_matrix_test(M: int, K: int, N: int) -> float:
18+
"""
19+
Run matrix multiplication test with given dimensions.
20+
21+
Args:
22+
M, K, N: Matrix dimensions
23+
24+
Returns:
25+
float: SQNR (Signal-to-Quantization-Noise Ratio) value
26+
"""
27+
dtype = torch.bfloat16
28+
device = torch.device("cuda")
29+
30+
# Initialize matrices
31+
a = torch.rand((M, K), dtype=dtype, device=device)
32+
b = torch.rand((N, K), dtype=dtype, device=device)
33+
34+
# Convert to MX format
35+
a_mx = MXTensor.to_mx(a, torch.float8_e4m3fn, 32)
36+
b_mx = MXTensor.to_mx(b, torch.float8_e4m3fn, 32)
37+
38+
a_fp8 = a_mx._data
39+
b_fp8 = b_mx._data
40+
assert b_fp8.is_contiguous()
41+
b_fp8 = b_fp8.transpose(-1, -2)
42+
43+
# Get scales
44+
a_scale_e8 = a_mx._scale_e8m0.view(M, K // 32)
45+
b_scale_e8 = b_mx._scale_e8m0.view(N, K // 32)
46+
47+
a_scale_block = to_blocked(a_scale_e8)
48+
b_scale_block = to_blocked(b_scale_e8)
49+
50+
# Get reference output
51+
out_hp = a_mx.to_dtype(torch.bfloat16) @ b_mx.to_dtype(torch.bfloat16).transpose(
52+
-1, -2
53+
)
54+
55+
# Run implementation
56+
out_e8_fp8 = mx_fp8_bf16(a_fp8, b_fp8, a_scale_block, b_scale_block)
57+
58+
# Calculate metrics
59+
sqnr = compute_error(out_hp, out_e8_fp8)
60+
61+
return sqnr.item()
62+
63+
64+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
65+
@pytest.mark.skipif(
66+
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8"
67+
)
68+
@pytest.mark.parametrize(
69+
"size",
70+
[
71+
# Small matrices
72+
(128, 128, 128),
73+
(256, 256, 256),
74+
(384, 384, 384),
75+
# Medium matrices
76+
(512, 512, 512),
77+
(640, 640, 640),
78+
(768, 768, 768),
79+
# Large matrices
80+
(896, 896, 896),
81+
(1024, 1024, 1024),
82+
# Very large matrices
83+
(8192, 8192, 8192),
84+
# Non-square matrices
85+
(128, 256, 384),
86+
(256, 384, 512),
87+
(384, 512, 640),
88+
# Non-aligned matrices
89+
(129, 256, 384),
90+
(256, 384, 536),
91+
(133, 512, 528),
92+
],
93+
ids=lambda x: f"{x[0]}x{x[1]}x{x[2]}",
94+
)
95+
def test_matrix_multiplication(size):
96+
"""
97+
Test matrix multiplication with various dimensions.
98+
Verifies that the SQNR meets minimum quality threshold.
99+
"""
100+
M, K, N = size
101+
sqnr = run_matrix_test(M, K, N)
102+
assert sqnr >= 80.0, f"SQNR {sqnr} below threshold for dims {M}x{K}x{N}"
+251
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
#include <torch/library.h>
2+
3+
#include <ATen/ATen.h>
4+
#include <ATen/core/Tensor.h>
5+
#include <ATen/cuda/CUDAUtils.h>
6+
#include <c10/util/Exception.h>
7+
#include <ATen/cuda/CUDAContext.h>
8+
#include <c10/cuda/CUDAException.h>
9+
10+
#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \
11+
defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
12+
#define BUILD_MX_KERNELS_CUTLASS
13+
#endif
14+
15+
#if defined(BUILD_MX_KERNELS_CUTLASS)
16+
17+
#include "cute/tensor.hpp"
18+
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
19+
#include "cutlass/epilogue/collective/collective_builder.hpp"
20+
#include "cutlass/epilogue/thread/linear_combination.h"
21+
#include "cutlass/gemm/collective/collective_builder.hpp"
22+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
23+
#include "cutlass/util/packed_stride.hpp"
24+
25+
26+
#endif
27+
28+
namespace torchao {
29+
30+
#if defined(BUILD_MX_KERNELS_CUTLASS)
31+
namespace {
32+
33+
using namespace cute;
34+
35+
template<typename Element>
36+
constexpr int GetAlignment() {
37+
if constexpr (std::is_same_v<Element, cutlass::nv_float4_t<cutlass::float_e2m1_t>>)
38+
return 32;
39+
return 16;
40+
}
41+
42+
template <typename ElementA,
43+
typename ElementB,
44+
typename ElementD,
45+
typename MmaTileShape,
46+
typename ClusterShape,
47+
typename PerSmTileShape_MNK>
48+
void run_gemm(at::Tensor& a, at::Tensor& b, at::Tensor& a_scale,
49+
at::Tensor& b_scale, at::Tensor& out) {
50+
int M = a.size(0);
51+
int K = a.size(1);
52+
int N = b.size(1);
53+
54+
// A matrix configuration
55+
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
56+
constexpr int AlignmentA = GetAlignment<ElementA>(); // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
57+
58+
// B matrix configuration
59+
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
60+
constexpr int AlignmentB = GetAlignment<ElementB>(); // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
61+
62+
// C/D matrix configuration
63+
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
64+
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
65+
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
66+
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes)
67+
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
68+
// Kernel functional config
69+
using ElementAccumulator = float; // Element type for internal accumulation
70+
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
71+
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
72+
73+
74+
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
75+
ArchTag, OperatorClass,
76+
PerSmTileShape_MNK, ClusterShape,
77+
cutlass::epilogue::collective::EpilogueTileAuto,
78+
ElementAccumulator, ElementAccumulator,
79+
ElementC, LayoutCTag, AlignmentC,
80+
ElementD, LayoutDTag, AlignmentD,
81+
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
82+
>::CollectiveOp;
83+
84+
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
85+
ArchTag, OperatorClass,
86+
ElementA, LayoutATag, AlignmentA,
87+
ElementB, LayoutBTag, AlignmentB,
88+
ElementAccumulator,
89+
MmaTileShape, ClusterShape,
90+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
91+
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy
92+
>::CollectiveOp;
93+
94+
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
95+
Shape<int,int,int,int>, // Indicates ProblemShape
96+
CollectiveMainloop,
97+
CollectiveEpilogue,
98+
void>;
99+
100+
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
101+
102+
// Reference device GEMM implementation type
103+
using StrideA = typename Gemm::GemmKernel::StrideA;
104+
using StrideB = typename Gemm::GemmKernel::StrideB;
105+
using StrideC = typename Gemm::GemmKernel::StrideC;
106+
using StrideD = typename Gemm::GemmKernel::StrideD;
107+
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
108+
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
109+
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
110+
111+
// Initialize strides using packed stride configuration
112+
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, make_shape(M, K, 1));
113+
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, make_shape(N, K, 1));
114+
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, make_shape(M, N, 1));
115+
116+
// Initialize scale factor layouts using block scaled configuration
117+
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1));
118+
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1));
119+
120+
using DtypeA = typename ElementA::DataType;
121+
using DtypeB = typename ElementB::DataType;
122+
using DtypeScaleA = typename ElementA::ScaleFactorType;
123+
using DtypeScaleB = typename ElementB::ScaleFactorType;
124+
using DtypeOut = ElementD;
125+
126+
Gemm gemm;
127+
128+
auto A_ptr = reinterpret_cast<DtypeA*>(a.data_ptr());
129+
auto B_ptr = reinterpret_cast<DtypeB*>(b.data_ptr());
130+
auto SFA_ptr = reinterpret_cast<DtypeScaleA*>(a_scale.data_ptr());
131+
auto SFB_ptr = reinterpret_cast<DtypeScaleB*>(b_scale.data_ptr());
132+
auto out_ptr = reinterpret_cast<DtypeOut*>(out.data_ptr());
133+
134+
typename Gemm::Arguments arguments{
135+
cutlass::gemm::GemmUniversalMode::kGemm,
136+
{M, N, K, 1},
137+
{ // Mainloop arguments
138+
A_ptr, stride_A,
139+
B_ptr, stride_B,
140+
SFA_ptr, layout_SFA,
141+
SFB_ptr, layout_SFB
142+
},
143+
{ // Epilogue arguments
144+
{1.0, 0.0},
145+
nullptr, StrideC{}, // No bias for now
146+
out_ptr, stride_D
147+
}
148+
};
149+
150+
// arguments.scheduler.max_swizzle_size = 8;
151+
152+
// Check the problem size is supported or not
153+
cutlass::Status status = gemm.can_implement(arguments);
154+
TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot implement");
155+
// Allocate workspace memory
156+
size_t workspace_size = Gemm::get_workspace_size(arguments);
157+
auto workspace = a.new_empty(
158+
{static_cast<int64_t>(workspace_size)},
159+
at::TensorOptions().dtype(at::kByte));
160+
161+
162+
// Initialize CUTLASS kernel with arguments and workspace pointer
163+
status = gemm.initialize(arguments, workspace.data_ptr());
164+
TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot initialize");
165+
166+
status = gemm.run(at::cuda::getCurrentCUDAStream());
167+
TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot run", cutlass::cutlassGetStatusString(status));
168+
169+
C10_CUDA_KERNEL_LAUNCH_CHECK();
170+
171+
}
172+
}
173+
#endif
174+
175+
void validate(at::Tensor a, at::Tensor b, at::Tensor a_scale, at::Tensor b_scale){
176+
TORCH_CHECK(a.is_cuda(), "a must be CUDA tensor");
177+
TORCH_CHECK(b.is_cuda(), "b must be CUDA tensor");
178+
TORCH_CHECK(a_scale.is_cuda(), "a_scale must be CUDA tensor");
179+
TORCH_CHECK(b_scale.is_cuda(), "b_scale must be CUDA tensor");
180+
181+
// Check matrix dimensions
182+
TORCH_CHECK(a.dim() == 2, "a must be a matrix");
183+
TORCH_CHECK(b.dim() == 2, "b must be a matrix");
184+
185+
// Get dimensions
186+
auto M = a.size(0);
187+
auto K = a.size(1);
188+
auto N = b.size(1);
189+
190+
TORCH_CHECK(b.size(0) == K,
191+
"Incompatible matrix dimensions: a is ", M, "x", K, " but b is ", b.size(0), "x", N);
192+
193+
// Needed for TMA store
194+
TORCH_CHECK(N % 8 == 0, "N must be a multiple of 16 but got, ", N);
195+
196+
// Check 16-byte alignment for input tensors
197+
TORCH_CHECK(
198+
reinterpret_cast<std::uintptr_t>(a.data_ptr()) % 16 == 0,
199+
"Input tensor 'a' must be 16-byte aligned");
200+
TORCH_CHECK(
201+
reinterpret_cast<std::uintptr_t>(b.data_ptr()) % 16 == 0,
202+
"Input tensor 'b' must be 16-byte aligned");
203+
204+
auto ceil_div = [](auto a, auto b) { return (a + b - 1) / b; };
205+
auto num_k_blocks = ceil_div(K, 32);
206+
// For a_scale, we expect elements or M* ceil(K/32) elements
207+
auto expected_a_scale_size = 128 * ceil_div(M, 128) * num_k_blocks;
208+
TORCH_CHECK(a_scale.numel() == expected_a_scale_size, "Expected b_scale_size to be ", expected_a_scale_size, " but got ", a_scale.numel());
209+
210+
// For b_scale, we expect N * ceil(K/32) elements
211+
auto expected_b_scale_size = 128 * ceil_div(N, 128) * num_k_blocks;
212+
TORCH_CHECK(b_scale.numel() == expected_b_scale_size, "Expected a_scale_size to be ", expected_b_scale_size, " but got ", b_scale.numel());
213+
214+
// Check tensor strides for optimal memory layout
215+
TORCH_CHECK(
216+
a.stride(1) == 1,
217+
"Input tensor 'a' must be contiguous in the K dimension (row-major)");
218+
TORCH_CHECK(
219+
b.stride(0) == 1,
220+
"Input tensor 'b' must be contiguous in the K dimension (column-major)");
221+
}
222+
223+
224+
at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
225+
at::Tensor b_scale) {
226+
#if defined(BUILD_MX_KERNELS_CUTLASS)
227+
validate(a, b, a_scale, b_scale);
228+
229+
auto out =
230+
at::empty({a.size(0), b.size(1)}, a.options().dtype(at::kBFloat16));
231+
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
232+
using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
233+
using ElementD = cutlass::bfloat16_t;
234+
235+
using MmaTileShape = Shape<_128,_128,_128>;
236+
using ClusterShape = Shape<_2,_1,_1>;
237+
using PerSmTileShape_MNK = Shape<_128,_128,_128>;
238+
239+
run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out);
240+
return out;
241+
#else
242+
TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
243+
return at::Tensor{};
244+
#endif
245+
}
246+
247+
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
248+
m.impl("torchao::mx_fp8_bf16", &mx_fp8_bf16);
249+
}
250+
251+
} // namespace torchao

0 commit comments

Comments
 (0)