Skip to content

Commit fc42ecb

Browse files
authored
integrate mx dim1 triton kernel into MXLinear (#1943)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent d32afef commit fc42ecb

File tree

4 files changed

+152
-43
lines changed

4 files changed

+152
-43
lines changed

test/prototype/mx_formats/test_mx_linear.py

+68-20
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import copy
8-
import itertools
98

109
import pytest
1110
import torch
@@ -16,7 +15,12 @@
1615
MXLinearConfig,
1716
MXLinearRecipeName,
1817
)
19-
from torchao.prototype.mx_formats.constants import DTYPE_FP4, SUPPORTED_ELEM_DTYPES
18+
from torchao.prototype.mx_formats.constants import (
19+
DTYPE_FP4,
20+
DTYPE_FP6_E2M3,
21+
DTYPE_FP6_E3M2,
22+
SUPPORTED_ELEM_DTYPES,
23+
)
2024
from torchao.prototype.mx_formats.mx_linear import (
2125
MXInferenceLinear,
2226
MXLinear,
@@ -48,38 +52,65 @@ def run_around_tests():
4852

4953
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
5054
@pytest.mark.parametrize(
51-
"elem_dtype", itertools.product(SUPPORTED_ELEM_DTYPES, repeat=3)
55+
"elem_dtype",
56+
(
57+
# test each dtype
58+
(torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn),
59+
(DTYPE_FP6_E3M2, DTYPE_FP6_E3M2, DTYPE_FP6_E3M2),
60+
(DTYPE_FP6_E2M3, DTYPE_FP6_E2M3, DTYPE_FP6_E2M3),
61+
(DTYPE_FP4, DTYPE_FP4, DTYPE_FP4),
62+
# only test one type of mixed-dtype overrides, to save testing time
63+
(torch.float8_e4m3fn, DTYPE_FP4, DTYPE_FP4),
64+
),
5265
)
5366
@pytest.mark.parametrize("bias", [True, False])
54-
@pytest.mark.parametrize("input_shape", [(4, 8), (1, 4, 8), (1, 1, 4, 8)])
55-
def test_linear_eager(elem_dtype, bias, input_shape):
67+
@pytest.mark.parametrize("input_shape", [(128, 256), (1, 128, 256), (1, 1, 128, 256)])
68+
@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True])
69+
def test_linear_eager_vs_hp(
70+
elem_dtype, bias, input_shape, use_fp8_dim1_cast_triton_kernel
71+
):
5672
"""
5773
Smoke test for training linear module with mx weight, compares the following:
5874
* baseline: float32
5975
* experiment: emulated MX
6076
"""
77+
if use_fp8_dim1_cast_triton_kernel:
78+
if elem_dtype != (
79+
torch.float8_e4m3fn,
80+
torch.float8_e4m3fn,
81+
torch.float8_e4m3fn,
82+
):
83+
pytest.skip("unsupported configuration")
84+
elif not is_sm_at_least_89():
85+
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
86+
6187
# elem_dtype is a tuple of (input, weight, gradient) dtypes.
6288
grad_shape = list(input_shape)
63-
grad_shape[-1] = 8
89+
grad_shape[-1] = 256
6490

6591
m = nn.Sequential(
66-
nn.Linear(8, 8, bias=bias, device="cuda"),
92+
nn.Linear(256, 256, bias=bias, device="cuda", dtype=torch.bfloat16),
6793
)
6894
m_mx = copy.deepcopy(m)
6995
config = MXLinearConfig(
7096
block_size=4,
7197
elem_dtype=elem_dtype[0],
7298
elem_dtype_weight_override=elem_dtype[1],
7399
elem_dtype_grad_output_override=elem_dtype[2],
100+
use_fp8_dim1_cast_triton_kernel=use_fp8_dim1_cast_triton_kernel,
74101
)
75102
swap_linear_with_mx_linear(m_mx, config=config)
76103

77-
x_ref = torch.randn(*input_shape, device="cuda").requires_grad_()
104+
x_ref = torch.randn(
105+
*input_shape, device="cuda", dtype=torch.bfloat16
106+
).requires_grad_()
78107
x = copy.deepcopy(x_ref)
79108
g = torch.randn(*grad_shape, device="cuda")
80-
with torch.autocast("cuda", dtype=torch.bfloat16):
81-
y_ref = m(x_ref)
82-
y_mx = m_mx(x)
109+
110+
y_ref = m(x_ref)
111+
y_mx = m_mx(x)
112+
113+
assert y_mx.dtype == x.dtype
83114

84115
y_ref.backward(g)
85116
y_mx.backward(g)
@@ -112,7 +143,6 @@ def test_linear_eager(elem_dtype, bias, input_shape):
112143
)
113144
@pytest.mark.parametrize("mkn", [(128, 256, 512), (256, 512, 128), (512, 128, 256)])
114145
def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn):
115-
M, K, N = 128, 128, 128
116146
M, K, N = mkn
117147

118148
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda").requires_grad_()
@@ -143,9 +173,9 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn):
143173
y_sqnr = compute_error(y_real, y_emulated)
144174
w_sqnr = compute_error(m_real[0].weight.grad, m_emulated[0].weight.grad)
145175
g_sqnr = compute_error(x_copy.grad, x.grad)
146-
assert y_sqnr > 100.0, f"y_sqnr {y_sqnr} too low!"
147-
assert w_sqnr > 100.0, f"w_sqnr {w_sqnr} too low!"
148-
assert g_sqnr > 100.0, f"g_sqnr {g_sqnr} too low!"
176+
assert y_sqnr > 90.0, f"y_sqnr {y_sqnr} too low!"
177+
assert w_sqnr > 90.0, f"w_sqnr {w_sqnr} too low!"
178+
assert g_sqnr > 90.0, f"g_sqnr {g_sqnr} too low!"
149179

150180

151181
# TODO(future): enable compile support
@@ -169,6 +199,7 @@ def test_activation_checkpointing():
169199

170200

171201
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
202+
@pytest.mark.parametrize("hp_dtype", [torch.float32, torch.bfloat16])
172203
@pytest.mark.parametrize(
173204
"recipe_name",
174205
[
@@ -182,7 +213,8 @@ def test_activation_checkpointing():
182213
@pytest.mark.parametrize("bias", [False, True])
183214
# TODO(future PR): figure out why torch.compile does not match eager when
184215
# autocast is on
185-
def test_linear_compile(recipe_name, bias):
216+
@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True])
217+
def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_kernel):
186218
"""
187219
Verify that compile does not change numerics of MX linear fw + bw
188220
"""
@@ -198,20 +230,36 @@ def test_linear_compile(recipe_name, bias):
198230
# TODO(future PR): fix this, things are clearly broken with bias=True
199231
pytest.skip("this test is broken for non-emulated recipes with bias=True")
200232

233+
if use_fp8_dim1_cast_triton_kernel:
234+
if recipe_name not in ("mxfp8_emulated", "mxfp8_cublas", "mxfp8_cutlass"):
235+
pytest.skip("unsupported configuration")
236+
if not is_sm_at_least_89():
237+
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
238+
if hp_dtype != torch.bfloat16:
239+
pytest.skip("unsupported configuration")
240+
241+
if hp_dtype == torch.bfloat16 and recipe_name != "mxfp8_cublas":
242+
# TODO(future PR): properly enable float32 + bfloat16 for every
243+
# recipe, this needs a cleanup of out_dtype (needs to match in-hp-dtype, even
244+
# if the underlying gemm kernel only supports bf16 output)
245+
pytest.skip("unsupported configuration")
246+
201247
M, K, N = 128, 256, 512
202248
input_shape = (M, K)
203249
grad_shape = (M, N)
204250
m_mx = nn.Sequential(
205-
nn.Linear(K, N, bias=bias, device="cuda"),
251+
nn.Linear(K, N, bias=bias, device="cuda", dtype=hp_dtype),
206252
)
207253
config = MXLinearConfig.from_recipe_name(recipe_name)
254+
config.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel
255+
208256
swap_linear_with_mx_linear(m_mx, config=config)
209257
m_mx_c = copy.deepcopy(m_mx)
210258
m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor")
211259

212-
x_ref = torch.randn(*input_shape, device="cuda").requires_grad_()
260+
x_ref = torch.randn(*input_shape, device="cuda", dtype=hp_dtype).requires_grad_()
213261
x = copy.deepcopy(x_ref)
214-
g = torch.randn(*grad_shape, device="cuda")
262+
g = torch.randn(*grad_shape, device="cuda", dtype=hp_dtype)
215263

216264
y_ref = m_mx(x_ref)
217265
y = m_mx_c(x)
@@ -283,7 +331,7 @@ def test_inference_compile_simple(elem_dtype):
283331
if elem_dtype is torch.float8_e4m3fn:
284332
assert sqnr >= 20.0
285333
else:
286-
assert sqnr >= 13.5
334+
assert sqnr >= 11.5
287335

288336

289337
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")

torchao/prototype/mx_formats/config.py

+5
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ class MXLinearConfig:
5858
# on the given hardware an exception will be thrown
5959
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED
6060

61+
# If True, uses a custom triton kernel for cast to mxfp8 across dim1
62+
# TODO(1945): remove this config option once torch.compile gives us
63+
# a fast kernel
64+
use_fp8_dim1_cast_triton_kernel: bool = False
65+
6166
# If True, uses a custom triton kernel for fp4 dequantize
6267
use_fp4_custom_triton_dequant_kernel: bool = False
6368

torchao/prototype/mx_formats/custom_cast.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1087,6 +1087,7 @@ def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor:
10871087
if TORCH_VERSION_AT_LEAST_2_8 and has_triton():
10881088
import triton
10891089
import triton.language as tl
1090+
from torch.library import triton_op, wrap_triton
10901091

10911092
@triton.jit
10921093
def _triton_calculate_scale(x, axis):
@@ -1298,8 +1299,9 @@ def to_mxfp8_dim1_kernel(
12981299
# TODO(future): mask this store
12991300
tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0)
13001301

1302+
@triton_op("torchao::triton_to_mxfp8_dim1", mutates_args={})
13011303
def triton_to_mxfp8_dim1(
1302-
x, inner_block_size=32
1304+
x: torch.Tensor, inner_block_size: int = 32
13031305
) -> Tuple[torch.Tensor, torch.Tensor]:
13041306
"""
13051307
Input:
@@ -1343,7 +1345,7 @@ def triton_to_mxfp8_dim1(
13431345
)
13441346

13451347
# Launch the kernel
1346-
to_mxfp8_dim1_kernel[grid](
1348+
wrap_triton(to_mxfp8_dim1_kernel)[grid](
13471349
x_ptr=x,
13481350
output_col_major_ptr=output_col_major,
13491351
col_scale_ptr=col_scale,

torchao/prototype/mx_formats/mx_linear.py

+75-21
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch.nn.functional as F
1515

1616
from torchao.prototype.mx_formats.config import MXGemmKernelChoice, MXLinearConfig
17+
from torchao.prototype.mx_formats.custom_cast import triton_to_mxfp8_dim1
1718
from torchao.prototype.mx_formats.mx_tensor import MXTensor
1819

1920

@@ -37,13 +38,15 @@ def forward(
3738
grad_elem_dtype: Any,
3839
block_size: int,
3940
gemm_kernel_choice: MXGemmKernelChoice,
41+
use_fp8_dim1_cast_triton_kernel: bool,
4042
):
4143
ctx.save_for_backward(input_hp, weight_hp)
4244
ctx.in_elem_dtype = in_elem_dtype
4345
ctx.w_elem_dtype = w_elem_dtype
4446
ctx.grad_elem_dtype = grad_elem_dtype
4547
ctx.block_size = block_size
4648
ctx.gemm_kernel_choice = gemm_kernel_choice
49+
ctx.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel
4750

4851
# input @ weight_t = output
4952
input_orig_shape = input_hp.shape
@@ -63,12 +66,12 @@ def forward(
6366
@staticmethod
6467
def backward(ctx, grad_output_hp: torch.Tensor):
6568
input_hp, weight_hp = ctx.saved_tensors
66-
weight_hp_t_c = weight_hp.t().contiguous()
6769
in_elem_dtype = ctx.in_elem_dtype
6870
w_elem_dtype = ctx.w_elem_dtype
6971
grad_elem_dtype = ctx.grad_elem_dtype
7072
block_size = ctx.block_size
7173
gemm_kernel_choice = ctx.gemm_kernel_choice
74+
use_fp8_dim1_cast_triton_kernel = ctx.use_fp8_dim1_cast_triton_kernel
7275

7376
grad_output_orig_shape = grad_output_hp.shape
7477
grad_output_hp_r = grad_output_hp.reshape(-1, grad_output_orig_shape[-1])
@@ -83,34 +86,84 @@ def backward(ctx, grad_output_hp: torch.Tensor):
8386
block_size,
8487
gemm_kernel_choice=gemm_kernel_choice,
8588
)
86-
weight_mx_dim1 = MXTensor.to_mx(
87-
weight_hp_t_c,
88-
w_elem_dtype,
89-
block_size,
90-
gemm_kernel_choice=gemm_kernel_choice,
91-
)
89+
90+
if use_fp8_dim1_cast_triton_kernel:
91+
weight_mx_dim1_data, weight_mx_dim1_scale = triton_to_mxfp8_dim1(
92+
weight_hp, block_size
93+
)
94+
weight_mx_dim1 = MXTensor(
95+
weight_mx_dim1_scale.view(torch.uint8).reshape(-1),
96+
weight_mx_dim1_data.t(),
97+
w_elem_dtype,
98+
block_size,
99+
weight_hp.dtype,
100+
False,
101+
gemm_kernel_choice,
102+
False,
103+
)
104+
105+
else:
106+
weight_hp_t_c = weight_hp.t().contiguous()
107+
weight_mx_dim1 = MXTensor.to_mx(
108+
weight_hp_t_c,
109+
w_elem_dtype,
110+
block_size,
111+
gemm_kernel_choice=gemm_kernel_choice,
112+
)
92113
grad_input = torch.mm(grad_output_mx_dim0, weight_mx_dim1.t())
93114
grad_input = grad_input.reshape(
94115
*grad_output_orig_shape[:-1], grad_input.shape[-1]
95116
)
96117

97118
# input_t @ grad_output = grad_weight
98-
grad_output_mx_dim1 = MXTensor.to_mx(
99-
grad_output_hp_r.t().contiguous(),
100-
grad_elem_dtype,
101-
block_size,
102-
gemm_kernel_choice=gemm_kernel_choice,
103-
)
104-
input_t_mx_dim0_tmp = MXTensor.to_mx(
105-
input_hp_r.t().contiguous(),
106-
in_elem_dtype,
107-
block_size,
108-
gemm_kernel_choice=gemm_kernel_choice,
109-
)
110-
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
119+
if use_fp8_dim1_cast_triton_kernel:
120+
grad_output_mx_dim1_data, grad_output_mx_dim1_scale = triton_to_mxfp8_dim1(
121+
grad_output_hp_r, block_size
122+
)
123+
grad_output_mx_dim1 = MXTensor(
124+
grad_output_mx_dim1_scale.view(torch.uint8).reshape(-1),
125+
grad_output_mx_dim1_data.t(),
126+
grad_elem_dtype,
127+
block_size,
128+
grad_output_hp_r.dtype,
129+
False,
130+
gemm_kernel_choice,
131+
False,
132+
)
133+
else:
134+
grad_output_mx_dim1 = MXTensor.to_mx(
135+
grad_output_hp_r.t().contiguous(),
136+
grad_elem_dtype,
137+
block_size,
138+
gemm_kernel_choice=gemm_kernel_choice,
139+
)
140+
141+
if use_fp8_dim1_cast_triton_kernel:
142+
input_t_mx_dim0_tmp_data, input_t_mx_dim0_tmp_scale = triton_to_mxfp8_dim1(
143+
input_hp_r, block_size
144+
)
145+
input_t_mx_dim0_tmp = MXTensor(
146+
input_t_mx_dim0_tmp_scale.view(torch.uint8).reshape(-1),
147+
input_t_mx_dim0_tmp_data.t(),
148+
in_elem_dtype,
149+
block_size,
150+
input_hp_r.dtype,
151+
False,
152+
gemm_kernel_choice,
153+
False,
154+
)
155+
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
156+
else:
157+
input_t_mx_dim0_tmp = MXTensor.to_mx(
158+
input_hp_r.t().contiguous(),
159+
in_elem_dtype,
160+
block_size,
161+
gemm_kernel_choice=gemm_kernel_choice,
162+
)
163+
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
111164
grad_weight = torch.mm(grad_output_mx_dim1, input_t_mx_dim0)
112165

113-
return grad_input, grad_weight, None, None, None, None, None
166+
return grad_input, grad_weight, None, None, None, None, None, None
114167

115168

116169
class MXLinear(torch.nn.Linear):
@@ -154,6 +207,7 @@ def forward(self, x):
154207
config.elem_dtype_grad_output_override or config.elem_dtype,
155208
config.block_size,
156209
config.gemm_kernel_choice,
210+
config.use_fp8_dim1_cast_triton_kernel,
157211
)
158212
if self.bias is not None:
159213
y = y + self.bias

0 commit comments

Comments
 (0)