-
Notifications
You must be signed in to change notification settings - Fork 213
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
stack-info: PR: #1637, branch: drisspg/stack/31
- Loading branch information
Showing
5 changed files
with
442 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import pytest | ||
import torch | ||
|
||
from torchao.float8.float8_utils import compute_error | ||
from torchao.ops import mx_fp8_bf16 | ||
from torchao.prototype.mx_formats.mx_tensor import MXTensor | ||
from torchao.prototype.mx_formats.utils import to_blocked | ||
from torchao.utils import ( | ||
TORCH_VERSION_AT_LEAST_2_4, | ||
is_sm_at_least_100, | ||
) | ||
|
||
if not TORCH_VERSION_AT_LEAST_2_4: | ||
pytest.skip("Unsupported PyTorch version", allow_module_level=True) | ||
|
||
|
||
def run_matrix_test(M: int, K: int, N: int) -> float: | ||
""" | ||
Run matrix multiplication test with given dimensions. | ||
Args: | ||
M, K, N: Matrix dimensions | ||
Returns: | ||
float: SQNR (Signal-to-Quantization-Noise Ratio) value | ||
""" | ||
dtype = torch.bfloat16 | ||
device = torch.device("cuda") | ||
|
||
# Initialize matrices | ||
a = torch.rand((M, K), dtype=dtype, device=device) | ||
b = torch.rand((N, K), dtype=dtype, device=device) | ||
|
||
# Convert to MX format | ||
a_mx = MXTensor.to_mx(a, torch.float8_e4m3fn, 32) | ||
b_mx = MXTensor.to_mx(b, torch.float8_e4m3fn, 32) | ||
|
||
a_fp8 = a_mx._data | ||
b_fp8 = b_mx._data | ||
assert b_fp8.is_contiguous() | ||
b_fp8 = b_fp8.transpose(-1, -2) | ||
|
||
# Get scales | ||
a_scale_e8 = a_mx._scale_e8m0.view(M, K // 32) | ||
b_scale_e8 = b_mx._scale_e8m0.view(N, K // 32) | ||
|
||
a_scale_block = to_blocked(a_scale_e8) | ||
b_scale_block = to_blocked(b_scale_e8) | ||
|
||
# Get reference output | ||
out_hp = a_mx.to_dtype(torch.bfloat16) @ b_mx.to_dtype(torch.bfloat16).transpose( | ||
-1, -2 | ||
) | ||
|
||
# Run implementation | ||
out_e8_fp8 = mx_fp8_bf16(a_fp8, b_fp8, a_scale_block, b_scale_block) | ||
|
||
# Calculate metrics | ||
sqnr = compute_error(out_hp, out_e8_fp8) | ||
|
||
return sqnr.item() | ||
|
||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
@pytest.mark.skipif( | ||
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for mxfloat8" | ||
) | ||
@pytest.mark.parametrize( | ||
"size", | ||
[ | ||
# Small matrices | ||
(128, 128, 128), | ||
(256, 256, 256), | ||
(384, 384, 384), | ||
# Medium matrices | ||
(512, 512, 512), | ||
(640, 640, 640), | ||
(768, 768, 768), | ||
# Large matrices | ||
(896, 896, 896), | ||
(1024, 1024, 1024), | ||
# Very large matrices | ||
(8192, 8192, 8192), | ||
# Non-square matrices | ||
(128, 256, 384), | ||
(256, 384, 512), | ||
(384, 512, 640), | ||
# Non-aligned matrices | ||
(129, 256, 384), | ||
(256, 384, 536), | ||
(133, 512, 528), | ||
], | ||
ids=lambda x: f"{x[0]}x{x[1]}x{x[2]}", | ||
) | ||
def test_matrix_multiplication(size): | ||
""" | ||
Test matrix multiplication with various dimensions. | ||
Verifies that the SQNR meets minimum quality threshold. | ||
""" | ||
M, K, N = size | ||
sqnr = run_matrix_test(M, K, N) | ||
assert sqnr >= 80.0, f"SQNR {sqnr} below threshold for dims {M}x{K}x{N}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,251 @@ | ||
#include <torch/library.h> | ||
|
||
#include <ATen/ATen.h> | ||
#include <ATen/core/Tensor.h> | ||
#include <ATen/cuda/CUDAUtils.h> | ||
#include <c10/util/Exception.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/cuda/CUDAException.h> | ||
|
||
#if defined(TORCHAO_USE_CUTLASS) && !defined(_WIN32) && \ | ||
defined(CUDA_VERSION) && (CUDA_VERSION >= 12080) | ||
#define BUILD_MX_KERNELS_CUTLASS | ||
#endif | ||
|
||
#if defined(BUILD_MX_KERNELS_CUTLASS) | ||
|
||
#include "cute/tensor.hpp" | ||
#include "cutlass/detail/sm100_blockscaled_layout.hpp" | ||
#include "cutlass/epilogue/collective/collective_builder.hpp" | ||
#include "cutlass/epilogue/thread/linear_combination.h" | ||
#include "cutlass/gemm/collective/collective_builder.hpp" | ||
#include "cutlass/gemm/device/gemm_universal_adapter.h" | ||
#include "cutlass/util/packed_stride.hpp" | ||
|
||
|
||
#endif | ||
|
||
namespace torchao { | ||
|
||
#if defined(BUILD_MX_KERNELS_CUTLASS) | ||
namespace { | ||
|
||
using namespace cute; | ||
|
||
template<typename Element> | ||
constexpr int GetAlignment() { | ||
if constexpr (std::is_same_v<Element, cutlass::nv_float4_t<cutlass::float_e2m1_t>>) | ||
return 32; | ||
return 16; | ||
} | ||
|
||
template <typename ElementA, | ||
typename ElementB, | ||
typename ElementD, | ||
typename MmaTileShape, | ||
typename ClusterShape, | ||
typename PerSmTileShape_MNK> | ||
void run_gemm(at::Tensor& a, at::Tensor& b, at::Tensor& a_scale, | ||
at::Tensor& b_scale, at::Tensor& out) { | ||
int M = a.size(0); | ||
int K = a.size(1); | ||
int N = b.size(1); | ||
|
||
// A matrix configuration | ||
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand | ||
constexpr int AlignmentA = GetAlignment<ElementA>(); // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) | ||
|
||
// B matrix configuration | ||
using LayoutBTag = cutlass::layout::ColumnMajor; // Layout type for B matrix operand | ||
constexpr int AlignmentB = GetAlignment<ElementB>(); // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) | ||
|
||
// C/D matrix configuration | ||
using ElementC = cutlass::bfloat16_t; // Element type for C matrix operand | ||
using LayoutCTag = cutlass::layout::RowMajor; // Layout type for C matrix operand | ||
using LayoutDTag = cutlass::layout::RowMajor; // Layout type for D matrix operand | ||
constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value; // Memory access granularity/alignment of D matrix in units of elements (up to 16 bytes) | ||
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) | ||
// Kernel functional config | ||
using ElementAccumulator = float; // Element type for internal accumulation | ||
using ArchTag = cutlass::arch::Sm100; // Tag indicating the minimum SM that supports the intended feature | ||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; // Operator class tag | ||
|
||
|
||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< | ||
ArchTag, OperatorClass, | ||
PerSmTileShape_MNK, ClusterShape, | ||
cutlass::epilogue::collective::EpilogueTileAuto, | ||
ElementAccumulator, ElementAccumulator, | ||
ElementC, LayoutCTag, AlignmentC, | ||
ElementD, LayoutDTag, AlignmentD, | ||
cutlass::epilogue::collective::EpilogueScheduleAuto // Epilogue schedule policy | ||
>::CollectiveOp; | ||
|
||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< | ||
ArchTag, OperatorClass, | ||
ElementA, LayoutATag, AlignmentA, | ||
ElementB, LayoutBTag, AlignmentB, | ||
ElementAccumulator, | ||
MmaTileShape, ClusterShape, | ||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>, | ||
cutlass::gemm::collective::KernelScheduleAuto // Kernel schedule policy. Auto or using targeted scheduling policy | ||
>::CollectiveOp; | ||
|
||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal< | ||
Shape<int,int,int,int>, // Indicates ProblemShape | ||
CollectiveMainloop, | ||
CollectiveEpilogue, | ||
void>; | ||
|
||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>; | ||
|
||
// Reference device GEMM implementation type | ||
using StrideA = typename Gemm::GemmKernel::StrideA; | ||
using StrideB = typename Gemm::GemmKernel::StrideB; | ||
using StrideC = typename Gemm::GemmKernel::StrideC; | ||
using StrideD = typename Gemm::GemmKernel::StrideD; | ||
using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; | ||
using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; | ||
using Sm100BlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; | ||
|
||
// Initialize strides using packed stride configuration | ||
auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, make_shape(M, K, 1)); | ||
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, make_shape(N, K, 1)); | ||
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, make_shape(M, N, 1)); | ||
|
||
// Initialize scale factor layouts using block scaled configuration | ||
auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA(make_shape(M, N, K, 1)); | ||
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(make_shape(M, N, K, 1)); | ||
|
||
using DtypeA = typename ElementA::DataType; | ||
using DtypeB = typename ElementB::DataType; | ||
using DtypeScaleA = typename ElementA::ScaleFactorType; | ||
using DtypeScaleB = typename ElementB::ScaleFactorType; | ||
using DtypeOut = ElementD; | ||
|
||
Gemm gemm; | ||
|
||
auto A_ptr = reinterpret_cast<DtypeA*>(a.data_ptr()); | ||
auto B_ptr = reinterpret_cast<DtypeB*>(b.data_ptr()); | ||
auto SFA_ptr = reinterpret_cast<DtypeScaleA*>(a_scale.data_ptr()); | ||
auto SFB_ptr = reinterpret_cast<DtypeScaleB*>(b_scale.data_ptr()); | ||
auto out_ptr = reinterpret_cast<DtypeOut*>(out.data_ptr()); | ||
|
||
typename Gemm::Arguments arguments{ | ||
cutlass::gemm::GemmUniversalMode::kGemm, | ||
{M, N, K, 1}, | ||
{ // Mainloop arguments | ||
A_ptr, stride_A, | ||
B_ptr, stride_B, | ||
SFA_ptr, layout_SFA, | ||
SFB_ptr, layout_SFB | ||
}, | ||
{ // Epilogue arguments | ||
{1.0, 0.0}, | ||
nullptr, StrideC{}, // No bias for now | ||
out_ptr, stride_D | ||
} | ||
}; | ||
|
||
// arguments.scheduler.max_swizzle_size = 8; | ||
|
||
// Check the problem size is supported or not | ||
cutlass::Status status = gemm.can_implement(arguments); | ||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot implement"); | ||
// Allocate workspace memory | ||
size_t workspace_size = Gemm::get_workspace_size(arguments); | ||
auto workspace = a.new_empty( | ||
{static_cast<int64_t>(workspace_size)}, | ||
at::TensorOptions().dtype(at::kByte)); | ||
|
||
|
||
// Initialize CUTLASS kernel with arguments and workspace pointer | ||
status = gemm.initialize(arguments, workspace.data_ptr()); | ||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot initialize"); | ||
|
||
status = gemm.run(at::cuda::getCurrentCUDAStream()); | ||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Cutlass cannot run", cutlass::cutlassGetStatusString(status)); | ||
|
||
C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||
|
||
} | ||
} | ||
#endif | ||
|
||
void validate(at::Tensor a, at::Tensor b, at::Tensor a_scale, at::Tensor b_scale){ | ||
TORCH_CHECK(a.is_cuda(), "a must be CUDA tensor"); | ||
TORCH_CHECK(b.is_cuda(), "b must be CUDA tensor"); | ||
TORCH_CHECK(a_scale.is_cuda(), "a_scale must be CUDA tensor"); | ||
TORCH_CHECK(b_scale.is_cuda(), "b_scale must be CUDA tensor"); | ||
|
||
// Check matrix dimensions | ||
TORCH_CHECK(a.dim() == 2, "a must be a matrix"); | ||
TORCH_CHECK(b.dim() == 2, "b must be a matrix"); | ||
|
||
// Get dimensions | ||
auto M = a.size(0); | ||
auto K = a.size(1); | ||
auto N = b.size(1); | ||
|
||
TORCH_CHECK(b.size(0) == K, | ||
"Incompatible matrix dimensions: a is ", M, "x", K, " but b is ", b.size(0), "x", N); | ||
|
||
// Needed for TMA store | ||
TORCH_CHECK(N % 8 == 0, "N must be a multiple of 16 but got, ", N); | ||
|
||
// Check 16-byte alignment for input tensors | ||
TORCH_CHECK( | ||
reinterpret_cast<std::uintptr_t>(a.data_ptr()) % 16 == 0, | ||
"Input tensor 'a' must be 16-byte aligned"); | ||
TORCH_CHECK( | ||
reinterpret_cast<std::uintptr_t>(b.data_ptr()) % 16 == 0, | ||
"Input tensor 'b' must be 16-byte aligned"); | ||
|
||
auto ceil_div = [](auto a, auto b) { return (a + b - 1) / b; }; | ||
auto num_k_blocks = ceil_div(K, 32); | ||
// For a_scale, we expect elements or M* ceil(K/32) elements | ||
auto expected_a_scale_size = 128 * ceil_div(M, 128) * num_k_blocks; | ||
TORCH_CHECK(a_scale.numel() == expected_a_scale_size, "Expected b_scale_size to be ", expected_a_scale_size, " but got ", a_scale.numel()); | ||
|
||
// For b_scale, we expect N * ceil(K/32) elements | ||
auto expected_b_scale_size = 128 * ceil_div(N, 128) * num_k_blocks; | ||
TORCH_CHECK(b_scale.numel() == expected_b_scale_size, "Expected a_scale_size to be ", expected_b_scale_size, " but got ", b_scale.numel()); | ||
|
||
// Check tensor strides for optimal memory layout | ||
TORCH_CHECK( | ||
a.stride(1) == 1, | ||
"Input tensor 'a' must be contiguous in the K dimension (row-major)"); | ||
TORCH_CHECK( | ||
b.stride(0) == 1, | ||
"Input tensor 'b' must be contiguous in the K dimension (column-major)"); | ||
} | ||
|
||
|
||
at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale, | ||
at::Tensor b_scale) { | ||
#if defined(BUILD_MX_KERNELS_CUTLASS) | ||
validate(a, b, a_scale, b_scale); | ||
|
||
auto out = | ||
at::empty({a.size(0), b.size(1)}, a.options().dtype(at::kBFloat16)); | ||
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>; | ||
using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>; | ||
using ElementD = cutlass::bfloat16_t; | ||
|
||
using MmaTileShape = Shape<_128,_128,_128>; | ||
using ClusterShape = Shape<_2,_1,_1>; | ||
using PerSmTileShape_MNK = Shape<_128,_128,_128>; | ||
|
||
run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out); | ||
return out; | ||
#else | ||
TORCH_CHECK_NOT_IMPLEMENTED(false, __func__); | ||
return at::Tensor{}; | ||
#endif | ||
} | ||
|
||
TORCH_LIBRARY_IMPL(torchao, CUDA, m) { | ||
m.impl("torchao::mx_fp8_bf16", &mx_fp8_bf16); | ||
} | ||
|
||
} // namespace torchao |
Oops, something went wrong.