diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index f61b212132..76f340dc78 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -443,18 +443,6 @@ def test_transpose(elem_dtype, fp4_triton): torch.testing.assert_close(tensor_mx_dq_t, tensor_mx_t_dq, atol=0, rtol=0) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) -def test_cast_autograd(elem_dtype): - x = torch.arange(8, device="cuda").bfloat16().requires_grad_() - grad = torch.arange(8, device="cuda").bfloat16() * 0.5 - block_size = 8 - x_mx = MXTensor.to_mx(x, elem_dtype, block_size) - x_dq = x_mx.to_dtype(torch.bfloat16) - x_dq.backward(gradient=grad) - torch.testing.assert_close(grad, x.grad, atol=0, rtol=0) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) def test_view(elem_dtype): diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 7f11913720..4a9ff498d5 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -467,65 +467,6 @@ def tensor_size_fp6x4_to_hpx3(orig_size, is_contiguous): return new_size -@torch._dynamo.allow_in_graph -class ToMXConstrFunc(torch.autograd.Function): - """ - Differentiable cast to MX, no-op in backward - """ - - @staticmethod - def forward( - ctx, - data_hp, - elem_dtype, - block_size, - scaling_mode, - use_fp4_custom_triton_dequant_kernel, - gemm_kernel_choice, - pack_fp6, - ): - scale_e8m0_biased, data_lp = to_mx( - data_hp, elem_dtype, block_size, scaling_mode, pack_fp6=pack_fp6 - ) - return MXTensor( - scale_e8m0_biased, - data_lp, - elem_dtype, - block_size, - data_hp.dtype, - use_fp4_custom_triton_dequant_kernel, - gemm_kernel_choice, - pack_fp6, - ) - - @staticmethod - def backward(ctx, g): - return g, None, None, None, None, None, None - - -@torch._dynamo.allow_in_graph -class FromMXConstrFunc(torch.autograd.Function): - """ - Differentiable cast from MX, no-op in backward - """ - - @staticmethod - def forward(ctx, tensor_lp, target_dtype): - return to_dtype( - tensor_lp._data, - tensor_lp._scale_e8m0, - tensor_lp._elem_dtype, - tensor_lp._block_size, - target_dtype, - tensor_lp._use_fp4_custom_triton_dequant_kernel, - tensor_lp._pack_fp6, - ) - - @staticmethod - def backward(ctx, g): - return g, None, None - - class MXTensor(torch.Tensor): def __new__( cls, @@ -627,7 +568,15 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): raise NotImplementedError(f"{func} not implemented") def to_dtype(self, target_dtype): - return FromMXConstrFunc.apply(self, target_dtype) + return to_dtype( + self._data, + self._scale_e8m0, + self._elem_dtype, + self._block_size, + target_dtype, + self._use_fp4_custom_triton_dequant_kernel, + self._pack_fp6, + ) @staticmethod @torch._dynamo.allow_in_graph @@ -640,11 +589,15 @@ def to_mx( gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED, pack_fp6: bool = False, ): - return ToMXConstrFunc.apply( - data_hp, + scale_e8m0_biased, data_lp = to_mx( + data_hp, elem_dtype, block_size, scaling_mode, pack_fp6 + ) + return MXTensor( + scale_e8m0_biased, + data_lp, elem_dtype, block_size, - scaling_mode, + data_hp.dtype, use_fp4_custom_triton_dequant_kernel, gemm_kernel_choice, pack_fp6,