Skip to content

integrate mx dim1 triton kernel into MXLinear #1943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 68 additions & 20 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import copy
import itertools

import pytest
import torch
Expand All @@ -16,7 +15,12 @@
MXLinearConfig,
MXLinearRecipeName,
)
from torchao.prototype.mx_formats.constants import DTYPE_FP4, SUPPORTED_ELEM_DTYPES
from torchao.prototype.mx_formats.constants import (
DTYPE_FP4,
DTYPE_FP6_E2M3,
DTYPE_FP6_E3M2,
SUPPORTED_ELEM_DTYPES,
)
from torchao.prototype.mx_formats.mx_linear import (
MXInferenceLinear,
MXLinear,
Expand Down Expand Up @@ -48,38 +52,65 @@ def run_around_tests():

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"elem_dtype", itertools.product(SUPPORTED_ELEM_DTYPES, repeat=3)
"elem_dtype",
(
# test each dtype
(torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn),
(DTYPE_FP6_E3M2, DTYPE_FP6_E3M2, DTYPE_FP6_E3M2),
(DTYPE_FP6_E2M3, DTYPE_FP6_E2M3, DTYPE_FP6_E2M3),
(DTYPE_FP4, DTYPE_FP4, DTYPE_FP4),
# only test one type of mixed-dtype overrides, to save testing time
(torch.float8_e4m3fn, DTYPE_FP4, DTYPE_FP4),
),
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("input_shape", [(4, 8), (1, 4, 8), (1, 1, 4, 8)])
def test_linear_eager(elem_dtype, bias, input_shape):
@pytest.mark.parametrize("input_shape", [(128, 256), (1, 128, 256), (1, 1, 128, 256)])
@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True])
def test_linear_eager_vs_hp(
elem_dtype, bias, input_shape, use_fp8_dim1_cast_triton_kernel
):
"""
Smoke test for training linear module with mx weight, compares the following:
* baseline: float32
* experiment: emulated MX
"""
if use_fp8_dim1_cast_triton_kernel:
if elem_dtype != (
torch.float8_e4m3fn,
torch.float8_e4m3fn,
torch.float8_e4m3fn,
):
pytest.skip("unsupported configuration")
elif not is_sm_at_least_89():
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")

# elem_dtype is a tuple of (input, weight, gradient) dtypes.
grad_shape = list(input_shape)
grad_shape[-1] = 8
grad_shape[-1] = 256

m = nn.Sequential(
nn.Linear(8, 8, bias=bias, device="cuda"),
nn.Linear(256, 256, bias=bias, device="cuda", dtype=torch.bfloat16),
)
m_mx = copy.deepcopy(m)
config = MXLinearConfig(
block_size=4,
elem_dtype=elem_dtype[0],
elem_dtype_weight_override=elem_dtype[1],
elem_dtype_grad_output_override=elem_dtype[2],
use_fp8_dim1_cast_triton_kernel=use_fp8_dim1_cast_triton_kernel,
)
swap_linear_with_mx_linear(m_mx, config=config)

x_ref = torch.randn(*input_shape, device="cuda").requires_grad_()
x_ref = torch.randn(
*input_shape, device="cuda", dtype=torch.bfloat16
).requires_grad_()
x = copy.deepcopy(x_ref)
g = torch.randn(*grad_shape, device="cuda")
with torch.autocast("cuda", dtype=torch.bfloat16):
y_ref = m(x_ref)
y_mx = m_mx(x)

y_ref = m(x_ref)
y_mx = m_mx(x)

assert y_mx.dtype == x.dtype

y_ref.backward(g)
y_mx.backward(g)
Expand Down Expand Up @@ -112,7 +143,6 @@ def test_linear_eager(elem_dtype, bias, input_shape):
)
@pytest.mark.parametrize("mkn", [(128, 256, 512), (256, 512, 128), (512, 128, 256)])
def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn):
M, K, N = 128, 128, 128
M, K, N = mkn

x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda").requires_grad_()
Expand Down Expand Up @@ -143,9 +173,9 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn):
y_sqnr = compute_error(y_real, y_emulated)
w_sqnr = compute_error(m_real[0].weight.grad, m_emulated[0].weight.grad)
g_sqnr = compute_error(x_copy.grad, x.grad)
assert y_sqnr > 100.0, f"y_sqnr {y_sqnr} too low!"
assert w_sqnr > 100.0, f"w_sqnr {w_sqnr} too low!"
assert g_sqnr > 100.0, f"g_sqnr {g_sqnr} too low!"
assert y_sqnr > 90.0, f"y_sqnr {y_sqnr} too low!"
assert w_sqnr > 90.0, f"w_sqnr {w_sqnr} too low!"
assert g_sqnr > 90.0, f"g_sqnr {g_sqnr} too low!"


# TODO(future): enable compile support
Expand All @@ -169,6 +199,7 @@ def test_activation_checkpointing():


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("hp_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize(
"recipe_name",
[
Expand All @@ -182,7 +213,8 @@ def test_activation_checkpointing():
@pytest.mark.parametrize("bias", [False, True])
# TODO(future PR): figure out why torch.compile does not match eager when
# autocast is on
def test_linear_compile(recipe_name, bias):
@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True])
def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_kernel):
"""
Verify that compile does not change numerics of MX linear fw + bw
"""
Expand All @@ -198,20 +230,36 @@ def test_linear_compile(recipe_name, bias):
# TODO(future PR): fix this, things are clearly broken with bias=True
pytest.skip("this test is broken for non-emulated recipes with bias=True")

if use_fp8_dim1_cast_triton_kernel:
if recipe_name not in ("mxfp8_emulated", "mxfp8_cublas", "mxfp8_cutlass"):
pytest.skip("unsupported configuration")
if not is_sm_at_least_89():
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
if hp_dtype != torch.bfloat16:
pytest.skip("unsupported configuration")

if hp_dtype == torch.bfloat16 and recipe_name != "mxfp8_cublas":
# TODO(future PR): properly enable float32 + bfloat16 for every
# recipe, this needs a cleanup of out_dtype (needs to match in-hp-dtype, even
# if the underlying gemm kernel only supports bf16 output)
pytest.skip("unsupported configuration")

M, K, N = 128, 256, 512
input_shape = (M, K)
grad_shape = (M, N)
m_mx = nn.Sequential(
nn.Linear(K, N, bias=bias, device="cuda"),
nn.Linear(K, N, bias=bias, device="cuda", dtype=hp_dtype),
)
config = MXLinearConfig.from_recipe_name(recipe_name)
config.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel

swap_linear_with_mx_linear(m_mx, config=config)
m_mx_c = copy.deepcopy(m_mx)
m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor")

x_ref = torch.randn(*input_shape, device="cuda").requires_grad_()
x_ref = torch.randn(*input_shape, device="cuda", dtype=hp_dtype).requires_grad_()
x = copy.deepcopy(x_ref)
g = torch.randn(*grad_shape, device="cuda")
g = torch.randn(*grad_shape, device="cuda", dtype=hp_dtype)

y_ref = m_mx(x_ref)
y = m_mx_c(x)
Expand Down Expand Up @@ -283,7 +331,7 @@ def test_inference_compile_simple(elem_dtype):
if elem_dtype is torch.float8_e4m3fn:
assert sqnr >= 20.0
else:
assert sqnr >= 13.5
assert sqnr >= 11.5


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand Down
5 changes: 5 additions & 0 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ class MXLinearConfig:
# on the given hardware an exception will be thrown
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED

# If True, uses a custom triton kernel for cast to mxfp8 across dim1
# TODO(1945): remove this config option once torch.compile gives us
# a fast kernel
use_fp8_dim1_cast_triton_kernel: bool = False

# If True, uses a custom triton kernel for fp4 dequantize
use_fp4_custom_triton_dequant_kernel: bool = False

Expand Down
6 changes: 4 additions & 2 deletions torchao/prototype/mx_formats/custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,7 @@ def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor:
if TORCH_VERSION_AT_LEAST_2_8 and has_triton():
import triton
import triton.language as tl
from torch.library import triton_op, wrap_triton

@triton.jit
def _triton_calculate_scale(x, axis):
Expand Down Expand Up @@ -1298,8 +1299,9 @@ def to_mxfp8_dim1_kernel(
# TODO(future): mask this store
tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0)

@triton_op("torchao::triton_to_mxfp8_dim1", mutates_args={})
def triton_to_mxfp8_dim1(
x, inner_block_size=32
x: torch.Tensor, inner_block_size: int = 32
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Input:
Expand Down Expand Up @@ -1343,7 +1345,7 @@ def triton_to_mxfp8_dim1(
)

# Launch the kernel
to_mxfp8_dim1_kernel[grid](
wrap_triton(to_mxfp8_dim1_kernel)[grid](
x_ptr=x,
output_col_major_ptr=output_col_major,
col_scale_ptr=col_scale,
Expand Down
96 changes: 75 additions & 21 deletions torchao/prototype/mx_formats/mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch.nn.functional as F

from torchao.prototype.mx_formats.config import MXGemmKernelChoice, MXLinearConfig
from torchao.prototype.mx_formats.custom_cast import triton_to_mxfp8_dim1
from torchao.prototype.mx_formats.mx_tensor import MXTensor


Expand All @@ -37,13 +38,15 @@ def forward(
grad_elem_dtype: Any,
block_size: int,
gemm_kernel_choice: MXGemmKernelChoice,
use_fp8_dim1_cast_triton_kernel: bool,
):
ctx.save_for_backward(input_hp, weight_hp)
ctx.in_elem_dtype = in_elem_dtype
ctx.w_elem_dtype = w_elem_dtype
ctx.grad_elem_dtype = grad_elem_dtype
ctx.block_size = block_size
ctx.gemm_kernel_choice = gemm_kernel_choice
ctx.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel

# input @ weight_t = output
input_orig_shape = input_hp.shape
Expand All @@ -63,12 +66,12 @@ def forward(
@staticmethod
def backward(ctx, grad_output_hp: torch.Tensor):
input_hp, weight_hp = ctx.saved_tensors
weight_hp_t_c = weight_hp.t().contiguous()
in_elem_dtype = ctx.in_elem_dtype
w_elem_dtype = ctx.w_elem_dtype
grad_elem_dtype = ctx.grad_elem_dtype
block_size = ctx.block_size
gemm_kernel_choice = ctx.gemm_kernel_choice
use_fp8_dim1_cast_triton_kernel = ctx.use_fp8_dim1_cast_triton_kernel

grad_output_orig_shape = grad_output_hp.shape
grad_output_hp_r = grad_output_hp.reshape(-1, grad_output_orig_shape[-1])
Expand All @@ -83,34 +86,84 @@ def backward(ctx, grad_output_hp: torch.Tensor):
block_size,
gemm_kernel_choice=gemm_kernel_choice,
)
weight_mx_dim1 = MXTensor.to_mx(
weight_hp_t_c,
w_elem_dtype,
block_size,
gemm_kernel_choice=gemm_kernel_choice,
)

if use_fp8_dim1_cast_triton_kernel:
weight_mx_dim1_data, weight_mx_dim1_scale = triton_to_mxfp8_dim1(
weight_hp, block_size
)
weight_mx_dim1 = MXTensor(
weight_mx_dim1_scale.view(torch.uint8).reshape(-1),
weight_mx_dim1_data.t(),
w_elem_dtype,
block_size,
weight_hp.dtype,
False,
gemm_kernel_choice,
False,
)

else:
weight_hp_t_c = weight_hp.t().contiguous()
weight_mx_dim1 = MXTensor.to_mx(
weight_hp_t_c,
w_elem_dtype,
block_size,
gemm_kernel_choice=gemm_kernel_choice,
)
grad_input = torch.mm(grad_output_mx_dim0, weight_mx_dim1.t())
grad_input = grad_input.reshape(
*grad_output_orig_shape[:-1], grad_input.shape[-1]
)

# input_t @ grad_output = grad_weight
grad_output_mx_dim1 = MXTensor.to_mx(
grad_output_hp_r.t().contiguous(),
grad_elem_dtype,
block_size,
gemm_kernel_choice=gemm_kernel_choice,
)
input_t_mx_dim0_tmp = MXTensor.to_mx(
input_hp_r.t().contiguous(),
in_elem_dtype,
block_size,
gemm_kernel_choice=gemm_kernel_choice,
)
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
if use_fp8_dim1_cast_triton_kernel:
grad_output_mx_dim1_data, grad_output_mx_dim1_scale = triton_to_mxfp8_dim1(
grad_output_hp_r, block_size
)
grad_output_mx_dim1 = MXTensor(
grad_output_mx_dim1_scale.view(torch.uint8).reshape(-1),
grad_output_mx_dim1_data.t(),
grad_elem_dtype,
block_size,
grad_output_hp_r.dtype,
False,
gemm_kernel_choice,
False,
)
else:
grad_output_mx_dim1 = MXTensor.to_mx(
grad_output_hp_r.t().contiguous(),
grad_elem_dtype,
block_size,
gemm_kernel_choice=gemm_kernel_choice,
)

if use_fp8_dim1_cast_triton_kernel:
input_t_mx_dim0_tmp_data, input_t_mx_dim0_tmp_scale = triton_to_mxfp8_dim1(
input_hp_r, block_size
)
input_t_mx_dim0_tmp = MXTensor(
input_t_mx_dim0_tmp_scale.view(torch.uint8).reshape(-1),
input_t_mx_dim0_tmp_data.t(),
in_elem_dtype,
block_size,
input_hp_r.dtype,
False,
gemm_kernel_choice,
False,
)
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
else:
input_t_mx_dim0_tmp = MXTensor.to_mx(
input_hp_r.t().contiguous(),
in_elem_dtype,
block_size,
gemm_kernel_choice=gemm_kernel_choice,
)
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
grad_weight = torch.mm(grad_output_mx_dim1, input_t_mx_dim0)

return grad_input, grad_weight, None, None, None, None, None
return grad_input, grad_weight, None, None, None, None, None, None


class MXLinear(torch.nn.Linear):
Expand Down Expand Up @@ -154,6 +207,7 @@ def forward(self, x):
config.elem_dtype_grad_output_override or config.elem_dtype,
config.block_size,
config.gemm_kernel_choice,
config.use_fp8_dim1_cast_triton_kernel,
)
if self.bias is not None:
y = y + self.bias
Expand Down
Loading