Skip to content

Commit 2dcae71

Browse files
committed
mx: remove differentiable casts
Summary: This functionality is not used right now, removing so we can move faster. We can add it back in the future if needed. Test Plan: ``` pytest test/prototype/mx_formats/ -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: ea12f77 ghstack-comment-id: 2761693552 Pull Request resolved: #1978
1 parent 10c2936 commit 2dcae71

File tree

2 files changed

+16
-75
lines changed

2 files changed

+16
-75
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -443,18 +443,6 @@ def test_transpose(elem_dtype, fp4_triton):
443443
torch.testing.assert_close(tensor_mx_dq_t, tensor_mx_t_dq, atol=0, rtol=0)
444444

445445

446-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
447-
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
448-
def test_cast_autograd(elem_dtype):
449-
x = torch.arange(8, device="cuda").bfloat16().requires_grad_()
450-
grad = torch.arange(8, device="cuda").bfloat16() * 0.5
451-
block_size = 8
452-
x_mx = MXTensor.to_mx(x, elem_dtype, block_size)
453-
x_dq = x_mx.to_dtype(torch.bfloat16)
454-
x_dq.backward(gradient=grad)
455-
torch.testing.assert_close(grad, x.grad, atol=0, rtol=0)
456-
457-
458446
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
459447
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
460448
def test_view(elem_dtype):

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 16 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -467,65 +467,6 @@ def tensor_size_fp6x4_to_hpx3(orig_size, is_contiguous):
467467
return new_size
468468

469469

470-
@torch._dynamo.allow_in_graph
471-
class ToMXConstrFunc(torch.autograd.Function):
472-
"""
473-
Differentiable cast to MX, no-op in backward
474-
"""
475-
476-
@staticmethod
477-
def forward(
478-
ctx,
479-
data_hp,
480-
elem_dtype,
481-
block_size,
482-
scaling_mode,
483-
use_fp4_custom_triton_dequant_kernel,
484-
gemm_kernel_choice,
485-
pack_fp6,
486-
):
487-
scale_e8m0_biased, data_lp = to_mx(
488-
data_hp, elem_dtype, block_size, scaling_mode, pack_fp6=pack_fp6
489-
)
490-
return MXTensor(
491-
scale_e8m0_biased,
492-
data_lp,
493-
elem_dtype,
494-
block_size,
495-
data_hp.dtype,
496-
use_fp4_custom_triton_dequant_kernel,
497-
gemm_kernel_choice,
498-
pack_fp6,
499-
)
500-
501-
@staticmethod
502-
def backward(ctx, g):
503-
return g, None, None, None, None, None, None
504-
505-
506-
@torch._dynamo.allow_in_graph
507-
class FromMXConstrFunc(torch.autograd.Function):
508-
"""
509-
Differentiable cast from MX, no-op in backward
510-
"""
511-
512-
@staticmethod
513-
def forward(ctx, tensor_lp, target_dtype):
514-
return to_dtype(
515-
tensor_lp._data,
516-
tensor_lp._scale_e8m0,
517-
tensor_lp._elem_dtype,
518-
tensor_lp._block_size,
519-
target_dtype,
520-
tensor_lp._use_fp4_custom_triton_dequant_kernel,
521-
tensor_lp._pack_fp6,
522-
)
523-
524-
@staticmethod
525-
def backward(ctx, g):
526-
return g, None, None
527-
528-
529470
class MXTensor(torch.Tensor):
530471
def __new__(
531472
cls,
@@ -627,7 +568,15 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
627568
raise NotImplementedError(f"{func} not implemented")
628569

629570
def to_dtype(self, target_dtype):
630-
return FromMXConstrFunc.apply(self, target_dtype)
571+
return to_dtype(
572+
self._data,
573+
self._scale_e8m0,
574+
self._elem_dtype,
575+
self._block_size,
576+
target_dtype,
577+
self._use_fp4_custom_triton_dequant_kernel,
578+
self._pack_fp6,
579+
)
631580

632581
@staticmethod
633582
@torch._dynamo.allow_in_graph
@@ -640,11 +589,15 @@ def to_mx(
640589
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED,
641590
pack_fp6: bool = False,
642591
):
643-
return ToMXConstrFunc.apply(
644-
data_hp,
592+
scale_e8m0_biased, data_lp = to_mx(
593+
data_hp, elem_dtype, block_size, scaling_mode, pack_fp6
594+
)
595+
return MXTensor(
596+
scale_e8m0_biased,
597+
data_lp,
645598
elem_dtype,
646599
block_size,
647-
scaling_mode,
600+
data_hp.dtype,
648601
use_fp4_custom_triton_dequant_kernel,
649602
gemm_kernel_choice,
650603
pack_fp6,

0 commit comments

Comments
 (0)