16
16
GemmInputRole ,
17
17
LinearMMConfig ,
18
18
)
19
+
19
20
from torchao .prototype .float8nocompile .kernels .fp8_dynamic_tensorwise import (
21
+ hp_to_fp8_col_major ,
22
+ hp_to_fp8_col_major_t ,
23
+ hp_to_fp8_row_and_col_major ,
24
+ hp_to_fp8_row_major ,
25
+ hp_to_fp8_row_major_t ,
20
26
KernelAlgorithm ,
21
- triton_hp_tensor_to_float8_dynamic ,
22
27
)
23
28
24
- # avoid division by zero when calculating scale
25
- # TODO: align this value with NVIDIA's assumptions (current value is a guess)
26
- EPS = 1e-12
27
29
30
+ class ToFP8RowAndColumnMajor (torch .autograd .Function ):
31
+ """
32
+ A differentiable conversion to fp8.
33
+ * forward: convert from high precision to float8 and produces both row-major and column-major outputs
34
+ * backward: pass the gradient without changes
35
+ """
36
+
37
+ @staticmethod
38
+ def forward (
39
+ ctx ,
40
+ tensor : torch .Tensor ,
41
+ float8_dtype : torch .dtype ,
42
+ linear_mm_config : LinearMMConfig ,
43
+ gemm_input_role : GemmInputRole ,
44
+ kernel_algo : KernelAlgorithm = KernelAlgorithm .ATOMIC_MAX ,
45
+ ):
46
+ fp8_row_major , fp8_col_major = hp_to_fp8_row_and_col_major (
47
+ tensor ,
48
+ float8_dtype ,
49
+ linear_mm_config ,
50
+ gemm_input_role ,
51
+ algo = kernel_algo ,
52
+ )
53
+ return fp8_row_major , fp8_col_major
28
54
29
- def hp_tensor_to_float8nocompile_dynamic (
30
- hp_tensor : torch . Tensor ,
31
- float8_dtype : torch . dtype ,
32
- linear_mm_config : LinearMMConfig ,
33
- gemm_input_role : GemmInputRole = GemmInputRole . INPUT ,
34
- ) -> Float8Tensor :
55
+ @ staticmethod
56
+ def backward ( ctx , g ):
57
+ return g , None , None , None , None
58
+
59
+
60
+ class ToFP8RowMajor ( torch . autograd . Function ) :
35
61
"""
36
- Given a high precision tensor `hp_tensor`,
37
- scales `hp_tensor` dynamically and returns a `Float8Tensor` of the result.
38
-
39
- Args:
40
- hp_tensor: the tensor to convert
41
- float8_dtype: the float8 dtype to use
42
- linear_mm_config: Defines the configuration for the scaled_mm for
43
- the 3 fwd/bwd gemms of linear
44
- gemm_input_role: Defines the role of this tensor (input, weight or grad_output) in
45
- the 3 fwd/bwd gemms of linear
62
+ A differentiable conversion to fp8 in row-major layout.
63
+ * forward: convert from high precision to float8 with row-major memory layout
64
+ * backward: pass the gradient without changes
46
65
"""
47
- # TODO(danielvegamyhre): replace this torch implementation with custom triton kernel
48
- # torch.compile and eager show different numerics for 1.0 / float32,
49
- # upcast to float64 to ensure same numeric between compile and eager
50
- amax = torch .max (torch .abs (hp_tensor )).to (torch .float64 )
51
- scale = torch .finfo (float8_dtype ).max / torch .clamp (amax , min = EPS )
52
- scale = scale .to (torch .float32 ) # scale must be fp32
53
- return _ToFloat8ConstrFunc .apply (
54
- hp_tensor ,
55
- scale ,
56
- float8_dtype ,
57
- linear_mm_config ,
58
- gemm_input_role ,
59
- None ,
60
- )
61
-
62
-
63
- class Float8NoCompileConversionFunc (torch .autograd .Function ):
66
+
67
+ @staticmethod
68
+ def forward (
69
+ ctx ,
70
+ tensor : torch .Tensor ,
71
+ float8_dtype : torch .dtype ,
72
+ linear_mm_config : LinearMMConfig ,
73
+ gemm_input_role : GemmInputRole ,
74
+ kernel_algo : KernelAlgorithm = KernelAlgorithm .ATOMIC_MAX ,
75
+ ):
76
+ fp8_row_major = hp_to_fp8_row_major (
77
+ tensor ,
78
+ float8_dtype ,
79
+ linear_mm_config ,
80
+ gemm_input_role ,
81
+ algo = kernel_algo ,
82
+ )
83
+ return fp8_row_major
84
+
85
+ @staticmethod
86
+ def backward (ctx , g ):
87
+ return g , None , None , None , None
88
+
89
+
90
+ class ToFP8RowMajorT (torch .autograd .Function ):
64
91
"""
65
- A differentiable conversion to fp8.
66
- * forward: convert from high precision to float8
92
+ A differentiable conversion to fp8 with transposed dimensions in row-major layout .
93
+ * forward: convert from high precision to float8 with transposed dimensions with row-major memory layout
67
94
* backward: pass the gradient without changes
68
95
"""
69
96
@@ -76,24 +103,25 @@ def forward(
76
103
gemm_input_role : GemmInputRole ,
77
104
kernel_algo : KernelAlgorithm = KernelAlgorithm .ATOMIC_MAX ,
78
105
):
79
- return triton_hp_tensor_to_float8_dynamic (
106
+ fp8_row_major_t = hp_to_fp8_row_major_t (
80
107
tensor ,
81
108
float8_dtype ,
82
109
linear_mm_config ,
83
110
gemm_input_role ,
84
111
algo = kernel_algo ,
85
112
)
113
+ return fp8_row_major_t
86
114
87
115
@staticmethod
88
116
def backward (ctx , g ):
89
- return g , None , None , None , None , None
117
+ return g , None , None , None , None
90
118
91
119
92
- class NoopFwToFloat8NoCompileBwDynamic (torch .autograd .Function ):
120
+ class ToFP8ColumnMajor (torch .autograd .Function ):
93
121
"""
94
- A differentiable conversion to fp8.
95
- * forward: no-op
96
- * backward: convert to float8 with tensor-wise dynamic scaling
122
+ A differentiable conversion to fp8 in column-major layout .
123
+ * forward: convert from high precision to float8 with column-major memory layout
124
+ * backward: pass the gradient without changes
97
125
"""
98
126
99
127
@staticmethod
@@ -102,20 +130,48 @@ def forward(
102
130
tensor : torch .Tensor ,
103
131
float8_dtype : torch .dtype ,
104
132
linear_mm_config : LinearMMConfig ,
133
+ gemm_input_role : GemmInputRole ,
105
134
kernel_algo : KernelAlgorithm = KernelAlgorithm .ATOMIC_MAX ,
106
135
):
107
- ctx .linear_mm_config = linear_mm_config
108
- ctx .target_dtype = float8_dtype
109
- ctx .kernel_algo = kernel_algo
110
- return tensor
136
+ fp8_col_major = hp_to_fp8_col_major (
137
+ tensor ,
138
+ float8_dtype ,
139
+ linear_mm_config ,
140
+ gemm_input_role ,
141
+ algo = kernel_algo ,
142
+ )
143
+ return fp8_col_major
144
+
145
+ @staticmethod
146
+ def backward (ctx , g ):
147
+ return g , None , None , None , None
148
+
149
+
150
+ class ToFP8ColumnMajorT (torch .autograd .Function ):
151
+ """
152
+ A differentiable conversion to fp8 with transposed dimensions in column-major layout.
153
+ * forward: convert from high precision to float8 with transposed dimensions in column-major memory layout.
154
+ * backward: pass the gradient without changes
155
+ """
111
156
112
157
@staticmethod
113
- def backward (ctx , gradY ):
114
- fp8_tensor = triton_hp_tensor_to_float8_dynamic (
115
- gradY ,
116
- ctx .target_dtype ,
117
- ctx .linear_mm_config ,
118
- GemmInputRole .GRAD_OUTPUT ,
119
- ctx .kernel_algo ,
158
+ def forward (
159
+ ctx ,
160
+ tensor : torch .Tensor ,
161
+ float8_dtype : torch .dtype ,
162
+ linear_mm_config : LinearMMConfig ,
163
+ gemm_input_role : GemmInputRole ,
164
+ kernel_algo : KernelAlgorithm = KernelAlgorithm .ATOMIC_MAX ,
165
+ ):
166
+ fp8_col_major_t = hp_to_fp8_col_major_t (
167
+ tensor ,
168
+ float8_dtype ,
169
+ linear_mm_config ,
170
+ gemm_input_role ,
171
+ algo = kernel_algo ,
120
172
)
121
- return fp8_tensor , None , None , None
173
+ return fp8_col_major_t
174
+
175
+ @staticmethod
176
+ def backward (ctx , g ):
177
+ return g , None , None , None , None
0 commit comments