2121
2222import torch
2323
24- import torchao .prototype .mx_formats .config as config
2524from torchao .prototype .mx_formats .constants import (
2625 BLOCK_SIZE_DEFAULT ,
2726 DTYPE_FP4 ,
@@ -239,7 +238,14 @@ def get_fp_scale(scale_e8m0):
239238 return s_fp
240239
241240
242- def to_dtype (data_lp , scale_e8m0 , elem_dtype , block_size , target_dtype ):
241+ def to_dtype (
242+ data_lp ,
243+ scale_e8m0 ,
244+ elem_dtype ,
245+ block_size ,
246+ target_dtype ,
247+ use_fp4_custom_triton_dequant_kernel ,
248+ ):
243249 orig_shape = data_lp .shape
244250 is_transposed = not data_lp .is_contiguous ()
245251 # if the underlying data is transposed, convert to row major before
@@ -258,7 +264,7 @@ def to_dtype(data_lp, scale_e8m0, elem_dtype, block_size, target_dtype):
258264 data_hp = f6_e3m2_unpacked_to_f32 (data_lp )
259265 data_hp = data_hp .to (target_dtype )
260266 elif elem_dtype == DTYPE_FP4 :
261- if config . use_fp4_custom_triton_dequant_kernel :
267+ if use_fp4_custom_triton_dequant_kernel :
262268 data_hp_rescaled = triton_f4_to_scaled_bf16 (
263269 data_lp ,
264270 scale_e8m0 ,
@@ -318,17 +324,29 @@ class ToMXConstrFunc(torch.autograd.Function):
318324 """
319325
320326 @staticmethod
321- def forward (ctx , data_hp , elem_dtype , block_size , scaling_mode ):
327+ def forward (
328+ ctx ,
329+ data_hp ,
330+ elem_dtype ,
331+ block_size ,
332+ scaling_mode ,
333+ use_fp4_custom_triton_dequant_kernel ,
334+ ):
322335 scale_e8m0_biased , data_lp = to_mx (
323336 data_hp , elem_dtype , block_size , scaling_mode
324337 )
325338 return MXTensor (
326- scale_e8m0_biased , data_lp , elem_dtype , block_size , data_hp .dtype
339+ scale_e8m0_biased ,
340+ data_lp ,
341+ elem_dtype ,
342+ block_size ,
343+ data_hp .dtype ,
344+ use_fp4_custom_triton_dequant_kernel ,
327345 )
328346
329347 @staticmethod
330348 def backward (ctx , g ):
331- return g , None , None , None
349+ return g , None , None , None , None
332350
333351
334352@torch ._dynamo .allow_in_graph
@@ -345,6 +363,7 @@ def forward(ctx, tensor_lp, target_dtype):
345363 tensor_lp ._elem_dtype ,
346364 tensor_lp ._block_size ,
347365 target_dtype ,
366+ tensor_lp ._use_fp4_custom_triton_dequant_kernel ,
348367 )
349368
350369 @staticmethod
@@ -360,6 +379,7 @@ def __new__(
360379 elem_dtype ,
361380 block_size ,
362381 orig_dtype ,
382+ use_fp4_custom_triton_dequant_kernel ,
363383 ):
364384 new_size = data_bits .size ()
365385 if elem_dtype == DTYPE_FP4 :
@@ -417,6 +437,9 @@ def __new__(
417437 self ._elem_dtype = elem_dtype
418438 self ._block_size = block_size
419439 self ._orig_dtype = orig_dtype
440+ self ._use_fp4_custom_triton_dequant_kernel = (
441+ use_fp4_custom_triton_dequant_kernel
442+ )
420443 return self
421444
422445 def __repr__ (self ):
@@ -443,14 +466,22 @@ def to_mx(
443466 elem_dtype : Union [torch .dtype , str ],
444467 block_size : int = BLOCK_SIZE_DEFAULT ,
445468 scaling_mode : ScaleCalculationMode = ScaleCalculationMode .FLOOR ,
469+ use_fp4_custom_triton_dequant_kernel : bool = False ,
446470 ):
447- return ToMXConstrFunc .apply (data_hp , elem_dtype , block_size , scaling_mode )
471+ return ToMXConstrFunc .apply (
472+ data_hp ,
473+ elem_dtype ,
474+ block_size ,
475+ scaling_mode ,
476+ use_fp4_custom_triton_dequant_kernel ,
477+ )
448478
449479 def __tensor_flatten__ (self ):
450480 ctx = {
451481 "_elem_dtype" : self ._elem_dtype ,
452482 "_block_size" : self ._block_size ,
453483 "_orig_dtype" : self ._orig_dtype ,
484+ "_use_fp4_custom_triton_dequant_kernel" : self ._use_fp4_custom_triton_dequant_kernel ,
454485 }
455486 return ["_scale_e8m0" , "_data" ], ctx
456487
@@ -467,6 +498,7 @@ def __tensor_unflatten__(
467498 metadata ["_elem_dtype" ],
468499 metadata ["_block_size" ],
469500 metadata ["_orig_dtype" ],
501+ metadata ["_use_fp4_custom_triton_dequant_kernel" ],
470502 )
471503
472504 # Do not force the MXTensor type on the returned tensor
0 commit comments