Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MX: move block_size and elem_dtype into MXLinearConfig #1689

Merged
merged 11 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 19 additions & 28 deletions test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
import torch.nn as nn

from torchao.prototype.mx_formats.config import MXLinearConfig
from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES
from torchao.prototype.mx_formats.mx_linear import (
MXInferenceLinear,
Expand Down Expand Up @@ -59,8 +60,13 @@ def test_linear_eager(elem_dtype, bias, input_shape):
nn.Linear(8, 6, bias=bias, device="cuda"),
)
m_mx = copy.deepcopy(m)
block_size = 2
swap_linear_with_mx_linear(m_mx, *elem_dtype, block_size=block_size)
config = MXLinearConfig(
block_size=2,
elem_dtype=elem_dtype[0],
elem_dtype_weight_override=elem_dtype[1],
elem_dtype_grad_output_override=elem_dtype[2],
)
swap_linear_with_mx_linear(m_mx, config=config)

x_ref = torch.randn(*input_shape, device="cuda").requires_grad_()
x = copy.deepcopy(x_ref)
Expand Down Expand Up @@ -97,8 +103,8 @@ def test_activation_checkpointing():
nn.Linear(4, 6, bias=True, device="cuda"),
nn.Linear(6, 6, bias=True, device="cuda"),
)
block_size = 2
swap_linear_with_mx_linear(m, elem_dtype, block_size=block_size)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
swap_linear_with_mx_linear(m, config=config)

x = torch.randn(*input_shape, device="cuda").requires_grad_()
g = torch.randn(*grad_shape, device="cuda")
Expand Down Expand Up @@ -133,8 +139,8 @@ def test_linear_compile(elem_dtype, bias, use_autocast):
m_mx = nn.Sequential(
nn.Linear(K, N, bias=bias, device="cuda"),
)
block_size = 2
swap_linear_with_mx_linear(m_mx, elem_dtype, block_size=block_size)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
swap_linear_with_mx_linear(m_mx, config=config)
m_mx_c = copy.deepcopy(m_mx)
m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor")

Expand Down Expand Up @@ -181,8 +187,8 @@ def test_inference_linear(elem_dtype, bias, input_shape):
m = nn.Sequential(nn.Linear(4, 6, bias=bias, dtype=torch.bfloat16))
m = m.cuda()
m_mx = copy.deepcopy(m)
block_size = 2
swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
swap_linear_with_mx_inference_linear(m_mx, config=config)

x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16)
y_ref = m(x)
Expand All @@ -209,8 +215,8 @@ def test_inference_compile_simple(elem_dtype):
m = nn.Sequential(nn.Linear(4, 6, bias=False, dtype=torch.bfloat16))
m = m.cuda()
m_mx = copy.deepcopy(m)
block_size = 2
swap_linear_with_mx_inference_linear(m_mx, elem_dtype, block_size)
config = MXLinearConfig(block_size=2, elem_dtype=elem_dtype)
swap_linear_with_mx_inference_linear(m_mx, config=config)
m_mx = torch.compile(m_mx, fullgraph="true")

x = torch.randn(2, 4, device="cuda", dtype=torch.bfloat16)
Expand All @@ -223,20 +229,6 @@ def test_inference_compile_simple(elem_dtype):
assert sqnr >= 13.5


def test_mx_linear_input_weight_gradient_dtypes():
m = nn.Sequential(nn.Linear(32, 32))
swap_linear_with_mx_linear(m, *SUPPORTED_ELEM_DTYPES[:3], block_size=32)
assert m[0].in_elem_dtype == SUPPORTED_ELEM_DTYPES[0]
assert m[0].w_elem_dtype == SUPPORTED_ELEM_DTYPES[1]
assert m[0].grad_elem_dtype == SUPPORTED_ELEM_DTYPES[2]

m = nn.Sequential(nn.Linear(32, 32))
swap_linear_with_mx_linear(m, torch.float8_e4m3fn, block_size=32)
assert m[0].in_elem_dtype == torch.float8_e4m3fn
assert m[0].w_elem_dtype == torch.float8_e4m3fn
assert m[0].grad_elem_dtype == torch.float8_e4m3fn


def test_filter_fn():
m1 = nn.Sequential(
nn.Linear(32, 32),
Expand All @@ -245,12 +237,11 @@ def test_filter_fn():
m2 = copy.deepcopy(m1)
filter_fn = lambda mod, fqn: fqn != "1" # noqa: E731

swap_linear_with_mx_linear(
m1, torch.float8_e4m3fn, block_size=32, filter_fn=filter_fn
)
config = MXLinearConfig(block_size=32)
swap_linear_with_mx_linear(m1, config=config, filter_fn=filter_fn)
assert type(m1[0]) == MXLinear
assert type(m1[1]) == torch.nn.Linear

swap_linear_with_mx_inference_linear(m2, torch.float8_e4m3fn, 32, filter_fn) # noqa: E501
swap_linear_with_mx_inference_linear(m2, config=config, filter_fn=filter_fn) # noqa: E501
assert type(m2[0]) == MXInferenceLinear
assert type(m2[1]) == torch.nn.Linear
11 changes: 6 additions & 5 deletions torchao/prototype/mx_formats/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ This is a module to do MX training, the MX matmul is currently emulated.

```python
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
from torchao.prototype.mx_formats.config import MXLinearConfig

m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
elem_dtype = torch.float8_e4m3fn
swap_linear_with_mx_linear(m, elem_dtype, block_size=32)
config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32)
swap_linear_with_mx_linear(m, config=config)

# training loop (not shown)
```
Expand All @@ -55,11 +56,11 @@ This is a module to do MX inference, weights are in MX and matmul is in high pre

```python
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_inference_linear
from torchao.prototype.mx_formats.config import MXLinearConfig

m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
elem_dtype = torch.float8_e4m3fn
block_size = 32
swap_linear_with_mx_inference_linear(m, elem_dtype, block_size)
config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32)
swap_linear_with_mx_inference_linear(m, config=config)

# do inference (not shown)
```
Expand Down
31 changes: 31 additions & 0 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,40 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from typing import Any, Optional

import torch

from torchao.prototype.mx_formats.constants import SUPPORTED_ELEM_DTYPES


@dataclass
class MXLinearConfig:
# block size for scaling, default is 32 to match
# https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
# section 5.2
block_size: int = 32

# element dtype, used for activations, weights and gradients
elem_dtype: Any = torch.float8_e4m3fn

# overrides for element dtype for weights and gradients
# TODO(future PR): refactor to make this cleaner
elem_dtype_weight_override: Optional[Any] = None
elem_dtype_grad_output_override: Optional[Any] = None

# If True, uses a custom triton kernel for fp4 dequantize
use_fp4_custom_triton_dequant_kernel: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you think that we will want to keep this public?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unlikely, but IMO we can punt that until later


def __post_init__(self):
assert (
self.elem_dtype in SUPPORTED_ELEM_DTYPES
), f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"
if self.elem_dtype_weight_override is not None:
assert (
self.elem_dtype_weight_override in SUPPORTED_ELEM_DTYPES
), f"elem_dtype_weight_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"
if self.elem_dtype_grad_output_override is not None:
assert (
self.elem_dtype_grad_output_override in SUPPORTED_ELEM_DTYPES
), f"elem_dtype_grad_output_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"
60 changes: 18 additions & 42 deletions torchao/prototype/mx_formats/mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,22 +107,11 @@ class MXLinear(torch.nn.Linear):
def from_float(
cls,
mod,
elem_dtype,
elem_dtype_weight_override=None,
elem_dtype_grad_output_override=None,
*,
# TODO(next PR): move elem_dtype* and block size into config
config: MXLinearConfig = None,
block_size=32,
config: Optional[MXLinearConfig] = MXLinearConfig(),
):
# TODO(before land): remove this
assert isinstance(config, MXLinearConfig)
mod.__class__ = MXLinear
mod.in_elem_dtype = elem_dtype
mod.w_elem_dtype = elem_dtype_weight_override or elem_dtype
mod.grad_elem_dtype = elem_dtype_grad_output_override or elem_dtype
mod.block_size = block_size
# TODO(next PR): fix this
if config is None:
config = MXLinearConfig()
mod.config = config
return mod

Expand All @@ -135,13 +124,14 @@ def forward(self, x):
else:
w = self.weight

config = self.config
y = mx_mm.apply(
x,
w,
self.in_elem_dtype,
self.w_elem_dtype,
self.grad_elem_dtype,
self.block_size,
config.elem_dtype,
config.elem_dtype_weight_override or config.elem_dtype,
config.elem_dtype_grad_output_override or config.elem_dtype,
config.block_size,
)
if self.bias is not None:
y = y + self.bias
Expand All @@ -158,9 +148,11 @@ class MXInferenceLinear(torch.nn.Linear):

@classmethod
@torch.no_grad()
def from_float(cls, mod, elem_dtype, block_size, config: MXLinearConfig):
# TODO(next PR): move elem_dtype and block_size into config

def from_float(
cls,
mod,
config: Optional[MXLinearConfig] = MXLinearConfig(),
):
with torch.device("meta"):
super_kwargs = {
"in_features": mod.in_features,
Expand All @@ -171,10 +163,9 @@ def from_float(cls, mod, elem_dtype, block_size, config: MXLinearConfig):
# TODO(future PR): set to new_mod.weight directly, will need to work
# through some errors
new_mod.weight_mx = MXTensor.to_mx(
mod.weight, elem_dtype, block_size=block_size
mod.weight, config.elem_dtype, block_size=config.block_size
)
new_mod.bias = mod.bias
new_mod.elem_dtype = elem_dtype
new_mod.config = config
return new_mod

Expand Down Expand Up @@ -213,13 +204,8 @@ def _is_linear(mod, fqn):

def swap_linear_with_mx_linear(
model,
elem_dtype,
elem_dtype_weight_override=None,
elem_dtype_grad_output_override=None,
*,
# TODO(next PR): move elem_dtype* and block_size into config
config: Optional[MXLinearConfig] = None,
block_size=32,
filter_fn=None,
):
if filter_fn is None:
Expand All @@ -232,24 +218,16 @@ def __fn(mod, fqn):
combined_filter_fn = __fn
replace_with_custom_fn_if_matches_filter(
model,
lambda mod: MXLinear.from_float(
mod,
elem_dtype,
elem_dtype_weight_override,
elem_dtype_grad_output_override,
config=config,
block_size=block_size,
),
lambda mod: MXLinear.from_float(mod, config=config),
combined_filter_fn,
)


def swap_linear_with_mx_inference_linear(
model,
elem_dtype,
block_size,
filter_fn=None,
*,
config: Optional[MXLinearConfig] = None,
filter_fn=None,
):
if filter_fn is None:
combined_filter_fn = _is_linear
Expand All @@ -261,8 +239,6 @@ def __fn(mod, fqn):
combined_filter_fn = __fn
replace_with_custom_fn_if_matches_filter(
model,
lambda mod: MXInferenceLinear.from_float(
mod, elem_dtype, block_size, config=config
),
lambda mod: MXInferenceLinear.from_float(mod, config=config),
combined_filter_fn,
)
Loading