|
| 1 | +#include <torch/library.h> |
| 2 | + |
| 3 | +#include <ATen/ATen.h> |
| 4 | +#include <ATen/core/Tensor.h> |
| 5 | +#include <ATen/cuda/CUDAUtils.h> |
| 6 | +#include <c10/util/Exception.h> |
| 7 | +#include <ATen/cuda/CUDAContext.h> |
| 8 | +#include <c10/cuda/CUDAException.h> |
| 9 | + |
| 10 | +#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ |
| 11 | + defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) |
| 12 | +#define BUILD_MX_KERNELS_CUTLASS |
| 13 | +#endif |
| 14 | + |
| 15 | +#if defined(BUILD_MX_KERNELS_CUTLASS) |
| 16 | + |
| 17 | +#include "cute/tensor.hpp" |
| 18 | +#include "cutlass/detail/sm100_blockscaled_layout.hpp" |
| 19 | +#include "cutlass/epilogue/collective/collective_builder.hpp" |
| 20 | +#include "cutlass/epilogue/thread/linear_combination.h" |
| 21 | +#include "cutlass/gemm/collective/collective_builder.hpp" |
| 22 | +#include "cutlass/gemm/device/gemm_universal_adapter.h" |
| 23 | +#include "cutlass/util/packed_stride.hpp" |
| 24 | + |
| 25 | + |
| 26 | +#endif |
| 27 | + |
| 28 | +namespace torchao { |
| 29 | + |
| 30 | +#if defined(BUILD_MX_KERNELS_CUTLASS) |
| 31 | +namespace { |
| 32 | + |
| 33 | +using namespace cute; |
| 34 | + |
| 35 | +template<typename Element> |
| 36 | +constexpr int GetAlignment() { |
| 37 | + if constexpr (std::is_same_v<Element, cutlass::nv_float4_t<cutlass::float_e2m1_t>>) |
| 38 | + return 32; |
| 39 | + return 16; |
| 40 | +} |
| 41 | + |
| 42 | +template <typename ElementA, |
| 43 | + typename ElementB, |
| 44 | + typename ElementD, |
| 45 | + typename MmaTileShape, |
| 46 | + typename ClusterShape, |
| 47 | + typename PerSmTileShape_MNK> |
| 48 | +void run_gemm(at::Tensor& a, at::Tensor& b, at::Tensor& a_scale, |
| 49 | + at::Tensor& b_scale, at::Tensor& out) { |
| 50 | + int M = a.size(0); |
| 51 | + int K = a.size(1); |
| 52 | + int N = b.size(1); |
| 53 | + |
| 54 | + // A matrix configuration |
| 55 | + using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand |
| 56 | + constexpr int AlignmentA = GetAlignment<ElementA>(); // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) |
| 57 | + |
| 58 | + // B matrix configuration |
| 59 | + using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand |
| 60 | + constexpr int AlignmentB = GetAlignment<ElementB>(); // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) |
| 61 | + |
| 62 | + // C/D matrix configuration |
| 63 | + using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand |
| 64 | + using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand |
| 65 | + using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand |
| 66 | + constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) |
| 67 | + constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) |
| 68 | + // Kernel functional config |
| 69 | + using ElementAccumulator = float; // Element type for internal accumulation |
| 70 | + using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature |
| 71 | + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag |
| 72 | + |
| 73 | + |
| 74 | + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< |
| 75 | + ArchTag, OperatorClass, |
| 76 | + PerSmTileShape_MNK, ClusterShape, |
| 77 | + cutlass::epilogue::collective::EpilogueTileAuto, |
| 78 | + ElementAccumulator, ElementAccumulator, |
| 79 | + ElementC, LayoutCTag, AlignmentC, |
| 80 | + ElementD, LayoutDTag, AlignmentD, |
| 81 | + cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy |
| 82 | + >::CollectiveOp; |
| 83 | + |
| 84 | + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< |
| 85 | + ArchTag, OperatorClass, |
| 86 | + ElementA, LayoutATag, AlignmentA, |
| 87 | + ElementB, LayoutBTag, AlignmentB, |
| 88 | + ElementAccumulator, |
| 89 | + MmaTileShape, ClusterShape, |
| 90 | + cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>, |
| 91 | + cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy |
| 92 | + >::CollectiveOp; |
| 93 | + |
| 94 | + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< |
| 95 | + Shape<int,int,int,int>, // Indicates ProblemShape |
| 96 | + CollectiveMainloop, |
| 97 | + CollectiveEpilogue, |
| 98 | + void>; |
| 99 | + |
| 100 | + using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; |
| 101 | + |
| 102 | + // Reference device GEMM implementation type |
| 103 | + using StrideA = typename Gemm::GemmKernel::StrideA; |
| 104 | + using StrideB = typename Gemm::GemmKernel::StrideB; |
| 105 | + using StrideC = typename Gemm::GemmKernel::StrideC; |
| 106 | + using StrideD = typename Gemm::GemmKernel::StrideD; |
| 107 | + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; |
| 108 | + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; |
| 109 | + using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; |
| 110 | + |
| 111 | + // Initialize strides using packed stride configuration |
| 112 | + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, make_shape(M, K, 1)); |
| 113 | + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, make_shape(N, K, 1)); |
| 114 | + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, make_shape(M, N, 1)); |
| 115 | + |
| 116 | + // Initialize scale factor layouts using block scaled configuration |
| 117 | + auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1)); |
| 118 | + auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1)); |
| 119 | + |
| 120 | + using DtypeA = typename ElementA::DataType; |
| 121 | + using DtypeB = typename ElementB::DataType; |
| 122 | + using DtypeScaleA = typename ElementA::ScaleFactorType; |
| 123 | + using DtypeScaleB = typename ElementB::ScaleFactorType; |
| 124 | + using DtypeOut = ElementD; |
| 125 | + |
| 126 | + Gemm gemm; |
| 127 | + |
| 128 | + auto A_ptr = reinterpret_cast<DtypeA*>(a.data_ptr()); |
| 129 | + auto B_ptr = reinterpret_cast<DtypeB*>(b.data_ptr()); |
| 130 | + auto SFA_ptr = reinterpret_cast<DtypeScaleA*>(a_scale.data_ptr()); |
| 131 | + auto SFB_ptr = reinterpret_cast<DtypeScaleB*>(b_scale.data_ptr()); |
| 132 | + auto out_ptr = reinterpret_cast<DtypeOut*>(out.data_ptr()); |
| 133 | + |
| 134 | + typename Gemm::Arguments arguments{ |
| 135 | + cutlass::gemm::GemmUniversalMode::kGemm, |
| 136 | + {M, N, K, 1}, |
| 137 | + { // Mainloop arguments |
| 138 | + A_ptr, stride_A, |
| 139 | + B_ptr, stride_B, |
| 140 | + SFA_ptr, layout_SFA, |
| 141 | + SFB_ptr, layout_SFB |
| 142 | + }, |
| 143 | + { // Epilogue arguments |
| 144 | + {1.0, 0.0}, |
| 145 | + nullptr, StrideC{}, // No bias for now |
| 146 | + out_ptr, stride_D |
| 147 | + } |
| 148 | + }; |
| 149 | + |
| 150 | + // arguments.scheduler.max_swizzle_size = 8; |
| 151 | + |
| 152 | + // Check the problem size is supported or not |
| 153 | + cutlass::Status status = gemm.can_implement(arguments); |
| 154 | + TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot implement"); |
| 155 | + // Allocate workspace memory |
| 156 | + size_t workspace_size = Gemm::get_workspace_size(arguments); |
| 157 | + auto workspace = a.new_empty( |
| 158 | + {static_cast<int64_t>(workspace_size)}, |
| 159 | + at::TensorOptions().dtype(at::kByte)); |
| 160 | + |
| 161 | + |
| 162 | + // Initialize CUTLASS kernel with arguments and workspace pointer |
| 163 | + status = gemm.initialize(arguments, workspace.data_ptr()); |
| 164 | + TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot initialize"); |
| 165 | + |
| 166 | + status = gemm.run(at::cuda::getCurrentCUDAStream()); |
| 167 | + TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot run", cutlass::cutlassGetStatusString(status)); |
| 168 | + |
| 169 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 170 | + |
| 171 | +} |
| 172 | +} |
| 173 | +#endif |
| 174 | + |
| 175 | +void validate(at::Tensor a, at::Tensor b, at::Tensor a_scale, at::Tensor b_scale){ |
| 176 | + TORCH_CHECK(a.is_cuda(), "a must be CUDA tensor"); |
| 177 | + TORCH_CHECK(b.is_cuda(), "b must be CUDA tensor"); |
| 178 | + TORCH_CHECK(a_scale.is_cuda(), "a_scale must be CUDA tensor"); |
| 179 | + TORCH_CHECK(b_scale.is_cuda(), "b_scale must be CUDA tensor"); |
| 180 | + |
| 181 | + // Check matrix dimensions |
| 182 | + TORCH_CHECK(a.dim() == 2, "a must be a matrix"); |
| 183 | + TORCH_CHECK(b.dim() == 2, "b must be a matrix"); |
| 184 | + |
| 185 | + // Get dimensions |
| 186 | + auto M = a.size(0); |
| 187 | + auto K = a.size(1); |
| 188 | + auto N = b.size(1); |
| 189 | + |
| 190 | + TORCH_CHECK(b.size(0) == K, |
| 191 | + "Incompatible matrix dimensions: a is ", M, "x", K, " but b is ", b.size(0), "x", N); |
| 192 | + |
| 193 | + // Needed for TMA store |
| 194 | + TORCH_CHECK(N % 8 == 0, "N must be a multiple of 16 but got, ", N); |
| 195 | + |
| 196 | + // Check 16-byte alignment for input tensors |
| 197 | + TORCH_CHECK( |
| 198 | + reinterpret_cast<std::uintptr_t>(a.data_ptr()) % 16 == 0, |
| 199 | + "Input tensor 'a' must be 16-byte aligned"); |
| 200 | + TORCH_CHECK( |
| 201 | + reinterpret_cast<std::uintptr_t>(b.data_ptr()) % 16 == 0, |
| 202 | + "Input tensor 'b' must be 16-byte aligned"); |
| 203 | + |
| 204 | + auto ceil_div = [](auto a, auto b) { return (a + b - 1) / b; }; |
| 205 | + auto num_k_blocks = ceil_div(K, 32); |
| 206 | + // For a_scale, we expect elements or M* ceil(K/32) elements |
| 207 | + auto expected_a_scale_size = 128 * ceil_div(M, 128) * num_k_blocks; |
| 208 | + TORCH_CHECK(a_scale.numel() == expected_a_scale_size, "Expected b_scale_size to be ", expected_a_scale_size, " but got ", a_scale.numel()); |
| 209 | + |
| 210 | + // For b_scale, we expect N * ceil(K/32) elements |
| 211 | + auto expected_b_scale_size = 128 * ceil_div(N, 128) * num_k_blocks; |
| 212 | + TORCH_CHECK(b_scale.numel() == expected_b_scale_size, "Expected a_scale_size to be ", expected_b_scale_size, " but got ", b_scale.numel()); |
| 213 | + |
| 214 | + // Check tensor strides for optimal memory layout |
| 215 | + TORCH_CHECK( |
| 216 | + a.stride(1) == 1, |
| 217 | + "Input tensor 'a' must be contiguous in the K dimension (row-major)"); |
| 218 | + TORCH_CHECK( |
| 219 | + b.stride(0) == 1, |
| 220 | + "Input tensor 'b' must be contiguous in the K dimension (column-major)"); |
| 221 | +} |
| 222 | + |
| 223 | + |
| 224 | +at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale, |
| 225 | + at::Tensor b_scale) { |
| 226 | +#if defined(BUILD_MX_KERNELS_CUTLASS) |
| 227 | + validate(a, b, a_scale, b_scale); |
| 228 | + |
| 229 | + auto out = |
| 230 | + at::empty({a.size(0), b.size(1)}, a.options().dtype(at::kBFloat16)); |
| 231 | + using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>; |
| 232 | + using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>; |
| 233 | + using ElementD = cutlass::bfloat16_t; |
| 234 | + |
| 235 | + using MmaTileShape = Shape<_128,_128,_128>; |
| 236 | + using ClusterShape = Shape<_2,_1,_1>; |
| 237 | + using PerSmTileShape_MNK = Shape<_128,_128,_128>; |
| 238 | + |
| 239 | + run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out); |
| 240 | + return out; |
| 241 | + #else |
| 242 | + TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); |
| 243 | + return at::Tensor{}; |
| 244 | +#endif |
| 245 | +} |
| 246 | + |
| 247 | +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { |
| 248 | + m.impl("torchao::mx_fp8_bf16", &mx_fp8_bf16); |
| 249 | +} |
| 250 | + |
| 251 | +} // namespace torchao |
0 commit comments