Skip to content

mx_formats: move training to the quantize_ API #1970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 42 commits into from
Mar 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@
Float8LinearConfig,
convert_to_float8_training,
)
from torchao.prototype.mx_formats.config import MXLinearConfig
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
from torchao.prototype.mx_formats import MXLinearConfig
from torchao.quantization import quantize_
from torchao.testing.float8.roofline_utils import (
get_float8_mem_sympy,
get_gemm_time_sympy,
Expand Down Expand Up @@ -391,7 +391,7 @@ def run(
assert mx_recipe_name is not None
config = MXLinearConfig.from_recipe_name(mx_recipe_name)
m_fp8_dyn = copy.deepcopy(m_orig)
swap_linear_with_mx_linear(m_fp8_dyn, config=config)
quantize_(m_fp8_dyn, config=config)
m_fp8_dyn = torch.compile(m_fp8_dyn)
b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x, grad_output)

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/float8/profile_lowp_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@
convert_to_float8_training,
)
from torchao.prototype.mx_formats.config import MXLinearConfig
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.quantization import quantize_

# don't truncate long kernel names
pd.options.display.max_colwidth = 100
Expand Down Expand Up @@ -379,7 +379,7 @@ def main(
if mx_recipe_name is None:
convert_to_float8_training(m_lowp, config=config)
else:
swap_linear_with_mx_linear(m_lowp, config=config)
quantize_(m_lowp, config=config)

# this function is only used for cast_only
to_mx_func = MXTensor.to_mx
Expand Down
18 changes: 9 additions & 9 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
MXInferenceLinear,
MXLinear,
swap_linear_with_mx_inference_linear,
swap_linear_with_mx_linear,
)
from torchao.quantization import quantize_
from torchao.quantization.utils import compute_error
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_8,
Expand Down Expand Up @@ -98,7 +98,7 @@ def test_linear_eager_vs_hp(
elem_dtype_grad_output_override=elem_dtype[2],
use_fp8_dim1_cast_triton_kernel=use_fp8_dim1_cast_triton_kernel,
)
swap_linear_with_mx_linear(m_mx, config=config)
quantize_(m_mx, config)

x_ref = torch.randn(
*input_shape, device="cuda", dtype=torch.bfloat16
Expand Down Expand Up @@ -159,8 +159,8 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn):
config_emulated = MXLinearConfig(block_size=32, elem_dtype=elem_dtype)
config_real = MXLinearConfig.from_recipe_name(recipe_name)

swap_linear_with_mx_linear(m_emulated, config=config_emulated)
swap_linear_with_mx_linear(m_real, config=config_real)
quantize_(m_emulated, config=config_emulated)
quantize_(m_real, config=config_real)

y_emulated = m_emulated(x)
y_emulated.backward(g)
Expand Down Expand Up @@ -189,7 +189,7 @@ def test_activation_checkpointing():
nn.Linear(8, 8, bias=True, device="cuda"),
)
config = MXLinearConfig(block_size=4, elem_dtype=elem_dtype)
swap_linear_with_mx_linear(m, config=config)
quantize_(m, config=config)

x = torch.randn(*input_shape, device="cuda").requires_grad_()
g = torch.randn(*grad_shape, device="cuda")
Expand Down Expand Up @@ -252,7 +252,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke
config = MXLinearConfig.from_recipe_name(recipe_name)
config.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel

swap_linear_with_mx_linear(m_mx, config=config)
quantize_(m_mx, config=config)
m_mx_c = copy.deepcopy(m_mx)
m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor")

Expand Down Expand Up @@ -339,10 +339,10 @@ def test_filter_fn():
nn.Linear(32, 32),
)
m2 = copy.deepcopy(m1)
filter_fn = lambda mod, fqn: fqn != "1" # noqa: E731
filter_fn = lambda mod, fqn: isinstance(mod, torch.nn.Linear) and fqn != "1" # noqa: E731

config = MXLinearConfig(block_size=32)
swap_linear_with_mx_linear(m1, config=config, filter_fn=filter_fn)
quantize_(m1, config=config, filter_fn=filter_fn)
assert type(m1[0]) == MXLinear
assert type(m1[1]) == torch.nn.Linear

Expand All @@ -354,7 +354,7 @@ def test_filter_fn():
def test_training_print_str():
m = nn.Sequential(nn.Linear(32, 32))
config = MXLinearConfig()
swap_linear_with_mx_linear(m, config=config)
quantize_(m, config=config)
s = str(m)
assert "bl_sz=32" in s
assert "kernel=emulated" in s
Expand Down
20 changes: 10 additions & 10 deletions torchao/prototype/mx_formats/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,24 @@ x_hp = x_mx.to_dtype(torch.float)
This is a module to do MX training, the MX matmul is currently emulated.

```python
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
from torchao.prototype.mx_formats.config import MXLinearConfig, MXGemmKernelChoice
import torch
from torchao.quantization import quantize_
from torchao.prototype.mx_formats import MXLinearConfig, MXGemmKernelChoice

# early prototype: on MX-enabled hardware, you can use the real MX gemm backed by
# torchao's CUTLASS kernels. In the future, we will also add cuBLAS kernel support.
gemm_kernel_choice = MXGemmKernelChoice.EMULATED

# on NVIDIA Blackwell GPUs, you can also use cuBLAS or CUTLASS mxfp8 kernels
# note: torch.compile support for both of these is WIP
# on NVIDIA Blackwell GPUs, you can use cuBLAS or CUTLASS mxfp8 kernels
gemm_kernel_choice = MXGemmKernelChoice.CUBLAS
# gemm_kernel_choice = MXGemmKernelChoice.CUTLASS
# gemm_kernel_choice = MXGemmKernelChoice.CUBLAS

# on older NVIDIA gpus, you can run training with emulated MX gemm
# gemm_kernel_choice = MXGemmKernelChoice.EMULATED

m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
config = MXLinearConfig(
elem_dtype=torch.float8_e4m3fn,
block_size=32,
gemm_kernel_choice=gemm_kernel_choice,
)
swap_linear_with_mx_linear(m, config=config)
quantize_(m, config)

# training loop (not shown)
```
Expand All @@ -68,6 +67,7 @@ swap_linear_with_mx_linear(m, config=config)
This is a module to do MX inference, weights are in MX and matmul is in high precision.

```python
import torch
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_inference_linear
from torchao.prototype.mx_formats.config import MXLinearConfig

Expand Down
15 changes: 15 additions & 0 deletions torchao/prototype/mx_formats/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from torchao.prototype.mx_formats.config import (
MXGemmKernelChoice,
MXLinearConfig,
MXLinearRecipeName,
)

# import mx_linear here to register the quantize_ transform logic
# ruff: noqa: I001
import torchao.prototype.mx_formats.mx_linear # noqa: F401

__all__ = [
"MXLinearConfig",
"MXGemmKernelChoice",
"MXLinearRecipeName",
]
3 changes: 2 additions & 1 deletion torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch

from torchao.core.config import AOBaseConfig
from torchao.prototype.mx_formats.constants import (
DTYPE_FP4,
DTYPE_TO_SHORT_STR,
Expand Down Expand Up @@ -41,7 +42,7 @@ class MXLinearRecipeName(Enum):


@dataclass
class MXLinearConfig:
class MXLinearConfig(AOBaseConfig):
# block size for scaling, default is 32 to match
# https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
# section 5.2
Expand Down
27 changes: 7 additions & 20 deletions torchao/prototype/mx_formats/mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from torchao.prototype.mx_formats.config import MXGemmKernelChoice, MXLinearConfig
from torchao.prototype.mx_formats.custom_cast import triton_to_mxfp8_dim1
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)


@torch._dynamo.allow_in_graph
Expand Down Expand Up @@ -183,7 +186,7 @@ def from_float(
mod,
config: Optional[MXLinearConfig] = MXLinearConfig(),
):
# TODO(before land): remove this
assert isinstance(mod, torch.nn.Linear), f"unsupported type(mod) {type(mod)}"
assert isinstance(config, MXLinearConfig)
mod.__class__ = MXLinear
mod.config = config
Expand Down Expand Up @@ -290,25 +293,9 @@ def _is_linear(mod, fqn):
return isinstance(mod, torch.nn.Linear)


def swap_linear_with_mx_linear(
model,
*,
config: Optional[MXLinearConfig] = None,
filter_fn=None,
):
if filter_fn is None:
combined_filter_fn = _is_linear
else:

def __fn(mod, fqn):
return _is_linear(mod, fqn) and filter_fn(mod, fqn)

combined_filter_fn = __fn
replace_with_custom_fn_if_matches_filter(
model,
lambda mod: MXLinear.from_float(mod, config=config),
combined_filter_fn,
)
@register_quantize_module_handler(MXLinearConfig)
def _mx_linear_transform(module: torch.nn.Module, config: MXLinearConfig):
return MXLinear.from_float(module, config=config)


def swap_linear_with_mx_inference_linear(
Expand Down
Loading