Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit ae51147

Browse files
committedJan 29, 2025·
Add mx_fp8_bf16 kernel
stack-info: PR: #1637, branch: drisspg/stack/31
1 parent cef8f5f commit ae51147

File tree

3 files changed

+239
-1
lines changed

3 files changed

+239
-1
lines changed
 

‎setup.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def get_extensions():
218218
"nvcc": [
219219
"-O3" if not debug_mode else "-O0",
220220
"-t=0",
221+
"-std=c++20"
221222
],
222223
}
223224

@@ -243,13 +244,16 @@ def get_extensions():
243244
use_cutlass = False
244245
if use_cuda and not IS_WINDOWS:
245246
use_cutlass = True
247+
248+
if use_cutlass:
246249
cutlass_dir = os.path.join(third_party_path, "cutlass")
247250
cutlass_include_dir = os.path.join(cutlass_dir, "include")
248-
if use_cutlass:
251+
cutlass_tools_include_dir = os.path.join(cutlass_dir, "tools", "util", "include")
249252
extra_compile_args["nvcc"].extend(
250253
[
251254
"-DTORCHAO_USE_CUTLASS",
252255
"-I" + cutlass_include_dir,
256+
"-I" + cutlass_tools_include_dir,
253257
]
254258
)
255259

+205
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
#include <torch/library.h>
2+
3+
#include <ATen/ATen.h>
4+
#include <ATen/core/Tensor.h>
5+
#include <ATen/cuda/CUDAUtils.h>
6+
#include <c10/util/Exception.h>
7+
#include <ATen/cuda/CUDAContext.h>
8+
#include <c10/cuda/CUDAException.h>
9+
10+
#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \
11+
defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
12+
#define BUILD_MX_KERNELS_CUTLASS
13+
#endif
14+
15+
#if defined(BUILD_MX_KERNELS_CUTLASS)
16+
17+
#include "cute/tensor.hpp"
18+
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
19+
#include "cutlass/epilogue/collective/collective_builder.hpp"
20+
#include "cutlass/epilogue/thread/linear_combination.h"
21+
#include "cutlass/gemm/collective/collective_builder.hpp"
22+
#include "cutlass/gemm/device/gemm_universal_adapter.h"
23+
#include "cutlass/util/packed_stride.hpp"
24+
25+
26+
#endif
27+
28+
namespace torchao {
29+
30+
#if defined(BUILD_MX_KERNELS_CUTLASS)
31+
namespace {
32+
33+
using namespace cute;
34+
35+
template<typename Element>
36+
constexpr int GetAlignment() {
37+
if constexpr (std::is_same_v<Element, cutlass::nv_float4_t<cutlass::float_e2m1_t>>)
38+
return 32;
39+
return 16;
40+
}
41+
42+
template <typename ElementA,
43+
typename ElementB,
44+
typename ElementD,
45+
typename MmaTileShape,
46+
typename ClusterShape,
47+
typename PerSmTileShape_MNK>
48+
void run_gemm(at::Tensor& a, at::Tensor& b, at::Tensor& a_scale,
49+
at::Tensor& b_scale, at::Tensor& out) {
50+
int M = a.size(0);
51+
int K = a.size(1);
52+
int N = b.size(1);
53+
54+
// A matrix configuration
55+
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
56+
constexpr int AlignmentA = GetAlignment<ElementA>(); // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
57+
58+
// B matrix configuration
59+
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
60+
constexpr int AlignmentB = 128; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)
61+
62+
// C/D matrix configuration
63+
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand
64+
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand
65+
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand
66+
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes)
67+
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)
68+
// Kernel functional config
69+
using ElementAccumulator = float; // Element type for internal accumulation
70+
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature
71+
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag
72+
73+
74+
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
75+
ArchTag, OperatorClass,
76+
PerSmTileShape_MNK, ClusterShape,
77+
cutlass::epilogue::collective::EpilogueTileAuto,
78+
ElementAccumulator, ElementAccumulator,
79+
ElementC, LayoutCTag, AlignmentC,
80+
ElementD, LayoutDTag, AlignmentD,
81+
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy
82+
>::CollectiveOp;
83+
84+
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
85+
ArchTag, OperatorClass,
86+
ElementA, LayoutATag, AlignmentA,
87+
ElementB, LayoutBTag, AlignmentB,
88+
ElementAccumulator,
89+
MmaTileShape, ClusterShape,
90+
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
91+
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy
92+
>::CollectiveOp;
93+
94+
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
95+
Shape<int,int,int,int>, // Indicates ProblemShape
96+
CollectiveMainloop,
97+
CollectiveEpilogue,
98+
void>;
99+
100+
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
101+
102+
// Reference device GEMM implementation type
103+
using StrideA = typename Gemm::GemmKernel::StrideA;
104+
using StrideB = typename Gemm::GemmKernel::StrideB;
105+
using StrideC = typename Gemm::GemmKernel::StrideC;
106+
using StrideD = typename Gemm::GemmKernel::StrideD;
107+
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA;
108+
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB;
109+
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig;
110+
111+
// Initialize strides using packed stride configuration
112+
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, make_shape(M, K, 1));
113+
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, make_shape(N, K, 1));
114+
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, make_shape(M, N, 1));
115+
116+
// Initialize scale factor layouts using block scaled configuration
117+
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1));
118+
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1));
119+
120+
using DtypeA = ElementA::DataType;
121+
using DtypeB = ElementB::DataType;
122+
using DtypeScaleA = ElementA::ScaleFactorType;
123+
using DtypeScaleB = ElementB::ScaleFactorType;
124+
using DtypeOut = ElementD;
125+
126+
Gemm gemm;
127+
128+
auto A_ptr = reinterpret_cast<DtypeA*>(a.data_ptr());
129+
auto B_ptr = reinterpret_cast<DtypeB*>(b.data_ptr());
130+
auto SFA_ptr = reinterpret_cast<DtypeScaleA*>(a_scale.data_ptr());
131+
auto SFB_ptr = reinterpret_cast<DtypeScaleB*>(b_scale.data_ptr());
132+
auto out_ptr = reinterpret_cast<DtypeOut*>(out.data_ptr());
133+
134+
typename Gemm::Arguments arguments{
135+
cutlass::gemm::GemmUniversalMode::kGemm,
136+
{M, N, K, 1},
137+
{ // Mainloop arguments
138+
A_ptr, stride_A,
139+
B_ptr, stride_B,
140+
SFA_ptr, layout_SFA,
141+
SFB_ptr, layout_SFB
142+
},
143+
{ // Epilogue arguments
144+
{1.0, 0.0},
145+
nullptr, StrideC{}, // No bias for now
146+
out_ptr, stride_D
147+
}
148+
};
149+
150+
// arguments.scheduler.max_swizzle_size = 8;
151+
152+
// Check the problem size is supported or not
153+
cutlass::Status status = gemm.can_implement(arguments);
154+
TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot implement");
155+
// Allocate workspace memory
156+
size_t workspace_size = Gemm::get_workspace_size(arguments);
157+
auto workspace = a.new_empty(
158+
{static_cast<int64_t>(workspace_size)},
159+
at::TensorOptions().dtype(at::kByte));
160+
161+
162+
// Initialize CUTLASS kernel with arguments and workspace pointer
163+
status = gemm.initialize(arguments, workspace.data_ptr());
164+
TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot initialize");
165+
166+
status = gemm.run(at::cuda::getCurrentCUDAStream());
167+
TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot run", cutlass::cutlassGetStatusString(status));
168+
169+
C10_CUDA_KERNEL_LAUNCH_CHECK();
170+
171+
}
172+
}
173+
#endif
174+
175+
at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
176+
at::Tensor b_scale) {
177+
#if defined(BUILD_MX_KERNELS_CUTLASS)
178+
TORCH_CHECK(a.is_cuda(), "a must be CUDA tensor");
179+
TORCH_CHECK(b.is_cuda(), "b must be CUDA tensor");
180+
TORCH_CHECK(a_scale.is_cuda(), "a_scale must be CUDA tensor");
181+
TORCH_CHECK(b_scale.is_cuda(), "b_scale must be CUDA tensor");
182+
183+
auto out =
184+
at::empty({a.size(0), b.size(1)}, a.options().dtype(at::kBFloat16));
185+
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
186+
using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
187+
using ElementD = cutlass::bfloat16_t;
188+
189+
using MmaTileShape = Shape<_256,_256,_256>;
190+
using ClusterShape = Shape<_4,_4,_1>;
191+
using PerSmTileShape_MNK = Shape<_128,_256,_256>;
192+
193+
run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out);
194+
return out;
195+
#else
196+
TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
197+
return at::Tensor{};
198+
#endif
199+
}
200+
201+
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
202+
m.impl("torchao::mx_fp8_bf16", &mx_fp8_bf16);
203+
}
204+
205+
} // namespace torchao

‎torchao/ops.py

+29
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
lib.define(
2323
"s8s4_linear_cutlass(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"
2424
)
25+
lib.define(
26+
"mx_fp8_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor"
27+
)
2528

2629

2730
def register_custom_op(name):
@@ -615,3 +618,29 @@ def _(
615618
dtype=input_scale.dtype,
616619
device=input.device,
617620
)
621+
622+
623+
def mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
624+
"""Defines a matmul between two fp8 tensors w/ MX scales in E8MO and returns a bf16 tensor.
625+
626+
Note: The mx scales are E8MO tensors store in uint8 tensors (for now).
627+
The layout of the scales is very particular, see:
628+
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
629+
630+
Args:
631+
A: fp8 tensor
632+
B: fp8 tensor
633+
A_scale: E8M0 scale tensor for A with groupsize=32 in swizzled layout
634+
B_scale: E8M0 scale tensor for B with groupsize=32 in swizzled layout
635+
636+
Returns:
637+
MXN bf16 Tensor
638+
639+
"""
640+
return torch.ops.torchao.mx_fp8_bf16.default(A, B, A_scale, B_scale)
641+
642+
643+
@register_custom_op("torchao::mx_fp8_bf16")
644+
def meta_mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
645+
"""Meta impl for mx_fp8_bf16"""
646+
return torch.empty((A.size(0), B.size(1)), dtype=torch.bfloat16, device=A.device)

0 commit comments

Comments
 (0)
Please sign in to comment.