Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mx: add ceil and RNE rounding modes to the cast from fp32 to e8m0 #1643

Merged
merged 5 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torchao.prototype.mx_formats.mx_tensor import (
E8M0_EXPONENT_NAN_VAL,
MXTensor,
ScaleCalculationMode,
to_dtype,
)
from torchao.quantization.utils import compute_error
Expand Down Expand Up @@ -47,8 +48,10 @@ def run_before_and_after_tests():
torch._dynamo.reset()


def _test_mx(data_hp, elem_dtype, block_size):
data_mx = MXTensor.to_mx(data_hp, elem_dtype, block_size)
def _test_mx(
data_hp, elem_dtype, block_size, scale_calculation_mode=ScaleCalculationMode.FLOOR
):
data_mx = MXTensor.to_mx(data_hp, elem_dtype, block_size, scale_calculation_mode)
data_mx_dq = data_mx.to_dtype(data_hp.dtype)

def assert_sqnr_gt_threshold(orig, new, threshold):
Expand All @@ -61,7 +64,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
assert sqnr >= threshold

if elem_dtype is torch.float8_e4m3fn:
assert_sqnr_gt_threshold(data_hp, data_mx_dq, 20.0)
assert_sqnr_gt_threshold(data_hp, data_mx_dq, 18.0)
else:
assert_sqnr_gt_threshold(data_hp, data_mx_dq, 14.0)

Expand All @@ -74,6 +77,15 @@ def test_hello_world(elem_dtype):
_test_mx(data, elem_dtype, block_size)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("scale_calculation_mode", [s for s in ScaleCalculationMode])
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_realistic_numerics(elem_dtype, scale_calculation_mode):
data = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
block_size = 32
_test_mx(data, elem_dtype, block_size, scale_calculation_mode)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_all_zeros(elem_dtype):
Expand Down
71 changes: 61 additions & 10 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
* Zeros: N/A
"""

from enum import Enum, auto
from typing import Dict, Union

import torch
Expand Down Expand Up @@ -53,11 +54,38 @@
unpack_uint4,
)

# TODO(later): read from somewhere else?
SBITS, EBITS_F32, MBITS_F32 = 1, 8, 23
EBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1
EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3
EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2
EBITS_F8_E4M3, MBITS_F8_E4M3 = 4, 3
EBITS_F8_E5M2, MBITS_F8_E5M2 = 5, 2


class ScaleCalculationMode(Enum):
"""
Enum representing the different methods for calculating MX block scaling.
There are three methods available:
FLOOR: This method is recommended by the OCP MX Spec 1.0 and uses X = 2^floor(log2(max_abs(v))-max_exp).
It result in overflow issues for large values and bad for gradient quantization.
CEIL: This method avoids overflow issues, but small values may shift to 0 due to a large scaling factor.
It uses X = 2^ceil(log2(max_abs(v))-max_exp).
EVEN: This method is a trade-off between Option 1 and Option 2. It uses X = 2^(floor(log2(rounding(max_abs(v)))-max_exp)).
It provides better accuracy for MX4 training compared to FLOOR and CEIL.
By default, we use the EVEN method for better accuracy.
"""

FLOOR = auto()
CEIL = auto()
EVEN = auto()


def to_mx(
data_hp: torch.Tensor,
elem_dtype: Union[torch.dtype, str],
block_size: int,
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR,
):
"""
Takes a high precision tensor and converts to MX scale and raw data, in
Expand Down Expand Up @@ -88,25 +116,45 @@ def to_mx(
# where the values are zero.
eps = F32_MIN_NORMAL * (max_abs == 0).type(max_abs.dtype)

# Find largest power of 2 less than or equal to max_abs.
largest_p2_lt_max_abs = torch.floor(torch.log2(max_abs + eps))

# Set X to be the largest power-of-two less than or equal to
# max_abs(v), divided by the largest power of two representable
# in the element data type
# in the element data type, and get the mbits at the same time
if elem_dtype == torch.float8_e4m3fn:
target_max_pow2 = F8E4M3_MAX_POW2
mbits = MBITS_F8_E4M3
elif elem_dtype == torch.float8_e5m2:
target_max_pow2 = F8E5M2_MAX_POW2
mbits = MBITS_F8_E5M2
elif elem_dtype == DTYPE_FP6_E2M3:
target_max_pow2 = F6_E2M3_MAX_POW2
mbits = MBITS_F6_E2M3
elif elem_dtype == DTYPE_FP6_E3M2:
target_max_pow2 = F6_E3M2_MAX_POW2
mbits = MBITS_F6_E3M2
elif elem_dtype == DTYPE_FP4:
target_max_pow2 = F4_E2M1_MAX_POW2
mbits = MBITS_F4_E2M1
else:
raise AssertionError("unsupported")
scale_e8m0_unbiased = largest_p2_lt_max_abs - target_max_pow2
raise AssertionError("unsupported element dtype")

# rounding before calculating the largest power of 2
# X = 2^(floor(log2(rounding(max_abs(v)))-max_exp))
if scaling_mode == ScaleCalculationMode.EVEN:
nan_mask = torch.isnan(max_abs)
max_abs = max_abs.to(torch.float32).view(torch.int32)
val_to_add = 1 << (MBITS_F32 - mbits - 1)
mask = ((1 << (EBITS_F32 + SBITS)) - 1) << MBITS_F32
max_abs = (max_abs + val_to_add) & mask
max_abs = max_abs.view(torch.float32)
max_abs[nan_mask] = torch.tensor(float("nan"), device=max_abs.device)

# Calculate the scale for different modes
if scaling_mode in (ScaleCalculationMode.FLOOR, ScaleCalculationMode.EVEN):
scale_e8m0_unbiased = torch.floor(torch.log2(max_abs + eps)) - target_max_pow2
elif scaling_mode == ScaleCalculationMode.CEIL:
scale_e8m0_unbiased = torch.ceil(torch.log2(max_abs + eps)) - target_max_pow2
else:
raise AssertionError("unsupported scaling calculation mode")

# Clamp to exponents that can be represented in e8m0
scale_e8m0_unbiased = torch.clamp(
Expand Down Expand Up @@ -270,15 +318,17 @@ class ToMXConstrFunc(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, data_hp, elem_dtype, block_size):
scale_e8m0_biased, data_lp = to_mx(data_hp, elem_dtype, block_size)
def forward(ctx, data_hp, elem_dtype, block_size, scaling_mode):
scale_e8m0_biased, data_lp = to_mx(
data_hp, elem_dtype, block_size, scaling_mode
)
return MXTensor(
scale_e8m0_biased, data_lp, elem_dtype, block_size, data_hp.dtype
)

@staticmethod
def backward(ctx, g):
return g, None, None
return g, None, None, None


@torch._dynamo.allow_in_graph
Expand Down Expand Up @@ -392,8 +442,9 @@ def to_mx(
data_hp: torch.Tensor,
elem_dtype: Union[torch.dtype, str],
block_size: int = BLOCK_SIZE_DEFAULT,
scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR,
):
return ToMXConstrFunc.apply(data_hp, elem_dtype, block_size)
return ToMXConstrFunc.apply(data_hp, elem_dtype, block_size, scaling_mode)

def __tensor_flatten__(self):
ctx = {
Expand Down
Loading