Skip to content

Commit d3306b2

Browse files
authored
Add mx_fp8_bf16 kernel (#1637)
* Add mx_fp8_bf16 kernel stack-info: PR: #1637, branch: drisspg/stack/31 * Add mx_fp4_kernel (#1661) stack-info: PR: #1661
1 parent aa51486 commit d3306b2

File tree

5 files changed

+497
-4
lines changed

5 files changed

+497
-4
lines changed

Diff for: setup.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,7 @@ def get_extensions():
215215
extra_link_args = []
216216
extra_compile_args = {
217217
"cxx": [f"-DPy_LIMITED_API={PY3_9_HEXCODE}"],
218-
"nvcc": [
219-
"-O3" if not debug_mode else "-O0",
220-
"-t=0",
221-
],
218+
"nvcc": ["-O3" if not debug_mode else "-O0", "-t=0", "-std=c++17"],
222219
}
223220

224221
if not IS_WINDOWS:
@@ -257,12 +254,16 @@ def get_extensions():
257254
use_cutlass = True
258255
cutlass_dir = os.path.join(third_party_path, "cutlass")
259256
cutlass_include_dir = os.path.join(cutlass_dir, "include")
257+
cutlass_tools_include_dir = os.path.join(
258+
cutlass_dir, "tools", "util", "include"
259+
)
260260
cutlass_extensions_include_dir = os.path.join(cwd, extensions_cuda_dir)
261261
if use_cutlass:
262262
extra_compile_args["nvcc"].extend(
263263
[
264264
"-DTORCHAO_USE_CUTLASS",
265265
"-I" + cutlass_include_dir,
266+
"-I" + cutlass_tools_include_dir,
266267
"-I" + cutlass_extensions_include_dir,
267268
]
268269
)

Diff for: test/prototype/mx_formats/test_mx_mm.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import pytest
2+
import torch
3+
4+
from torchao.float8.float8_utils import compute_error
5+
from torchao.ops import mx_fp4_bf16, mx_fp8_bf16
6+
from torchao.prototype.mx_formats.mx_tensor import DTYPE_FP4, MXTensor
7+
from torchao.prototype.mx_formats.utils import to_blocked
8+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_100
9+
10+
if not TORCH_VERSION_AT_LEAST_2_4:
11+
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
12+
13+
14+
def run_matrix_test(M: int, K: int, N: int, format) -> float:
15+
dtype = torch.bfloat16
16+
device = torch.device("cuda")
17+
18+
a = torch.rand((M, K), dtype=dtype, device=device)
19+
b = torch.rand((N, K), dtype=dtype, device=device)
20+
21+
fmt = torch.float8_e4m3fn if format == "fp8" else DTYPE_FP4
22+
mx_func = mx_fp8_bf16 if format == "fp8" else mx_fp4_bf16
23+
24+
a_mx = MXTensor.to_mx(a, fmt, 32)
25+
b_mx = MXTensor.to_mx(b, fmt, 32)
26+
27+
a_data = a_mx._data
28+
b_data = b_mx._data
29+
assert b_data.is_contiguous()
30+
b_data = b_data.transpose(-1, -2)
31+
32+
a_scale = a_mx._scale_e8m0.view(M, K // 32)
33+
b_scale = b_mx._scale_e8m0.view(N, K // 32)
34+
35+
a_scale_block = to_blocked(a_scale)
36+
b_scale_block = to_blocked(b_scale)
37+
38+
out_hp = a_mx.to_dtype(torch.bfloat16) @ b_mx.to_dtype(torch.bfloat16).transpose(
39+
-1, -2
40+
)
41+
out = mx_func(a_data, b_data, a_scale_block, b_scale_block)
42+
43+
return compute_error(out_hp, out).item()
44+
45+
46+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
47+
@pytest.mark.skipif(
48+
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8"
49+
)
50+
@pytest.mark.parametrize(
51+
"size",
52+
[
53+
(128, 128, 128),
54+
(256, 256, 256),
55+
(384, 384, 384), # Small
56+
(512, 512, 512),
57+
(768, 768, 768), # Medium
58+
(1024, 1024, 1024),
59+
(8192, 8192, 8192), # Large
60+
(128, 256, 384),
61+
(256, 384, 512), # Non-square
62+
(129, 256, 384),
63+
(133, 512, 528), # Non-aligned
64+
],
65+
ids=lambda x: f"{x[0]}x{x[1]}x{x[2]}",
66+
)
67+
@pytest.mark.parametrize("format", ["fp8", "fp4"])
68+
def test_matrix_multiplication(size, format):
69+
M, K, N = size
70+
sqnr = run_matrix_test(M, K, N, format)
71+
threshold = 80.0
72+
assert (
73+
sqnr >= threshold
74+
), f"{format} SQNR {sqnr} below threshold for dims {M}x{K}x{N}"
+285
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
#include <torch/library.h>
2+
3+
#include <ATen/ATen.h>
4+
#include <ATen/core/Tensor.h>
5+
#include <ATen/cuda/CUDAUtils.h>
6+
#include <c10/util/Exception.h>
7+
#include <ATen/cuda/CUDAContext.h>
8+
#include <c10/cuda/CUDAException.h>
9+
10+
#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \
11+
defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
12+
#define BUILD_MX_KERNELS_CUTLASS
13+
#endif
14+
15+
#if defined(BUILD_MX_KERNELS_CUTLASS)
16+
17+
#include "cute/tensor.hpp"
18+
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
19+
#include "cutlass/epilogue/collective/collective_builder.hpp"
20+
#include "cutlass/epilogue/thread/linear_combination.h"
21+
#include "cutlass/gemm/collective/collective_builder.hpp"
22+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
23+
#include "cutlass/util/packed_stride.hpp"
24+
25+
26+
#endif
27+
28+
namespace torchao {
29+
30+
#if defined(BUILD_MX_KERNELS_CUTLASS)
31+
namespace {
32+
33+
using namespace cute;
34+
35+
template<typename Element>
36+
constexpr int GetAlignment() {
37+
if constexpr (std::is_same_v<Element, cutlass::mx_float4_t<cutlass::float_e2m1_t>>)
38+
return 32;
39+
return 16;
40+
}
41+
42+
template <typename ElementA,
43+
typename ElementB,
44+
typename ElementD,
45+
typename MmaTileShape,
46+
typename ClusterShape,
47+
typename PerSmTileShape_MNK>
48+
void run_gemm(at::Tensor& a, at::Tensor& b, at::Tensor& a_scale,
49+
at::Tensor& b_scale, at::Tensor& out, int M, int K, int N) {
50+
// A matrix configuration
51+
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
52+
constexpr int AlignmentA = GetAlignment<ElementA>(); // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
53+
54+
// B matrix configuration
55+
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
56+
constexpr int AlignmentB = GetAlignment<ElementB>(); // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
57+
58+
// C/D matrix configuration
59+
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
60+
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
61+
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
62+
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes)
63+
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
64+
// Kernel functional config
65+
using ElementAccumulator = float; // Element type for internal accumulation
66+
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
67+
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
68+
69+
70+
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
71+
ArchTag, OperatorClass,
72+
PerSmTileShape_MNK, ClusterShape,
73+
cutlass::epilogue::collective::EpilogueTileAuto,
74+
ElementAccumulator, ElementAccumulator,
75+
ElementC, LayoutCTag, AlignmentC,
76+
ElementD, LayoutDTag, AlignmentD,
77+
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
78+
>::CollectiveOp;
79+
80+
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
81+
ArchTag, OperatorClass,
82+
ElementA, LayoutATag, AlignmentA,
83+
ElementB, LayoutBTag, AlignmentB,
84+
ElementAccumulator,
85+
MmaTileShape, ClusterShape,
86+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
87+
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy
88+
>::CollectiveOp;
89+
90+
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
91+
Shape<int,int,int,int>, // Indicates ProblemShape
92+
CollectiveMainloop,
93+
CollectiveEpilogue,
94+
void>;
95+
96+
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
97+
98+
// Reference device GEMM implementation type
99+
using StrideA = typename Gemm::GemmKernel::StrideA;
100+
using StrideB = typename Gemm::GemmKernel::StrideB;
101+
using StrideC = typename Gemm::GemmKernel::StrideC;
102+
using StrideD = typename Gemm::GemmKernel::StrideD;
103+
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
104+
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
105+
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
106+
107+
// Initialize strides using packed stride configuration
108+
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, make_shape(M, K, 1));
109+
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, make_shape(N, K, 1));
110+
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, make_shape(M, N, 1));
111+
112+
// Initialize scale factor layouts using block scaled configuration
113+
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1));
114+
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1));
115+
116+
using DtypeA = typename ElementA::DataType;
117+
using DtypeB = typename ElementB::DataType;
118+
using DtypeScaleA = typename ElementA::ScaleFactorType;
119+
using DtypeScaleB = typename ElementB::ScaleFactorType;
120+
using DtypeOut = ElementD;
121+
122+
Gemm gemm;
123+
124+
auto A_ptr = reinterpret_cast<DtypeA*>(a.data_ptr());
125+
auto B_ptr = reinterpret_cast<DtypeB*>(b.data_ptr());
126+
auto SFA_ptr = reinterpret_cast<DtypeScaleA*>(a_scale.data_ptr());
127+
auto SFB_ptr = reinterpret_cast<DtypeScaleB*>(b_scale.data_ptr());
128+
auto out_ptr = reinterpret_cast<DtypeOut*>(out.data_ptr());
129+
130+
typename Gemm::Arguments arguments{
131+
cutlass::gemm::GemmUniversalMode::kGemm,
132+
{M, N, K, 1},
133+
{ // Mainloop arguments
134+
A_ptr, stride_A,
135+
B_ptr, stride_B,
136+
SFA_ptr, layout_SFA,
137+
SFB_ptr, layout_SFB
138+
},
139+
{ // Epilogue arguments
140+
{1.0, 0.0},
141+
nullptr, StrideC{}, // No bias for now
142+
out_ptr, stride_D
143+
}
144+
};
145+
146+
// arguments.scheduler.max_swizzle_size = 8;
147+
148+
// Check the problem size is supported or not
149+
cutlass::Status status = gemm.can_implement(arguments);
150+
TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot implement");
151+
// Allocate workspace memory
152+
size_t workspace_size = Gemm::get_workspace_size(arguments);
153+
auto workspace = a.new_empty(
154+
{static_cast<int64_t>(workspace_size)},
155+
at::TensorOptions().dtype(at::kByte));
156+
157+
158+
// Initialize CUTLASS kernel with arguments and workspace pointer
159+
status = gemm.initialize(arguments, workspace.data_ptr());
160+
TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot initialize");
161+
162+
status = gemm.run(at::cuda::getCurrentCUDAStream());
163+
TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot run", cutlass::cutlassGetStatusString(status));
164+
165+
C10_CUDA_KERNEL_LAUNCH_CHECK();
166+
167+
}
168+
}
169+
#endif
170+
171+
void validate(at::Tensor a, at::Tensor b, at::Tensor a_scale, at::Tensor b_scale){
172+
TORCH_CHECK(a.is_cuda(), "a must be CUDA tensor");
173+
TORCH_CHECK(b.is_cuda(), "b must be CUDA tensor");
174+
TORCH_CHECK(a_scale.is_cuda(), "a_scale must be CUDA tensor");
175+
TORCH_CHECK(b_scale.is_cuda(), "b_scale must be CUDA tensor");
176+
177+
// Check matrix dimensions
178+
TORCH_CHECK(a.dim() == 2, "a must be a matrix");
179+
TORCH_CHECK(b.dim() == 2, "b must be a matrix");
180+
181+
// Get dimensions
182+
auto M = a.size(0);
183+
auto K = a.size(1);
184+
auto N = b.size(1);
185+
186+
TORCH_CHECK(b.size(0) == K,
187+
"Incompatible matrix dimensions: a is ", M, "x", K, " but b is ", b.size(0), "x", N);
188+
189+
// Needed for TMA store
190+
TORCH_CHECK(N % 8 == 0, "N must be a multiple of 16 but got, ", N);
191+
192+
// Check 16-byte alignment for input tensors
193+
TORCH_CHECK(
194+
reinterpret_cast<std::uintptr_t>(a.data_ptr()) % 16 == 0,
195+
"Input tensor 'a' must be 16-byte aligned");
196+
TORCH_CHECK(
197+
reinterpret_cast<std::uintptr_t>(b.data_ptr()) % 16 == 0,
198+
"Input tensor 'b' must be 16-byte aligned");
199+
200+
auto ceil_div = [](auto a, auto b) { return (a + b - 1) / b; };
201+
auto num_k_blocks = ceil_div(K, 32);
202+
// For a_scale, we expect elements or M* ceil(K/32) elements
203+
auto expected_a_scale_size = 128 * ceil_div(M, 128) * num_k_blocks;
204+
TORCH_CHECK(a_scale.numel() == expected_a_scale_size, "Expected b_scale_size to be ", expected_a_scale_size, " but got ", a_scale.numel());
205+
206+
// For b_scale, we expect N * ceil(K/32) elements
207+
auto expected_b_scale_size = 128 * ceil_div(N, 128) * num_k_blocks;
208+
TORCH_CHECK(b_scale.numel() == expected_b_scale_size, "Expected a_scale_size to be ", expected_b_scale_size, " but got ", b_scale.numel());
209+
210+
// Check tensor strides for optimal memory layout
211+
TORCH_CHECK(
212+
a.stride(1) == 1,
213+
"Input tensor 'a' must be contiguous in the K dimension (row-major)");
214+
TORCH_CHECK(
215+
b.stride(0) == 1,
216+
"Input tensor 'b' must be contiguous in the K dimension (column-major)");
217+
}
218+
219+
220+
at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
221+
at::Tensor b_scale) {
222+
#if defined(BUILD_MX_KERNELS_CUTLASS)
223+
validate(a, b, a_scale, b_scale);
224+
auto M = a.size(0);
225+
auto K = a.size(1);
226+
auto N = b.size(1);
227+
228+
auto out =
229+
at::empty({M, N}, a.options().dtype(at::kBFloat16));
230+
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
231+
using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
232+
using ElementD = cutlass::bfloat16_t;
233+
234+
using MmaTileShape = Shape<_128,_128,_128>;
235+
using ClusterShape = Shape<_2,_1,_1>;
236+
using PerSmTileShape_MNK = Shape<_128,_128,_128>;
237+
238+
run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out, M, K, N);
239+
return out;
240+
#else
241+
TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
242+
return at::Tensor{};
243+
#endif
244+
}
245+
246+
at::Tensor mx_fp4_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
247+
at::Tensor b_scale) {
248+
#if defined(BUILD_MX_KERNELS_CUTLASS)
249+
TORCH_CHECK(a.is_cuda(), "a must be CUDA tensor");
250+
TORCH_CHECK(b.is_cuda(), "b must be CUDA tensor");
251+
TORCH_CHECK(a_scale.is_cuda(), "a_scale must be CUDA tensor");
252+
TORCH_CHECK(b_scale.is_cuda(), "b_scale must be CUDA tensor");
253+
254+
auto M = a.size(0);
255+
auto K = a.size(1) * 2;
256+
auto N = b.size(1);
257+
258+
auto out =
259+
at::empty({M, N}, a.options().dtype(at::kBFloat16));
260+
using ElementA = cutlass::mx_float4_t<cutlass::float_e2m1_t>;
261+
using ElementB = cutlass::mx_float4_t<cutlass::float_e2m1_t>;
262+
using ElementD = cutlass::bfloat16_t;
263+
264+
using MmaTileShape = Shape<_128,_128,_128>;
265+
using ClusterShape = Shape<_2,_1,_1>;
266+
using PerSmTileShape_MNK = Shape<_128,_128,_128>;
267+
268+
run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out, M, K, N);
269+
return out;
270+
#else
271+
TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
272+
return at::Tensor{};
273+
#endif
274+
}
275+
276+
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
277+
m.impl("torchao::mx_fp8_bf16", &mx_fp8_bf16);
278+
}
279+
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
280+
m.impl("torchao::mx_fp4_bf16", &mx_fp4_bf16);
281+
}
282+
283+
284+
285+
} // namespace torchao

0 commit comments

Comments
 (0)