Skip to content

Commit 1dfafdd

Browse files
committed
mx_formats: move training to the quantize_ API
Summary: Moves the MX training code to the `quantize_` API, and removes the custom linear swapping function. The inference code will be moved in a separate PR, since that will require splitting the config between training and inference. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 8ebcac398082d5480cae251a7ca67b39c94716f9 ghstack-comment-id: 2755847180 Pull Request resolved: #1970
1 parent c03dfe2 commit 1dfafdd

File tree

7 files changed

+50
-45
lines changed

7 files changed

+50
-45
lines changed

benchmarks/float8/float8_roofline.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@
6161
Float8LinearConfig,
6262
convert_to_float8_training,
6363
)
64-
from torchao.prototype.mx_formats.config import MXLinearConfig
65-
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
64+
from torchao.prototype.mx_formats import MXLinearConfig
65+
from torchao.quantization import quantize_
6666
from torchao.testing.float8.roofline_utils import (
6767
get_float8_mem_sympy,
6868
get_gemm_time_sympy,
@@ -391,7 +391,7 @@ def run(
391391
assert mx_recipe_name is not None
392392
config = MXLinearConfig.from_recipe_name(mx_recipe_name)
393393
m_fp8_dyn = copy.deepcopy(m_orig)
394-
swap_linear_with_mx_linear(m_fp8_dyn, config=config)
394+
quantize_(m_fp8_dyn, config=config)
395395
m_fp8_dyn = torch.compile(m_fp8_dyn)
396396
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x, grad_output)
397397

benchmarks/float8/profile_lowp_training.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@
4646
convert_to_float8_training,
4747
)
4848
from torchao.prototype.mx_formats.config import MXLinearConfig
49-
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
5049
from torchao.prototype.mx_formats.mx_tensor import MXTensor
5150
from torchao.prototype.mx_formats.utils import to_blocked
51+
from torchao.quantization import quantize_
5252

5353
# don't truncate long kernel names
5454
pd.options.display.max_colwidth = 100
@@ -379,7 +379,7 @@ def main(
379379
if mx_recipe_name is None:
380380
convert_to_float8_training(m_lowp, config=config)
381381
else:
382-
swap_linear_with_mx_linear(m_lowp, config=config)
382+
quantize_(m_lowp, config=config)
383383

384384
# this function is only used for cast_only
385385
to_mx_func = MXTensor.to_mx

test/prototype/mx_formats/test_mx_linear.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
MXInferenceLinear,
2525
MXLinear,
2626
swap_linear_with_mx_inference_linear,
27-
swap_linear_with_mx_linear,
2827
)
28+
from torchao.quantization import quantize_
2929
from torchao.quantization.utils import compute_error
3030
from torchao.utils import (
3131
TORCH_VERSION_AT_LEAST_2_8,
@@ -98,7 +98,7 @@ def test_linear_eager_vs_hp(
9898
elem_dtype_grad_output_override=elem_dtype[2],
9999
use_fp8_dim1_cast_triton_kernel=use_fp8_dim1_cast_triton_kernel,
100100
)
101-
swap_linear_with_mx_linear(m_mx, config=config)
101+
quantize_(m_mx, config)
102102

103103
x_ref = torch.randn(
104104
*input_shape, device="cuda", dtype=torch.bfloat16
@@ -159,8 +159,8 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn):
159159
config_emulated = MXLinearConfig(block_size=32, elem_dtype=elem_dtype)
160160
config_real = MXLinearConfig.from_recipe_name(recipe_name)
161161

162-
swap_linear_with_mx_linear(m_emulated, config=config_emulated)
163-
swap_linear_with_mx_linear(m_real, config=config_real)
162+
quantize_(m_emulated, config=config_emulated)
163+
quantize_(m_real, config=config_real)
164164

165165
y_emulated = m_emulated(x)
166166
y_emulated.backward(g)
@@ -189,7 +189,7 @@ def test_activation_checkpointing():
189189
nn.Linear(8, 8, bias=True, device="cuda"),
190190
)
191191
config = MXLinearConfig(block_size=4, elem_dtype=elem_dtype)
192-
swap_linear_with_mx_linear(m, config=config)
192+
quantize_(m, config=config)
193193

194194
x = torch.randn(*input_shape, device="cuda").requires_grad_()
195195
g = torch.randn(*grad_shape, device="cuda")
@@ -252,7 +252,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke
252252
config = MXLinearConfig.from_recipe_name(recipe_name)
253253
config.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel
254254

255-
swap_linear_with_mx_linear(m_mx, config=config)
255+
quantize_(m_mx, config=config)
256256
m_mx_c = copy.deepcopy(m_mx)
257257
m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor")
258258

@@ -339,10 +339,12 @@ def test_filter_fn():
339339
nn.Linear(32, 32),
340340
)
341341
m2 = copy.deepcopy(m1)
342-
filter_fn = lambda mod, fqn: fqn != "1" # noqa: E731
342+
filter_fn = lambda mod, fqn: isinstance(mod, torch.nn.Linear) and fqn != "1" # noqa: E731
343343

344344
config = MXLinearConfig(block_size=32)
345-
swap_linear_with_mx_linear(m1, config=config, filter_fn=filter_fn)
345+
print("before", m1)
346+
quantize_(m1, config=config, filter_fn=filter_fn)
347+
print("after", m1)
346348
assert type(m1[0]) == MXLinear
347349
assert type(m1[1]) == torch.nn.Linear
348350

@@ -354,7 +356,7 @@ def test_filter_fn():
354356
def test_training_print_str():
355357
m = nn.Sequential(nn.Linear(32, 32))
356358
config = MXLinearConfig()
357-
swap_linear_with_mx_linear(m, config=config)
359+
quantize_(m, config=config)
358360
s = str(m)
359361
assert "bl_sz=32" in s
360362
assert "kernel=emulated" in s

torchao/prototype/mx_formats/README.md

+10-10
Original file line numberDiff line numberDiff line change
@@ -40,25 +40,24 @@ x_hp = x_mx.to_dtype(torch.float)
4040
This is a module to do MX training, the MX matmul is currently emulated.
4141

4242
```python
43-
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
44-
from torchao.prototype.mx_formats.config import MXLinearConfig, MXGemmKernelChoice
43+
import torch
44+
from torchao.quantization import quantize_
45+
from torchao.prototype.mx_formats import MXLinearConfig, MXGemmKernelChoice
4546

46-
# early prototype: on MX-enabled hardware, you can use the real MX gemm backed by
47-
# torchao's CUTLASS kernels. In the future, we will also add cuBLAS kernel support.
48-
gemm_kernel_choice = MXGemmKernelChoice.EMULATED
49-
50-
# on NVIDIA Blackwell GPUs, you can also use cuBLAS or CUTLASS mxfp8 kernels
51-
# note: torch.compile support for both of these is WIP
47+
# on NVIDIA Blackwell GPUs, you can use cuBLAS or CUTLASS mxfp8 kernels
48+
gemm_kernel_choice = MXGemmKernelChoice.CUBLAS
5249
# gemm_kernel_choice = MXGemmKernelChoice.CUTLASS
53-
# gemm_kernel_choice = MXGemmKernelChoice.CUBLAS
50+
51+
# on older NVIDIA gpus, you can run training with emulated MX gemm
52+
# gemm_kernel_choice = MXGemmKernelChoice.EMULATED
5453

5554
m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
5655
config = MXLinearConfig(
5756
elem_dtype=torch.float8_e4m3fn,
5857
block_size=32,
5958
gemm_kernel_choice=gemm_kernel_choice,
6059
)
61-
swap_linear_with_mx_linear(m, config=config)
60+
quantize_(m, config)
6261

6362
# training loop (not shown)
6463
```
@@ -68,6 +67,7 @@ swap_linear_with_mx_linear(m, config=config)
6867
This is a module to do MX inference, weights are in MX and matmul is in high precision.
6968

7069
```python
70+
import torch
7171
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_inference_linear
7272
from torchao.prototype.mx_formats.config import MXLinearConfig
7373

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from torchao.prototype.mx_formats.config import (
2+
MXGemmKernelChoice,
3+
MXLinearConfig,
4+
MXLinearRecipeName,
5+
)
6+
7+
# import mx_linear here to register the quantize_ transform logic
8+
# ruff: noqa: I001
9+
import torchao.prototype.mx_formats.mx_linear # noqa: F401
10+
11+
__all__ = [
12+
"MXLinearConfig",
13+
"MXGemmKernelChoice",
14+
"MXLinearRecipeName",
15+
]

torchao/prototype/mx_formats/config.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212

13+
from torchao.core.config import AOBaseConfig
1314
from torchao.prototype.mx_formats.constants import (
1415
DTYPE_FP4,
1516
DTYPE_TO_SHORT_STR,
@@ -41,7 +42,7 @@ class MXLinearRecipeName(Enum):
4142

4243

4344
@dataclass
44-
class MXLinearConfig:
45+
class MXLinearConfig(AOBaseConfig):
4546
# block size for scaling, default is 32 to match
4647
# https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
4748
# section 5.2

torchao/prototype/mx_formats/mx_linear.py

+7-20
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from torchao.prototype.mx_formats.config import MXGemmKernelChoice, MXLinearConfig
1717
from torchao.prototype.mx_formats.custom_cast import triton_to_mxfp8_dim1
1818
from torchao.prototype.mx_formats.mx_tensor import MXTensor
19+
from torchao.quantization.transform_module import (
20+
register_quantize_module_handler,
21+
)
1922

2023

2124
@torch._dynamo.allow_in_graph
@@ -183,7 +186,7 @@ def from_float(
183186
mod,
184187
config: Optional[MXLinearConfig] = MXLinearConfig(),
185188
):
186-
# TODO(before land): remove this
189+
assert isinstance(mod, torch.nn.Linear), f"unsupported type(mod) {type(mod)}"
187190
assert isinstance(config, MXLinearConfig)
188191
mod.__class__ = MXLinear
189192
mod.config = config
@@ -290,25 +293,9 @@ def _is_linear(mod, fqn):
290293
return isinstance(mod, torch.nn.Linear)
291294

292295

293-
def swap_linear_with_mx_linear(
294-
model,
295-
*,
296-
config: Optional[MXLinearConfig] = None,
297-
filter_fn=None,
298-
):
299-
if filter_fn is None:
300-
combined_filter_fn = _is_linear
301-
else:
302-
303-
def __fn(mod, fqn):
304-
return _is_linear(mod, fqn) and filter_fn(mod, fqn)
305-
306-
combined_filter_fn = __fn
307-
replace_with_custom_fn_if_matches_filter(
308-
model,
309-
lambda mod: MXLinear.from_float(mod, config=config),
310-
combined_filter_fn,
311-
)
296+
@register_quantize_module_handler(MXLinearConfig)
297+
def _mx_linear_transform(module: torch.nn.Module, config: MXLinearConfig):
298+
return MXLinear.from_float(module, config=config)
312299

313300

314301
def swap_linear_with_mx_inference_linear(

0 commit comments

Comments
 (0)