Skip to content

Commit 1d75c8f

Browse files
authoredFeb 6, 2025··
Support mixed MX element dtype in mx_mm function and MXLinear. (#1667)
* Support mixed MX element dtype in `mx_mm` function. Following the MXFP and quantization literature, it is useful to support different element dtypes for activations, weights and gradients. * Support (input, weight, gradient) element dtype tuple in MXLinear layer factory method. Passing a tuple of 3 element dtypes avoids introducing a breaking change in the current interface of `MXLinear` and `swap_linear_with_mx_linear`. Some additional unit test coverage has been added on MXLinear. * Using default `elem_dtype` argument and optional weight/grad overrides.
1 parent 867a91f commit 1d75c8f

File tree

3 files changed

+88
-26
lines changed

3 files changed

+88
-26
lines changed
 

‎test/prototype/mx_formats/test_mx_linear.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import copy
8+
import itertools
89

910
import pytest
1011
import torch
@@ -41,13 +42,16 @@ def run_around_tests():
4142

4243

4344
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
44-
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
45+
@pytest.mark.parametrize(
46+
"elem_dtype", itertools.product(SUPPORTED_ELEM_DTYPES, repeat=3)
47+
)
4548
@pytest.mark.parametrize("bias", [True, False])
4649
@pytest.mark.parametrize("input_shape", [(4, 8), (1, 4, 8), (1, 1, 4, 8)])
4750
def test_linear_eager(elem_dtype, bias, input_shape):
4851
"""
4952
Smoke test for training linear module with mx weight
5053
"""
54+
# elem_dtype is a tuple of (input, weight, gradient) dtypes.
5155
grad_shape = list(input_shape)
5256
grad_shape[-1] = 6
5357

@@ -56,7 +60,7 @@ def test_linear_eager(elem_dtype, bias, input_shape):
5660
)
5761
m_mx = copy.deepcopy(m)
5862
block_size = 2
59-
swap_linear_with_mx_linear(m_mx, elem_dtype, block_size)
63+
swap_linear_with_mx_linear(m_mx, *elem_dtype, block_size=block_size)
6064

6165
x_ref = torch.randn(*input_shape, device="cuda").requires_grad_()
6266
x = copy.deepcopy(x_ref)
@@ -72,7 +76,7 @@ def test_linear_eager(elem_dtype, bias, input_shape):
7276
w_g_sqnr = compute_error(m[0].weight.grad, getattr(m_mx, "0").weight.grad)
7377
x_g_sqnr = compute_error(x_ref.grad, x.grad)
7478

75-
if elem_dtype is torch.float8_e4m3fn:
79+
if elem_dtype == (torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn):
7680
assert y_sqnr >= 18.0
7781
assert w_g_sqnr >= 18.0
7882
assert x_g_sqnr >= 12.0
@@ -94,7 +98,7 @@ def test_activation_checkpointing():
9498
nn.Linear(6, 6, bias=True, device="cuda"),
9599
)
96100
block_size = 2
97-
swap_linear_with_mx_linear(m, elem_dtype, block_size)
101+
swap_linear_with_mx_linear(m, elem_dtype, block_size=block_size)
98102

99103
x = torch.randn(*input_shape, device="cuda").requires_grad_()
100104
g = torch.randn(*grad_shape, device="cuda")
@@ -130,7 +134,7 @@ def test_linear_compile(elem_dtype, bias, use_autocast):
130134
nn.Linear(K, N, bias=bias, device="cuda"),
131135
)
132136
block_size = 2
133-
swap_linear_with_mx_linear(m_mx, elem_dtype, block_size)
137+
swap_linear_with_mx_linear(m_mx, elem_dtype, block_size=block_size)
134138
m_mx_c = copy.deepcopy(m_mx)
135139
m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor")
136140

@@ -219,6 +223,20 @@ def test_inference_compile_simple(elem_dtype):
219223
assert sqnr >= 13.5
220224

221225

226+
def test_mx_linear_input_weight_gradient_dtypes():
227+
m = nn.Sequential(nn.Linear(32, 32))
228+
swap_linear_with_mx_linear(m, *SUPPORTED_ELEM_DTYPES[:3], block_size=32)
229+
assert m[0].in_elem_dtype == SUPPORTED_ELEM_DTYPES[0]
230+
assert m[0].w_elem_dtype == SUPPORTED_ELEM_DTYPES[1]
231+
assert m[0].grad_elem_dtype == SUPPORTED_ELEM_DTYPES[2]
232+
233+
m = nn.Sequential(nn.Linear(32, 32))
234+
swap_linear_with_mx_linear(m, torch.float8_e4m3fn, block_size=32)
235+
assert m[0].in_elem_dtype == torch.float8_e4m3fn
236+
assert m[0].w_elem_dtype == torch.float8_e4m3fn
237+
assert m[0].grad_elem_dtype == torch.float8_e4m3fn
238+
239+
222240
def test_filter_fn():
223241
m1 = nn.Sequential(
224242
nn.Linear(32, 32),
@@ -227,7 +245,9 @@ def test_filter_fn():
227245
m2 = copy.deepcopy(m1)
228246
filter_fn = lambda mod, fqn: fqn != "1" # noqa: E731
229247

230-
swap_linear_with_mx_linear(m1, torch.float8_e4m3fn, 32, filter_fn)
248+
swap_linear_with_mx_linear(
249+
m1, torch.float8_e4m3fn, block_size=32, filter_fn=filter_fn
250+
)
231251
assert type(m1[0]) == MXLinear
232252
assert type(m1[1]) == torch.nn.Linear
233253

‎torchao/prototype/mx_formats/README.md

+4-5
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
This is a POC of training and inference with tensors in the MX format from the OCP spec (https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) in native PyTorch.
44

5-
Note that the current version of the code is written for readability and
6-
numerical correctness and not yet for optimal performance. We welcome
5+
Note that the current version of the code is written for readability and
6+
numerical correctness and not yet for optimal performance. We welcome
77
contributions on performance improvements.
88

99
Note that there are no BC guarantees at the moment and we plan to evolve
@@ -44,8 +44,7 @@ from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
4444

4545
m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
4646
elem_dtype = torch.float8_e4m3fn
47-
block_size = 32
48-
swap_linear_with_mx_linear(m, elem_dtype, block_size)
47+
swap_linear_with_mx_linear(m, elem_dtype, block_size=32)
4948

5049
# training loop (not shown)
5150
```
@@ -93,7 +92,7 @@ python torchao/prototype/mx_formats/benchmarks/bench_qdq.py
9392

9493
## floating point format convenience functions
9594

96-
We have a convenience script which summarizes the various properties of
95+
We have a convenience script which summarizes the various properties of
9796
floating point formats:
9897

9998
```bash

‎torchao/prototype/mx_formats/mx_linear.py

+58-15
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,31 @@ class mx_mm(torch.autograd.Function):
2323
# 1. input @ weight_t = output (forward pass)
2424
# 2. grad_output @ weight = grad_input (backward pass)
2525
# 3. input_t @ grad_output = grad_weight (backward pass)
26+
#
27+
# input, weight and grad_output can have each their own MX element dtype.
2628

2729
@staticmethod
2830
def forward(
2931
ctx,
3032
input_hp: torch.Tensor,
3133
weight_hp: torch.Tensor,
32-
elem_dtype: Any,
34+
in_elem_dtype: Any,
35+
w_elem_dtype: Any,
36+
grad_elem_dtype: Any,
3337
block_size: int,
3438
):
3539
ctx.save_for_backward(input_hp, weight_hp)
36-
ctx.elem_dtype = elem_dtype
40+
ctx.in_elem_dtype = in_elem_dtype
41+
ctx.w_elem_dtype = w_elem_dtype
42+
ctx.grad_elem_dtype = grad_elem_dtype
3743
ctx.block_size = block_size
3844

3945
# input @ weight_t = output
4046
input_orig_shape = input_hp.shape
4147
input_hp_r = input_hp.reshape(-1, input_orig_shape[-1])
4248

43-
input_mx_r_dim0 = MXTensor.to_mx(input_hp_r, elem_dtype, block_size)
44-
weight_mx_dim0 = MXTensor.to_mx(weight_hp, elem_dtype, block_size)
49+
input_mx_r_dim0 = MXTensor.to_mx(input_hp_r, in_elem_dtype, block_size)
50+
weight_mx_dim0 = MXTensor.to_mx(weight_hp, w_elem_dtype, block_size)
4551
output = torch.mm(input_mx_r_dim0, weight_mx_dim0.t())
4652
output = output.reshape(*input_orig_shape[:-1], output.shape[-1])
4753

@@ -51,7 +57,9 @@ def forward(
5157
def backward(ctx, grad_output_hp: torch.Tensor):
5258
input_hp, weight_hp = ctx.saved_tensors
5359
weight_hp_t_c = weight_hp.t().contiguous()
54-
elem_dtype = ctx.elem_dtype
60+
in_elem_dtype = ctx.in_elem_dtype
61+
w_elem_dtype = ctx.w_elem_dtype
62+
grad_elem_dtype = ctx.grad_elem_dtype
5563
block_size = ctx.block_size
5664

5765
grad_output_orig_shape = grad_output_hp.shape
@@ -61,24 +69,26 @@ def backward(ctx, grad_output_hp: torch.Tensor):
6169
input_hp_r = input_hp.reshape(-1, input_hp_orig_shape[-1])
6270

6371
# grad_output @ weight = grad_input
64-
grad_output_mx_dim0 = MXTensor.to_mx(grad_output_hp_r, elem_dtype, block_size)
65-
weight_mx_dim1 = MXTensor.to_mx(weight_hp_t_c, elem_dtype, block_size)
72+
grad_output_mx_dim0 = MXTensor.to_mx(
73+
grad_output_hp_r, grad_elem_dtype, block_size
74+
)
75+
weight_mx_dim1 = MXTensor.to_mx(weight_hp_t_c, w_elem_dtype, block_size)
6676
grad_input = torch.mm(grad_output_mx_dim0, weight_mx_dim1.t())
6777
grad_input = grad_input.reshape(
6878
*grad_output_orig_shape[:-1], grad_input.shape[-1]
6979
)
7080

7181
# input_t @ grad_output = grad_weight
7282
grad_output_mx_dim1 = MXTensor.to_mx(
73-
grad_output_hp_r.t().contiguous(), elem_dtype, block_size
83+
grad_output_hp_r.t().contiguous(), grad_elem_dtype, block_size
7484
)
7585
input_t_mx_dim0_tmp = MXTensor.to_mx(
76-
input_hp_r.t().contiguous(), elem_dtype, block_size
86+
input_hp_r.t().contiguous(), in_elem_dtype, block_size
7787
)
7888
input_t_mx_dim0 = input_t_mx_dim0_tmp.t()
7989
grad_weight = torch.mm(grad_output_mx_dim1, input_t_mx_dim0)
8090

81-
return grad_input, grad_weight, None, None
91+
return grad_input, grad_weight, None, None, None, None
8292

8393

8494
class MXLinear(torch.nn.Linear):
@@ -87,13 +97,25 @@ class MXLinear(torch.nn.Linear):
8797
matmul is emulated since there is no hardware support yet. Activations,
8898
weights and grads are casted to MX and back to high precision for each
8999
matmul.
100+
101+
Input, weight and grad_output can have each their own MX element dtype.
90102
"""
91103

92104
@classmethod
93105
@torch.no_grad()
94-
def from_float(cls, mod, elem_dtype, block_size):
106+
def from_float(
107+
cls,
108+
mod,
109+
elem_dtype,
110+
elem_dtype_weight_override=None,
111+
elem_dtype_grad_output_override=None,
112+
*,
113+
block_size=32,
114+
):
95115
mod.__class__ = MXLinear
96-
mod.elem_dtype = elem_dtype
116+
mod.in_elem_dtype = elem_dtype
117+
mod.w_elem_dtype = elem_dtype_weight_override or elem_dtype
118+
mod.grad_elem_dtype = elem_dtype_grad_output_override or elem_dtype
97119
mod.block_size = block_size
98120
return mod
99121

@@ -106,7 +128,14 @@ def forward(self, x):
106128
else:
107129
w = self.weight
108130

109-
y = mx_mm.apply(x, w, self.elem_dtype, self.block_size)
131+
y = mx_mm.apply(
132+
x,
133+
w,
134+
self.in_elem_dtype,
135+
self.w_elem_dtype,
136+
self.grad_elem_dtype,
137+
self.block_size,
138+
)
110139
if self.bias is not None:
111140
y = y + self.bias
112141
return y
@@ -172,7 +201,15 @@ def _is_linear(mod, fqn):
172201
return isinstance(mod, torch.nn.Linear)
173202

174203

175-
def swap_linear_with_mx_linear(model, elem_dtype, block_size, filter_fn=None):
204+
def swap_linear_with_mx_linear(
205+
model,
206+
elem_dtype,
207+
elem_dtype_weight_override=None,
208+
elem_dtype_grad_output_override=None,
209+
*,
210+
block_size=32,
211+
filter_fn=None,
212+
):
176213
if filter_fn is None:
177214
combined_filter_fn = _is_linear
178215
else:
@@ -183,7 +220,13 @@ def __fn(mod, fqn):
183220
combined_filter_fn = __fn
184221
replace_with_custom_fn_if_matches_filter(
185222
model,
186-
lambda mod: MXLinear.from_float(mod, elem_dtype, block_size),
223+
lambda mod: MXLinear.from_float(
224+
mod,
225+
elem_dtype,
226+
elem_dtype_weight_override,
227+
elem_dtype_grad_output_override,
228+
block_size=block_size,
229+
),
187230
combined_filter_fn,
188231
)
189232

0 commit comments

Comments
 (0)
Please sign in to comment.