Skip to content

mx formats: create MXLinearConfig #1688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pytest
import torch

from torchao.prototype.mx_formats import config
from torchao.prototype.mx_formats.constants import (
DTYPE_FP4,
DTYPE_FP6_E2M3,
Expand Down Expand Up @@ -139,8 +138,14 @@ def test_exponent_nan_out(elem_dtype):
else:
raise AssertionError("unsupported")
block_size = 2
use_fp4_custom_triton_dequant_kernel = False
tensor_mx = MXTensor(
scale_e8m0_bits, data_bits, elem_dtype, block_size, torch.float
scale_e8m0_bits,
data_bits,
elem_dtype,
block_size,
torch.float,
use_fp4_custom_triton_dequant_kernel,
)
tensor_hp = tensor_mx.to_dtype(torch.float)
assert torch.all(torch.isnan(tensor_hp[0:1]))
Expand Down Expand Up @@ -188,15 +193,16 @@ def test_transpose(elem_dtype, fp4_triton):
M, K = 128, 256
block_size = 32
tensor_hp = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size)
config.use_fp4_custom_triton_dequant_kernel = fp4_triton
tensor_mx = MXTensor.to_mx(
tensor_hp,
elem_dtype,
block_size,
use_fp4_custom_triton_dequant_kernel=fp4_triton,
)
tensor_mx_dq_t = tensor_mx.to_dtype(tensor_hp.dtype).t()
config.use_fp4_custom_triton_dequant_kernel = False

tensor_mx_t = tensor_mx.t()
config.use_fp4_custom_triton_dequant_kernel = fp4_triton
tensor_mx_t_dq = tensor_mx_t.to_dtype(tensor_hp.dtype)
config.use_fp4_custom_triton_dequant_kernel = False

assert tensor_mx_dq_t.shape == tensor_mx_t_dq.shape
torch.testing.assert_close(tensor_mx_dq_t, tensor_mx_t_dq, atol=0, rtol=0)
Expand Down Expand Up @@ -258,18 +264,21 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):

to_dtype_c = torch.compile(to_dtype, fullgraph=True)

use_fp4_custom_triton_dequant_kernel = False
x_mx_dq = to_dtype(
x_mx._data,
x_mx._scale_e8m0,
x_mx._elem_dtype,
x_mx._block_size,
hp_dtype, # noqa: E501
use_fp4_custom_triton_dequant_kernel,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? I did a quick pass through the RFC and didn't see any reference to this functionality

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll probably delete it in a future PR. The current PR is a refactor with no logic changes.

)
x_mx_c_dq = to_dtype_c(
x_mx_c._data,
x_mx_c._scale_e8m0,
x_mx_c._elem_dtype,
x_mx_c._block_size,
hp_dtype,
use_fp4_custom_triton_dequant_kernel,
)
torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0)
15 changes: 13 additions & 2 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,13 @@
# If True, uses a custom triton kernel for fp4 dequantize
use_fp4_custom_triton_dequant_kernel = False
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass


@dataclass
class MXLinearConfig:
# If True, uses a custom triton kernel for fp4 dequantize
use_fp4_custom_triton_dequant_kernel: bool = False
22 changes: 19 additions & 3 deletions torchao/prototype/mx_formats/mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
Defines the prototype UX for converting a model to use mx weights
"""

from typing import Any
from typing import Any, Optional

import torch
import torch.nn.functional as F

from torchao.prototype.mx_formats.config import MXLinearConfig
from torchao.prototype.mx_formats.mx_tensor import MXTensor


Expand Down Expand Up @@ -110,13 +111,19 @@ def from_float(
elem_dtype_weight_override=None,
elem_dtype_grad_output_override=None,
*,
# TODO(next PR): move elem_dtype* and block size into config
config: MXLinearConfig = None,
block_size=32,
):
mod.__class__ = MXLinear
mod.in_elem_dtype = elem_dtype
mod.w_elem_dtype = elem_dtype_weight_override or elem_dtype
mod.grad_elem_dtype = elem_dtype_grad_output_override or elem_dtype
mod.block_size = block_size
# TODO(next PR): fix this
if config is None:
config = MXLinearConfig()
mod.config = config
return mod

def forward(self, x):
Expand Down Expand Up @@ -151,7 +158,9 @@ class MXInferenceLinear(torch.nn.Linear):

@classmethod
@torch.no_grad()
def from_float(cls, mod, elem_dtype, block_size):
def from_float(cls, mod, elem_dtype, block_size, config: MXLinearConfig):
# TODO(next PR): move elem_dtype and block_size into config

with torch.device("meta"):
super_kwargs = {
"in_features": mod.in_features,
Expand All @@ -166,6 +175,7 @@ def from_float(cls, mod, elem_dtype, block_size):
)
new_mod.bias = mod.bias
new_mod.elem_dtype = elem_dtype
new_mod.config = config
return new_mod

@torch.no_grad()
Expand Down Expand Up @@ -207,6 +217,8 @@ def swap_linear_with_mx_linear(
elem_dtype_weight_override=None,
elem_dtype_grad_output_override=None,
*,
# TODO(next PR): move elem_dtype* and block_size into config
config: Optional[MXLinearConfig] = None,
block_size=32,
filter_fn=None,
):
Expand All @@ -225,6 +237,7 @@ def __fn(mod, fqn):
elem_dtype,
elem_dtype_weight_override,
elem_dtype_grad_output_override,
config=config,
block_size=block_size,
),
combined_filter_fn,
Expand All @@ -236,6 +249,7 @@ def swap_linear_with_mx_inference_linear(
elem_dtype,
block_size,
filter_fn=None,
config: Optional[MXLinearConfig] = None,
):
if filter_fn is None:
combined_filter_fn = _is_linear
Expand All @@ -247,6 +261,8 @@ def __fn(mod, fqn):
combined_filter_fn = __fn
replace_with_custom_fn_if_matches_filter(
model,
lambda mod: MXInferenceLinear.from_float(mod, elem_dtype, block_size),
lambda mod: MXInferenceLinear.from_float(
mod, elem_dtype, block_size, config=config
),
combined_filter_fn,
)
6 changes: 4 additions & 2 deletions torchao/prototype/mx_formats/mx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def mx_desugar_op(aten_op, args, kwargs=None):
old._elem_dtype,
old._block_size,
old._orig_dtype,
old._use_fp4_custom_triton_dequant_kernel,
)
return new

Expand Down Expand Up @@ -82,6 +83,7 @@ def mx_t(aten_op, args, kwargs=None):
old._elem_dtype,
old._block_size,
old._orig_dtype,
old._use_fp4_custom_triton_dequant_kernel,
)
return new

Expand Down Expand Up @@ -120,6 +122,7 @@ def mx_view_op(aten_op, args, kwargs=None):
args[0]._elem_dtype,
args[0]._block_size,
args[0]._orig_dtype,
args[0]._use_fp4_custom_triton_dequant_kernel,
)


Expand All @@ -130,7 +133,6 @@ def autocast_to_copy(aten_op, args, kwargs=None):
tensor.
"""
assert isinstance(args[0], MXTensor)
# print('before', args[0], args[0].dtype, args[0]._orig_dtype)
assert (
len(kwargs) == 1 and "dtype" in kwargs
), "Only support dtype kwarg for autocast"
Expand All @@ -144,6 +146,6 @@ def autocast_to_copy(aten_op, args, kwargs=None):
args[0]._elem_dtype,
args[0]._block_size,
kwargs["dtype"],
args[0]._use_fp4_custom_triton_dequant_kernel,
)
# print('after', res, res.dtype, res._orig_dtype)
return res
46 changes: 39 additions & 7 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import torch

import torchao.prototype.mx_formats.config as config
from torchao.prototype.mx_formats.constants import (
BLOCK_SIZE_DEFAULT,
DTYPE_FP4,
Expand Down Expand Up @@ -239,7 +238,14 @@ def get_fp_scale(scale_e8m0):
return s_fp


def to_dtype(data_lp, scale_e8m0, elem_dtype, block_size, target_dtype):
def to_dtype(
data_lp,
scale_e8m0,
elem_dtype,
block_size,
target_dtype,
use_fp4_custom_triton_dequant_kernel,
):
orig_shape = data_lp.shape
is_transposed = not data_lp.is_contiguous()
# if the underlying data is transposed, convert to row major before
Expand All @@ -258,7 +264,7 @@ def to_dtype(data_lp, scale_e8m0, elem_dtype, block_size, target_dtype):
data_hp = f6_e3m2_unpacked_to_f32(data_lp)
data_hp = data_hp.to(target_dtype)
elif elem_dtype == DTYPE_FP4:
if config.use_fp4_custom_triton_dequant_kernel:
if use_fp4_custom_triton_dequant_kernel:
data_hp_rescaled = triton_f4_to_scaled_bf16(
data_lp,
scale_e8m0,
Expand Down Expand Up @@ -318,17 +324,29 @@ class ToMXConstrFunc(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, data_hp, elem_dtype, block_size, scaling_mode):
def forward(
ctx,
data_hp,
elem_dtype,
block_size,
scaling_mode,
use_fp4_custom_triton_dequant_kernel,
):
scale_e8m0_biased, data_lp = to_mx(
data_hp, elem_dtype, block_size, scaling_mode
)
return MXTensor(
scale_e8m0_biased, data_lp, elem_dtype, block_size, data_hp.dtype
scale_e8m0_biased,
data_lp,
elem_dtype,
block_size,
data_hp.dtype,
use_fp4_custom_triton_dequant_kernel,
)

@staticmethod
def backward(ctx, g):
return g, None, None, None
return g, None, None, None, None


@torch._dynamo.allow_in_graph
Expand All @@ -345,6 +363,7 @@ def forward(ctx, tensor_lp, target_dtype):
tensor_lp._elem_dtype,
tensor_lp._block_size,
target_dtype,
tensor_lp._use_fp4_custom_triton_dequant_kernel,
)

@staticmethod
Expand All @@ -360,6 +379,7 @@ def __new__(
elem_dtype,
block_size,
orig_dtype,
use_fp4_custom_triton_dequant_kernel,
):
new_size = data_bits.size()
if elem_dtype == DTYPE_FP4:
Expand Down Expand Up @@ -417,6 +437,9 @@ def __new__(
self._elem_dtype = elem_dtype
self._block_size = block_size
self._orig_dtype = orig_dtype
self._use_fp4_custom_triton_dequant_kernel = (
use_fp4_custom_triton_dequant_kernel
)
return self

def __repr__(self):
Expand All @@ -443,14 +466,22 @@ def to_mx(
elem_dtype: Union[torch.dtype, str],
block_size: int = BLOCK_SIZE_DEFAULT,
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR,
use_fp4_custom_triton_dequant_kernel: bool = False,
):
return ToMXConstrFunc.apply(data_hp, elem_dtype, block_size, scaling_mode)
return ToMXConstrFunc.apply(
data_hp,
elem_dtype,
block_size,
scaling_mode,
use_fp4_custom_triton_dequant_kernel,
)

def __tensor_flatten__(self):
ctx = {
"_elem_dtype": self._elem_dtype,
"_block_size": self._block_size,
"_orig_dtype": self._orig_dtype,
"_use_fp4_custom_triton_dequant_kernel": self._use_fp4_custom_triton_dequant_kernel,
}
return ["_scale_e8m0", "_data"], ctx

Expand All @@ -467,6 +498,7 @@ def __tensor_unflatten__(
metadata["_elem_dtype"],
metadata["_block_size"],
metadata["_orig_dtype"],
metadata["_use_fp4_custom_triton_dequant_kernel"],
)

# Do not force the MXTensor type on the returned tensor
Expand Down
Loading