28
28
from float8_experimental .float8_tensor import (
29
29
Float8Tensor ,
30
30
GemmInputRole ,
31
+ hp_tensor_and_scale_to_float8 ,
31
32
LinearMMConfig ,
32
33
ScaledMMConfig ,
33
- ToFloat8ConstrFunc ,
34
34
)
35
35
from float8_experimental .float8_utils import (
36
36
compute_error ,
@@ -66,7 +66,7 @@ def test_preserves_dtype(self) -> None:
66
66
for hp_dtype , lp_dtype in itertools .product (hp_dtypes , lp_dtypes ):
67
67
x1_hp = torch .randn (4 , 4 , dtype = hp_dtype )
68
68
x1_s = tensor_to_scale (x1_hp , lp_dtype )
69
- x2_lp = ToFloat8ConstrFunc . apply (x1_hp , x1_s , lp_dtype )
69
+ x2_lp = hp_tensor_and_scale_to_float8 (x1_hp , x1_s , lp_dtype )
70
70
x3_hp = x2_lp .to_original_precision ()
71
71
self .assertTrue (x3_hp .dtype == hp_dtype )
72
72
@@ -76,7 +76,7 @@ def test_differentiable_casts(self) -> None:
76
76
x = torch .randn (1 ).requires_grad_ ()
77
77
grad = torch .randn (1 )
78
78
x_s = tensor_to_scale (x , f8_dtype )
79
- x_f8 = ToFloat8ConstrFunc . apply (x , x_s , f8_dtype )
79
+ x_f8 = hp_tensor_and_scale_to_float8 (x , x_s , f8_dtype )
80
80
x_f8_hp = x_f8 .to_original_precision ()
81
81
x_f8_hp .backward (grad )
82
82
# the gradient should be unchanged through both casts
@@ -85,7 +85,7 @@ def test_differentiable_casts(self) -> None:
85
85
def test_split_cat (self ):
86
86
a = torch .rand (16 , 16 , dtype = torch .bfloat16 )
87
87
scale = tensor_to_scale (a , e4m3_dtype )
88
- fp8_a = ToFloat8ConstrFunc . apply (a , scale , e4m3_dtype )
88
+ fp8_a = hp_tensor_and_scale_to_float8 (a , scale , e4m3_dtype )
89
89
90
90
splits = torch .split (fp8_a , 16 )
91
91
catted = torch .cat (splits , dim = 0 )
@@ -94,14 +94,14 @@ def test_split_cat(self):
94
94
def test_index_put (self ):
95
95
a = torch .rand (16 , dtype = torch .bfloat16 )
96
96
scale_a = tensor_to_scale (a , torch .float8_e4m3fn )
97
- fp8_a = ToFloat8ConstrFunc . apply (a , scale_a , torch .float8_e4m3fn )
97
+ fp8_a = hp_tensor_and_scale_to_float8 (a , scale_a , torch .float8_e4m3fn )
98
98
99
99
index = torch .randint (0 , 15 , (16 ,), dtype = torch .long )
100
100
101
101
b = torch .rand (16 , 16 , dtype = torch .bfloat16 )
102
102
scale_b = tensor_to_scale (b , torch .float8_e4m3fn )
103
- fp8_b = ToFloat8ConstrFunc . apply (b , scale_a , torch .float8_e4m3fn )
104
- fp8_b_bad = ToFloat8ConstrFunc . apply (b , scale_b , torch .float8_e4m3fn )
103
+ fp8_b = hp_tensor_and_scale_to_float8 (b , scale_a , torch .float8_e4m3fn )
104
+ fp8_b_bad = hp_tensor_and_scale_to_float8 (b , scale_b , torch .float8_e4m3fn )
105
105
106
106
with self .assertRaises (AssertionError ):
107
107
b [index ] = fp8_a
@@ -112,7 +112,7 @@ def test_index_put(self):
112
112
def test_copy_ (self ):
113
113
a = torch .rand (16 , dtype = torch .bfloat16 )
114
114
scale_a = tensor_to_scale (a , torch .float8_e4m3fn )
115
- fp8_a = ToFloat8ConstrFunc . apply (a , scale_a , torch .float8_e4m3fn )
115
+ fp8_a = hp_tensor_and_scale_to_float8 (a , scale_a , torch .float8_e4m3fn )
116
116
117
117
b = torch .empty (16 , dtype = torch .bfloat16 )
118
118
b .copy_ (fp8_a ) # Should work
@@ -407,8 +407,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
407
407
a_scale = tensor_to_scale (a , input_dtype ).float ()
408
408
b_scale = tensor_to_scale (b , input_dtype ).float ()
409
409
410
- a_fp8 = ToFloat8ConstrFunc . apply (a , a_scale , input_dtype )
411
- b_fp8 = ToFloat8ConstrFunc . apply (b , b_scale , input_dtype )
410
+ a_fp8 = hp_tensor_and_scale_to_float8 (a , a_scale , input_dtype )
411
+ b_fp8 = hp_tensor_and_scale_to_float8 (b , b_scale , input_dtype )
412
412
413
413
out_scaled_mm = addmm_float8_unwrapped (
414
414
a_fp8 ._data ,
@@ -447,14 +447,14 @@ def test_different_configs_error(self):
447
447
ScaledMMConfig (True , False , False , False ),
448
448
ScaledMMConfig (True , False , False , False ),
449
449
)
450
- a = ToFloat8ConstrFunc . apply (
450
+ a = hp_tensor_and_scale_to_float8 (
451
451
x_fp32 ,
452
452
x_scale ,
453
453
fp8_dtype ,
454
454
linear_config_a ,
455
455
GemmInputRole .INPUT ,
456
456
)
457
- b = ToFloat8ConstrFunc . apply (
457
+ b = hp_tensor_and_scale_to_float8 (
458
458
x_fp32 ,
459
459
x_scale ,
460
460
fp8_dtype ,
@@ -486,10 +486,10 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
486
486
a_scale = tensor_to_scale (a , input_dtype ).float ()
487
487
b_scale = tensor_to_scale (b , input_dtype ).float ()
488
488
489
- a_fp8 = ToFloat8ConstrFunc . apply (
489
+ a_fp8 = hp_tensor_and_scale_to_float8 (
490
490
a , a_scale , input_dtype , None , GemmInputRole .INPUT
491
491
)
492
- b_fp8 = ToFloat8ConstrFunc . apply (
492
+ b_fp8 = hp_tensor_and_scale_to_float8 (
493
493
b , b_scale , input_dtype , None , GemmInputRole .WEIGHT
494
494
)
495
495
@@ -506,14 +506,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
506
506
scaled_mm_config , scaled_mm_config , scaled_mm_config
507
507
)
508
508
509
- a_fp8 = ToFloat8ConstrFunc . apply (
509
+ a_fp8 = hp_tensor_and_scale_to_float8 (
510
510
a ,
511
511
a_scale ,
512
512
input_dtype ,
513
513
pad_config ,
514
514
GemmInputRole .INPUT ,
515
515
)
516
- b_fp8 = ToFloat8ConstrFunc . apply (
516
+ b_fp8 = hp_tensor_and_scale_to_float8 (
517
517
b ,
518
518
b_scale ,
519
519
input_dtype ,
@@ -529,14 +529,14 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
529
529
emulated_scaled_mm_config ,
530
530
emulated_scaled_mm_config ,
531
531
)
532
- a_fp8 = ToFloat8ConstrFunc . apply (
532
+ a_fp8 = hp_tensor_and_scale_to_float8 (
533
533
a ,
534
534
a_scale ,
535
535
input_dtype ,
536
536
emulated_config ,
537
537
GemmInputRole .INPUT ,
538
538
)
539
- b_fp8 = ToFloat8ConstrFunc . apply (
539
+ b_fp8 = hp_tensor_and_scale_to_float8 (
540
540
b ,
541
541
b_scale ,
542
542
input_dtype ,
@@ -695,19 +695,19 @@ def test_fp8_tensor_statistics(self):
695
695
696
696
# Overflow caused by a too large scaling factor
697
697
s_overflow = torch .tensor (1e9 )
698
- fp8_overflow = ToFloat8ConstrFunc . apply (x1_hp , s_overflow , lp_dtype )
698
+ fp8_overflow = hp_tensor_and_scale_to_float8 (x1_hp , s_overflow , lp_dtype )
699
699
(zero_cnt , max_cnt ) = fp8_tensor_statistics (fp8_overflow , lp_dtype )
700
700
self .assertEqual ((zero_cnt , max_cnt ), (0 , tensor_len ))
701
701
702
702
# Underflow caused by a too small scaling factor
703
703
s_underflow = torch .tensor (1e-9 )
704
- fp8_underflow = ToFloat8ConstrFunc . apply (x1_hp , s_underflow , lp_dtype )
704
+ fp8_underflow = hp_tensor_and_scale_to_float8 (x1_hp , s_underflow , lp_dtype )
705
705
(zero_cnt , max_cnt ) = fp8_tensor_statistics (fp8_underflow , lp_dtype )
706
706
self .assertEqual ((zero_cnt , max_cnt ), (tensor_len , 0 ))
707
707
708
708
# Both overflow and underflow
709
709
x2_hp = torch .cat ((x1_hp * 1e9 , x1_hp * 1.0 , x1_hp * 1e-9 ), 0 )
710
- fp8_over_underflow = ToFloat8ConstrFunc . apply (
710
+ fp8_over_underflow = hp_tensor_and_scale_to_float8 (
711
711
x2_hp , torch .tensor (1.0 ), lp_dtype
712
712
)
713
713
(zero_cnt , max_cnt ) = fp8_tensor_statistics (fp8_over_underflow , lp_dtype )
0 commit comments