Skip to content

Commit 40d01cd

Browse files
authored
MX: move block_size and elem_dtype into MXLinearConfig (#1689)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent c3bb80e commit 40d01cd

File tree

4 files changed

+74
-75
lines changed

4 files changed

+74
-75
lines changed

test/prototype/mx_formats/test_mx_linear.py

+19-28
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
import torch.nn as nn
1313

14+
from torchao.prototype.mx_formats.config import MXLinearConfig
1415
from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES
1516
from torchao.prototype.mx_formats.mx_linear import (
1617
MXInferenceLinear,
@@ -59,8 +60,13 @@ def test_linear_eager(elem_dtype, bias, input_shape):
5960
nn.Linear(8, 6, bias=bias, device="cuda"),
6061
)
6162
m_mx = copy.deepcopy(m)
62-
block_size = 2
63-
swap_linear_with_mx_linear(m_mx, *elem_dtype, block_size=block_size)
63+
config = MXLinearConfig(
64+
block_size=2,
65+
elem_dtype=elem_dtype[0],
66+
elem_dtype_weight_override=elem_dtype[1],
67+
elem_dtype_grad_output_override=elem_dtype[2],
68+
)
69+
swap_linear_with_mx_linear(m_mx, config=config)
6470

6571
x_ref = torch.randn(*input_shape, device="cuda").requires_grad_()
6672
x = copy.deepcopy(x_ref)
@@ -97,8 +103,8 @@ def test_activation_checkpointing():
97103
nn.Linear(4, 6, bias=True, device="cuda"),
98104
nn.Linear(6, 6, bias=True, device="cuda"),
99105
)
100-
block_size = 2
101-
swap_linear_with_mx_linear(m, elem_dtype, block_size=block_size)
106+
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
107+
swap_linear_with_mx_linear(m, config=config)
102108

103109
x = torch.randn(*input_shape, device="cuda").requires_grad_()
104110
g = torch.randn(*grad_shape, device="cuda")
@@ -133,8 +139,8 @@ def test_linear_compile(elem_dtype, bias, use_autocast):
133139
m_mx = nn.Sequential(
134140
nn.Linear(K, N, bias=bias, device="cuda"),
135141
)
136-
block_size = 2
137-
swap_linear_with_mx_linear(m_mx, elem_dtype, block_size=block_size)
142+
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
143+
swap_linear_with_mx_linear(m_mx, config=config)
138144
m_mx_c = copy.deepcopy(m_mx)
139145
m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor")
140146

@@ -181,8 +187,8 @@ def test_inference_linear(elem_dtype, bias, input_shape):
181187
m = nn.Sequential(nn.Linear(4, 6, bias=bias, dtype=torch.bfloat16))
182188
m = m.cuda()
183189
m_mx = copy.deepcopy(m)
184-
block_size = 2
185-
swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size)
190+
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
191+
swap_linear_with_mx_inference_linear(m_mx, config=config)
186192

187193
x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16)
188194
y_ref = m(x)
@@ -209,8 +215,8 @@ def test_inference_compile_simple(elem_dtype):
209215
m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16))
210216
m = m.cuda()
211217
m_mx = copy.deepcopy(m)
212-
block_size = 2
213-
swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size)
218+
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
219+
swap_linear_with_mx_inference_linear(m_mx, config=config)
214220
m_mx = torch.compile(m_mx, fullgraph="true")
215221

216222
x = torch.randn(2, 4, device="cuda", dtype=torch.bfloat16)
@@ -223,20 +229,6 @@ def test_inference_compile_simple(elem_dtype):
223229
assert sqnr >= 13.5
224230

225231

226-
def test_mx_linear_input_weight_gradient_dtypes():
227-
m = nn.Sequential(nn.Linear(32, 32))
228-
swap_linear_with_mx_linear(m, *SUPPORTED_ELEM_DTYPES[:3], block_size=32)
229-
assert m[0].in_elem_dtype == SUPPORTED_ELEM_DTYPES[0]
230-
assert m[0].w_elem_dtype == SUPPORTED_ELEM_DTYPES[1]
231-
assert m[0].grad_elem_dtype == SUPPORTED_ELEM_DTYPES[2]
232-
233-
m = nn.Sequential(nn.Linear(32, 32))
234-
swap_linear_with_mx_linear(m, torch.float8_e4m3fn, block_size=32)
235-
assert m[0].in_elem_dtype == torch.float8_e4m3fn
236-
assert m[0].w_elem_dtype == torch.float8_e4m3fn
237-
assert m[0].grad_elem_dtype == torch.float8_e4m3fn
238-
239-
240232
def test_filter_fn():
241233
m1 = nn.Sequential(
242234
nn.Linear(32, 32),
@@ -245,12 +237,11 @@ def test_filter_fn():
245237
m2 = copy.deepcopy(m1)
246238
filter_fn = lambda mod, fqn: fqn != "1" # noqa: E731
247239

248-
swap_linear_with_mx_linear(
249-
m1, torch.float8_e4m3fn, block_size=32, filter_fn=filter_fn
250-
)
240+
config = MXLinearConfig(block_size=32)
241+
swap_linear_with_mx_linear(m1, config=config, filter_fn=filter_fn)
251242
assert type(m1[0]) == MXLinear
252243
assert type(m1[1]) == torch.nn.Linear
253244

254-
swap_linear_with_mx_inference_linear(m2, torch.float8_e4m3fn, 32, filter_fn) # noqa: E501
245+
swap_linear_with_mx_inference_linear(m2, config=config, filter_fn=filter_fn) # noqa: E501
255246
assert type(m2[0]) == MXInferenceLinear
256247
assert type(m2[1]) == torch.nn.Linear

torchao/prototype/mx_formats/README.md

+6-5
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,11 @@ This is a module to do MX training, the MX matmul is currently emulated.
4141

4242
```python
4343
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
44+
from torchao.prototype.mx_formats.config import MXLinearConfig
4445

4546
m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
46-
elem_dtype = torch.float8_e4m3fn
47-
swap_linear_with_mx_linear(m, elem_dtype, block_size=32)
47+
config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32)
48+
swap_linear_with_mx_linear(m, config=config)
4849

4950
# training loop (not shown)
5051
```
@@ -55,11 +56,11 @@ This is a module to do MX inference, weights are in MX and matmul is in high pre
5556

5657
```python
5758
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_inference_linear
59+
from torchao.prototype.mx_formats.config import MXLinearConfig
5860

5961
m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
60-
elem_dtype = torch.float8_e4m3fn
61-
block_size = 32
62-
swap_linear_with_mx_inference_linear(m, elem_dtype, block_size)
62+
config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32)
63+
swap_linear_with_mx_inference_linear(m, config=config)
6364

6465
# do inference (not shown)
6566
```

torchao/prototype/mx_formats/config.py

+31
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,40 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from dataclasses import dataclass
8+
from typing import Any, Optional
9+
10+
import torch
11+
12+
from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES
813

914

1015
@dataclass
1116
class MXLinearConfig:
17+
# block size for scaling, default is 32 to match
18+
# https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
19+
# section 5.2
20+
block_size: int = 32
21+
22+
# element dtype, used for activations, weights and gradients
23+
elem_dtype: Any = torch.float8_e4m3fn
24+
25+
# overrides for element dtype for weights and gradients
26+
# TODO(future PR): refactor to make this cleaner
27+
elem_dtype_weight_override: Optional[Any] = None
28+
elem_dtype_grad_output_override: Optional[Any] = None
29+
1230
# If True, uses a custom triton kernel for fp4 dequantize
1331
use_fp4_custom_triton_dequant_kernel: bool = False
32+
33+
def __post_init__(self):
34+
assert (
35+
self.elem_dtype in SUPPORTED_ELEM_DTYPES
36+
), f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"
37+
if self.elem_dtype_weight_override is not None:
38+
assert (
39+
self.elem_dtype_weight_override in SUPPORTED_ELEM_DTYPES
40+
), f"elem_dtype_weight_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"
41+
if self.elem_dtype_grad_output_override is not None:
42+
assert (
43+
self.elem_dtype_grad_output_override in SUPPORTED_ELEM_DTYPES
44+
), f"elem_dtype_grad_output_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"

torchao/prototype/mx_formats/mx_linear.py

+18-42
Original file line numberDiff line numberDiff line change
@@ -107,22 +107,11 @@ class MXLinear(torch.nn.Linear):
107107
def from_float(
108108
cls,
109109
mod,
110-
elem_dtype,
111-
elem_dtype_weight_override=None,
112-
elem_dtype_grad_output_override=None,
113-
*,
114-
# TODO(next PR): move elem_dtype* and block size into config
115-
config: MXLinearConfig = None,
116-
block_size=32,
110+
config: Optional[MXLinearConfig] = MXLinearConfig(),
117111
):
112+
# TODO(before land): remove this
113+
assert isinstance(config, MXLinearConfig)
118114
mod.__class__ = MXLinear
119-
mod.in_elem_dtype = elem_dtype
120-
mod.w_elem_dtype = elem_dtype_weight_override or elem_dtype
121-
mod.grad_elem_dtype = elem_dtype_grad_output_override or elem_dtype
122-
mod.block_size = block_size
123-
# TODO(next PR): fix this
124-
if config is None:
125-
config = MXLinearConfig()
126115
mod.config = config
127116
return mod
128117

@@ -135,13 +124,14 @@ def forward(self, x):
135124
else:
136125
w = self.weight
137126

127+
config = self.config
138128
y = mx_mm.apply(
139129
x,
140130
w,
141-
self.in_elem_dtype,
142-
self.w_elem_dtype,
143-
self.grad_elem_dtype,
144-
self.block_size,
131+
config.elem_dtype,
132+
config.elem_dtype_weight_override or config.elem_dtype,
133+
config.elem_dtype_grad_output_override or config.elem_dtype,
134+
config.block_size,
145135
)
146136
if self.bias is not None:
147137
y = y + self.bias
@@ -158,9 +148,11 @@ class MXInferenceLinear(torch.nn.Linear):
158148

159149
@classmethod
160150
@torch.no_grad()
161-
def from_float(cls, mod, elem_dtype, block_size, config: MXLinearConfig):
162-
# TODO(next PR): move elem_dtype and block_size into config
163-
151+
def from_float(
152+
cls,
153+
mod,
154+
config: Optional[MXLinearConfig] = MXLinearConfig(),
155+
):
164156
with torch.device("meta"):
165157
super_kwargs = {
166158
"in_features": mod.in_features,
@@ -171,10 +163,9 @@ def from_float(cls, mod, elem_dtype, block_size, config: MXLinearConfig):
171163
# TODO(future PR): set to new_mod.weight directly, will need to work
172164
# through some errors
173165
new_mod.weight_mx = MXTensor.to_mx(
174-
mod.weight, elem_dtype, block_size=block_size
166+
mod.weight, config.elem_dtype, block_size=config.block_size
175167
)
176168
new_mod.bias = mod.bias
177-
new_mod.elem_dtype = elem_dtype
178169
new_mod.config = config
179170
return new_mod
180171

@@ -213,13 +204,8 @@ def _is_linear(mod, fqn):
213204

214205
def swap_linear_with_mx_linear(
215206
model,
216-
elem_dtype,
217-
elem_dtype_weight_override=None,
218-
elem_dtype_grad_output_override=None,
219207
*,
220-
# TODO(next PR): move elem_dtype* and block_size into config
221208
config: Optional[MXLinearConfig] = None,
222-
block_size=32,
223209
filter_fn=None,
224210
):
225211
if filter_fn is None:
@@ -232,24 +218,16 @@ def __fn(mod, fqn):
232218
combined_filter_fn = __fn
233219
replace_with_custom_fn_if_matches_filter(
234220
model,
235-
lambda mod: MXLinear.from_float(
236-
mod,
237-
elem_dtype,
238-
elem_dtype_weight_override,
239-
elem_dtype_grad_output_override,
240-
config=config,
241-
block_size=block_size,
242-
),
221+
lambda mod: MXLinear.from_float(mod, config=config),
243222
combined_filter_fn,
244223
)
245224

246225

247226
def swap_linear_with_mx_inference_linear(
248227
model,
249-
elem_dtype,
250-
block_size,
251-
filter_fn=None,
228+
*,
252229
config: Optional[MXLinearConfig] = None,
230+
filter_fn=None,
253231
):
254232
if filter_fn is None:
255233
combined_filter_fn = _is_linear
@@ -261,8 +239,6 @@ def __fn(mod, fqn):
261239
combined_filter_fn = __fn
262240
replace_with_custom_fn_if_matches_filter(
263241
model,
264-
lambda mod: MXInferenceLinear.from_float(
265-
mod, elem_dtype, block_size, config=config
266-
),
242+
lambda mod: MXInferenceLinear.from_float(mod, config=config),
267243
combined_filter_fn,
268244
)

0 commit comments

Comments
 (0)