@@ -467,65 +467,6 @@ def tensor_size_fp6x4_to_hpx3(orig_size, is_contiguous):
467
467
return new_size
468
468
469
469
470
- @torch ._dynamo .allow_in_graph
471
- class ToMXConstrFunc (torch .autograd .Function ):
472
- """
473
- Differentiable cast to MX, no-op in backward
474
- """
475
-
476
- @staticmethod
477
- def forward (
478
- ctx ,
479
- data_hp ,
480
- elem_dtype ,
481
- block_size ,
482
- scaling_mode ,
483
- use_fp4_custom_triton_dequant_kernel ,
484
- gemm_kernel_choice ,
485
- pack_fp6 ,
486
- ):
487
- scale_e8m0_biased , data_lp = to_mx (
488
- data_hp , elem_dtype , block_size , scaling_mode , pack_fp6 = pack_fp6
489
- )
490
- return MXTensor (
491
- scale_e8m0_biased ,
492
- data_lp ,
493
- elem_dtype ,
494
- block_size ,
495
- data_hp .dtype ,
496
- use_fp4_custom_triton_dequant_kernel ,
497
- gemm_kernel_choice ,
498
- pack_fp6 ,
499
- )
500
-
501
- @staticmethod
502
- def backward (ctx , g ):
503
- return g , None , None , None , None , None , None
504
-
505
-
506
- @torch ._dynamo .allow_in_graph
507
- class FromMXConstrFunc (torch .autograd .Function ):
508
- """
509
- Differentiable cast from MX, no-op in backward
510
- """
511
-
512
- @staticmethod
513
- def forward (ctx , tensor_lp , target_dtype ):
514
- return to_dtype (
515
- tensor_lp ._data ,
516
- tensor_lp ._scale_e8m0 ,
517
- tensor_lp ._elem_dtype ,
518
- tensor_lp ._block_size ,
519
- target_dtype ,
520
- tensor_lp ._use_fp4_custom_triton_dequant_kernel ,
521
- tensor_lp ._pack_fp6 ,
522
- )
523
-
524
- @staticmethod
525
- def backward (ctx , g ):
526
- return g , None , None
527
-
528
-
529
470
class MXTensor (torch .Tensor ):
530
471
def __new__ (
531
472
cls ,
@@ -627,7 +568,15 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
627
568
raise NotImplementedError (f"{ func } not implemented" )
628
569
629
570
def to_dtype (self , target_dtype ):
630
- return FromMXConstrFunc .apply (self , target_dtype )
571
+ return to_dtype (
572
+ self ._data ,
573
+ self ._scale_e8m0 ,
574
+ self ._elem_dtype ,
575
+ self ._block_size ,
576
+ target_dtype ,
577
+ self ._use_fp4_custom_triton_dequant_kernel ,
578
+ self ._pack_fp6 ,
579
+ )
631
580
632
581
@staticmethod
633
582
@torch ._dynamo .allow_in_graph
@@ -640,11 +589,15 @@ def to_mx(
640
589
gemm_kernel_choice : MXGemmKernelChoice = MXGemmKernelChoice .EMULATED ,
641
590
pack_fp6 : bool = False ,
642
591
):
643
- return ToMXConstrFunc .apply (
644
- data_hp ,
592
+ scale_e8m0_biased , data_lp = to_mx (
593
+ data_hp , elem_dtype , block_size , scaling_mode , pack_fp6
594
+ )
595
+ return MXTensor (
596
+ scale_e8m0_biased ,
597
+ data_lp ,
645
598
elem_dtype ,
646
599
block_size ,
647
- scaling_mode ,
600
+ data_hp . dtype ,
648
601
use_fp4_custom_triton_dequant_kernel ,
649
602
gemm_kernel_choice ,
650
603
pack_fp6 ,
0 commit comments