21
21
22
22
import torch
23
23
24
- import torchao .prototype .mx_formats .config as config
25
24
from torchao .prototype .mx_formats .constants import (
26
25
BLOCK_SIZE_DEFAULT ,
27
26
DTYPE_FP4 ,
@@ -239,7 +238,14 @@ def get_fp_scale(scale_e8m0):
239
238
return s_fp
240
239
241
240
242
- def to_dtype (data_lp , scale_e8m0 , elem_dtype , block_size , target_dtype ):
241
+ def to_dtype (
242
+ data_lp ,
243
+ scale_e8m0 ,
244
+ elem_dtype ,
245
+ block_size ,
246
+ target_dtype ,
247
+ use_fp4_custom_triton_dequant_kernel ,
248
+ ):
243
249
orig_shape = data_lp .shape
244
250
is_transposed = not data_lp .is_contiguous ()
245
251
# if the underlying data is transposed, convert to row major before
@@ -258,7 +264,7 @@ def to_dtype(data_lp, scale_e8m0, elem_dtype, block_size, target_dtype):
258
264
data_hp = f6_e3m2_unpacked_to_f32 (data_lp )
259
265
data_hp = data_hp .to (target_dtype )
260
266
elif elem_dtype == DTYPE_FP4 :
261
- if config . use_fp4_custom_triton_dequant_kernel :
267
+ if use_fp4_custom_triton_dequant_kernel :
262
268
data_hp_rescaled = triton_f4_to_scaled_bf16 (
263
269
data_lp ,
264
270
scale_e8m0 ,
@@ -318,17 +324,29 @@ class ToMXConstrFunc(torch.autograd.Function):
318
324
"""
319
325
320
326
@staticmethod
321
- def forward (ctx , data_hp , elem_dtype , block_size , scaling_mode ):
327
+ def forward (
328
+ ctx ,
329
+ data_hp ,
330
+ elem_dtype ,
331
+ block_size ,
332
+ scaling_mode ,
333
+ use_fp4_custom_triton_dequant_kernel ,
334
+ ):
322
335
scale_e8m0_biased , data_lp = to_mx (
323
336
data_hp , elem_dtype , block_size , scaling_mode
324
337
)
325
338
return MXTensor (
326
- scale_e8m0_biased , data_lp , elem_dtype , block_size , data_hp .dtype
339
+ scale_e8m0_biased ,
340
+ data_lp ,
341
+ elem_dtype ,
342
+ block_size ,
343
+ data_hp .dtype ,
344
+ use_fp4_custom_triton_dequant_kernel ,
327
345
)
328
346
329
347
@staticmethod
330
348
def backward (ctx , g ):
331
- return g , None , None , None
349
+ return g , None , None , None , None
332
350
333
351
334
352
@torch ._dynamo .allow_in_graph
@@ -345,6 +363,7 @@ def forward(ctx, tensor_lp, target_dtype):
345
363
tensor_lp ._elem_dtype ,
346
364
tensor_lp ._block_size ,
347
365
target_dtype ,
366
+ tensor_lp ._use_fp4_custom_triton_dequant_kernel ,
348
367
)
349
368
350
369
@staticmethod
@@ -360,6 +379,7 @@ def __new__(
360
379
elem_dtype ,
361
380
block_size ,
362
381
orig_dtype ,
382
+ use_fp4_custom_triton_dequant_kernel ,
363
383
):
364
384
new_size = data_bits .size ()
365
385
if elem_dtype == DTYPE_FP4 :
@@ -417,6 +437,9 @@ def __new__(
417
437
self ._elem_dtype = elem_dtype
418
438
self ._block_size = block_size
419
439
self ._orig_dtype = orig_dtype
440
+ self ._use_fp4_custom_triton_dequant_kernel = (
441
+ use_fp4_custom_triton_dequant_kernel
442
+ )
420
443
return self
421
444
422
445
def __repr__ (self ):
@@ -443,14 +466,22 @@ def to_mx(
443
466
elem_dtype : Union [torch .dtype , str ],
444
467
block_size : int = BLOCK_SIZE_DEFAULT ,
445
468
scaling_mode : ScaleCalculationMode = ScaleCalculationMode .FLOOR ,
469
+ use_fp4_custom_triton_dequant_kernel : bool = False ,
446
470
):
447
- return ToMXConstrFunc .apply (data_hp , elem_dtype , block_size , scaling_mode )
471
+ return ToMXConstrFunc .apply (
472
+ data_hp ,
473
+ elem_dtype ,
474
+ block_size ,
475
+ scaling_mode ,
476
+ use_fp4_custom_triton_dequant_kernel ,
477
+ )
448
478
449
479
def __tensor_flatten__ (self ):
450
480
ctx = {
451
481
"_elem_dtype" : self ._elem_dtype ,
452
482
"_block_size" : self ._block_size ,
453
483
"_orig_dtype" : self ._orig_dtype ,
484
+ "_use_fp4_custom_triton_dequant_kernel" : self ._use_fp4_custom_triton_dequant_kernel ,
454
485
}
455
486
return ["_scale_e8m0" , "_data" ], ctx
456
487
@@ -467,6 +498,7 @@ def __tensor_unflatten__(
467
498
metadata ["_elem_dtype" ],
468
499
metadata ["_block_size" ],
469
500
metadata ["_orig_dtype" ],
501
+ metadata ["_use_fp4_custom_triton_dequant_kernel" ],
470
502
)
471
503
472
504
# Do not force the MXTensor type on the returned tensor
0 commit comments