Skip to content

mx: remove differentiable casts #1978

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 65 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
af6ae2f
Update
vkuzo Mar 21, 2025
45120de
Update
vkuzo Mar 21, 2025
5527e72
Update
vkuzo Mar 21, 2025
478b9e1
Update
vkuzo Mar 21, 2025
571775d
Update
vkuzo Mar 21, 2025
fd30558
Update
vkuzo Mar 21, 2025
b0cd056
Update
vkuzo Mar 21, 2025
26b49fd
Update
vkuzo Mar 21, 2025
ba10a02
Update
vkuzo Mar 21, 2025
483cdfd
Update
vkuzo Mar 21, 2025
32005c9
Update
vkuzo Mar 21, 2025
e341c2e
Update
vkuzo Mar 24, 2025
7ecd79f
Update
vkuzo Mar 24, 2025
ca3c4cf
Update
vkuzo Mar 24, 2025
0de11cf
Update
vkuzo Mar 24, 2025
912e4dc
Update
vkuzo Mar 24, 2025
fb5662a
Update
vkuzo Mar 25, 2025
f245d64
Update
vkuzo Mar 26, 2025
9e5b8f8
Update
vkuzo Mar 26, 2025
e5bdecb
Update
vkuzo Mar 26, 2025
4c2ad8c
Update
vkuzo Mar 26, 2025
c1ceef1
Update
vkuzo Mar 26, 2025
65bfff0
Update
vkuzo Mar 26, 2025
0ff3a93
Update
vkuzo Mar 26, 2025
71a5548
Update
vkuzo Mar 26, 2025
0576d0d
Update
vkuzo Mar 26, 2025
f98453f
Update
vkuzo Mar 27, 2025
81dc214
Update
vkuzo Mar 27, 2025
5d60f24
Update
vkuzo Mar 27, 2025
a313055
Update
vkuzo Mar 27, 2025
798abfc
Update
vkuzo Mar 27, 2025
4933b66
Update
vkuzo Mar 27, 2025
d9e60c1
Update
vkuzo Mar 27, 2025
884f065
Update
vkuzo Mar 27, 2025
41b1f9d
Update
vkuzo Mar 27, 2025
5cc2755
Update
vkuzo Mar 27, 2025
af1f386
Update
vkuzo Mar 27, 2025
8691bd4
Update
vkuzo Mar 27, 2025
1a0993d
Update
vkuzo Mar 27, 2025
b053f97
Update
vkuzo Mar 27, 2025
9e335ce
Update
vkuzo Mar 27, 2025
87756f9
Update
vkuzo Mar 28, 2025
d0a0fd1
Update
vkuzo Mar 28, 2025
cf9dfe4
Update
vkuzo Mar 28, 2025
beafdd9
Update
vkuzo Mar 28, 2025
45abedf
Update
vkuzo Mar 28, 2025
af87eee
Update
vkuzo Mar 28, 2025
db67393
Update
vkuzo Mar 28, 2025
a679de7
Update
vkuzo Mar 28, 2025
28dedc0
Update
vkuzo Mar 28, 2025
02d5065
Update
vkuzo Mar 28, 2025
d1bf83a
Update
vkuzo Mar 28, 2025
84c77d7
Update
vkuzo Mar 28, 2025
749564b
Update
vkuzo Mar 28, 2025
f63479e
Update
vkuzo Mar 28, 2025
c603f09
Update
vkuzo Mar 28, 2025
42fb0e9
Update
vkuzo Mar 28, 2025
a16f576
Update
vkuzo Mar 28, 2025
b890654
Update
vkuzo Mar 28, 2025
83e1e2e
Update
vkuzo Mar 28, 2025
c41cc19
Update
vkuzo Mar 28, 2025
8ba3018
Update
vkuzo Mar 28, 2025
a62e00b
Update
vkuzo Mar 28, 2025
195d904
Update
vkuzo Mar 28, 2025
8a5050e
Update
vkuzo Apr 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,18 +443,6 @@ def test_transpose(elem_dtype, fp4_triton):
torch.testing.assert_close(tensor_mx_dq_t, tensor_mx_t_dq, atol=0, rtol=0)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_cast_autograd(elem_dtype):
x = torch.arange(8, device="cuda").bfloat16().requires_grad_()
grad = torch.arange(8, device="cuda").bfloat16() * 0.5
block_size = 8
x_mx = MXTensor.to_mx(x, elem_dtype, block_size)
x_dq = x_mx.to_dtype(torch.bfloat16)
x_dq.backward(gradient=grad)
torch.testing.assert_close(grad, x.grad, atol=0, rtol=0)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_view(elem_dtype):
Expand Down
79 changes: 16 additions & 63 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,65 +467,6 @@ def tensor_size_fp6x4_to_hpx3(orig_size, is_contiguous):
return new_size


@torch._dynamo.allow_in_graph
class ToMXConstrFunc(torch.autograd.Function):
"""
Differentiable cast to MX, no-op in backward
"""

@staticmethod
def forward(
ctx,
data_hp,
elem_dtype,
block_size,
scaling_mode,
use_fp4_custom_triton_dequant_kernel,
gemm_kernel_choice,
pack_fp6,
):
scale_e8m0_biased, data_lp = to_mx(
data_hp, elem_dtype, block_size, scaling_mode, pack_fp6=pack_fp6
)
return MXTensor(
scale_e8m0_biased,
data_lp,
elem_dtype,
block_size,
data_hp.dtype,
use_fp4_custom_triton_dequant_kernel,
gemm_kernel_choice,
pack_fp6,
)

@staticmethod
def backward(ctx, g):
return g, None, None, None, None, None, None


@torch._dynamo.allow_in_graph
class FromMXConstrFunc(torch.autograd.Function):
"""
Differentiable cast from MX, no-op in backward
"""

@staticmethod
def forward(ctx, tensor_lp, target_dtype):
return to_dtype(
tensor_lp._data,
tensor_lp._scale_e8m0,
tensor_lp._elem_dtype,
tensor_lp._block_size,
target_dtype,
tensor_lp._use_fp4_custom_triton_dequant_kernel,
tensor_lp._pack_fp6,
)

@staticmethod
def backward(ctx, g):
return g, None, None


class MXTensor(torch.Tensor):
def __new__(
cls,
Expand Down Expand Up @@ -627,7 +568,15 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
raise NotImplementedError(f"{func} not implemented")

def to_dtype(self, target_dtype):
return FromMXConstrFunc.apply(self, target_dtype)
return to_dtype(
self._data,
self._scale_e8m0,
self._elem_dtype,
self._block_size,
target_dtype,
self._use_fp4_custom_triton_dequant_kernel,
self._pack_fp6,
)

@staticmethod
@torch._dynamo.allow_in_graph
Expand All @@ -640,11 +589,15 @@ def to_mx(
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED,
pack_fp6: bool = False,
):
return ToMXConstrFunc.apply(
data_hp,
scale_e8m0_biased, data_lp = to_mx(
data_hp, elem_dtype, block_size, scaling_mode, pack_fp6
)
return MXTensor(
scale_e8m0_biased,
data_lp,
elem_dtype,
block_size,
scaling_mode,
data_hp.dtype,
use_fp4_custom_triton_dequant_kernel,
gemm_kernel_choice,
pack_fp6,
Expand Down
Loading