16
16
ToFP8ColumnMajor ,
17
17
ToFP8ColumnMajorT ,
18
18
ToFP8RowAndColumnMajor ,
19
+ ToFP8RowMajor ,
19
20
ToFP8RowMajorTAndNonT ,
20
21
)
21
22
from torchao .prototype .float8nocompile .kernels .fp8_dynamic_tensorwise import (
@@ -36,47 +37,54 @@ def __init__(self, *args, **kwargs):
36
37
Additional arguments on top of `torch.nn.Linear`'s arguments:
37
38
* `config`: Float8LinearConfig
38
39
"""
39
- config = kwargs .pop ("config" )
40
- kernel_algo = kwargs .pop ("kernel_algo" )
41
- emulate = config .emulate
40
+ self .config = kwargs .pop ("config" )
41
+ self .kernel_algo = kwargs .pop ("kernel_algo" )
42
+ self .no_precompute_for_backward = kwargs .pop (
43
+ "no_precompute_for_backward" , False
44
+ )
42
45
super ().__init__ (* args , ** kwargs )
43
46
44
- self .config = config
45
- self .kernel_algo = kernel_algo
46
-
47
47
self .linear_mm_config = LinearMMConfig (
48
48
# output
49
49
ScaledMMConfig (
50
- emulate ,
50
+ self . config . emulate ,
51
51
self .config .gemm_config_output .use_fast_accum ,
52
52
False ,
53
53
self .config .pad_inner_dim ,
54
54
),
55
55
# grad_input
56
56
ScaledMMConfig (
57
- emulate ,
57
+ self . config . emulate ,
58
58
self .config .gemm_config_grad_input .use_fast_accum ,
59
59
False ,
60
60
self .config .pad_inner_dim ,
61
61
),
62
62
# grad_weight
63
63
ScaledMMConfig (
64
- emulate ,
64
+ self . config . emulate ,
65
65
self .config .gemm_config_grad_weight .use_fast_accum ,
66
66
False ,
67
67
self .config .pad_inner_dim ,
68
68
),
69
69
)
70
70
71
71
def forward (self , input : torch .Tensor ) -> torch .Tensor :
72
- # TODO(danielvegamyhre): support for FSDP once dependencies are implemented
73
- output = matmul_with_args_in_hp .apply (
74
- input ,
75
- self .weight ,
76
- self .config ,
77
- self .linear_mm_config ,
78
- self .kernel_algo ,
79
- )
72
+ if self .no_precompute_for_backward :
73
+ output = matmul_with_args_in_hp_no_precompute_for_backward .apply (
74
+ input ,
75
+ self .weight ,
76
+ self .config ,
77
+ self .linear_mm_config ,
78
+ self .kernel_algo ,
79
+ )
80
+ else :
81
+ output = matmul_with_args_in_hp .apply (
82
+ input ,
83
+ self .weight ,
84
+ self .config ,
85
+ self .linear_mm_config ,
86
+ self .kernel_algo ,
87
+ )
80
88
return output
81
89
82
90
@classmethod
@@ -85,6 +93,7 @@ def from_float(
85
93
mod ,
86
94
config : Float8LinearConfig , # only default config is supported, non-defaults silently ignored
87
95
kernel_algo : KernelAlgorithm = KernelAlgorithm .ATOMIC_MAX ,
96
+ no_precompute_for_backward : bool = False ,
88
97
):
89
98
"""
90
99
Create an nn.Linear with fp8 compute from a regular nn.Linear
@@ -101,6 +110,7 @@ def from_float(
101
110
bias = False ,
102
111
config = config ,
103
112
kernel_algo = kernel_algo ,
113
+ no_precompute_for_backward = no_precompute_for_backward ,
104
114
)
105
115
new_mod .weight = mod .weight
106
116
new_mod .bias = mod .bias
@@ -110,8 +120,20 @@ def from_float(
110
120
111
121
112
122
class matmul_with_args_in_hp (torch .autograd .Function ):
123
+ """FP8 matmul with args in high precision to be used in a region without AC.
124
+ FP8 tensors only needed for backward are computed as part of kernels in the forward pass,
125
+ to reduce number of kernel dispatches and increase throughput, at the cost of higher
126
+ peak memory usage."""
127
+
113
128
@staticmethod
114
- def forward (ctx , input_hp , weight_hp , config , linear_mm_config , kernel_algo ):
129
+ def forward (
130
+ ctx ,
131
+ input_hp : torch .Tensor ,
132
+ weight_hp : torch .Tensor ,
133
+ config : Float8LinearConfig ,
134
+ linear_mm_config : LinearMMConfig ,
135
+ kernel_algo : KernelAlgorithm ,
136
+ ):
115
137
# reshape to be 2D for triton kernels
116
138
orig_input_shape = input_hp .shape
117
139
input_hp = input_hp .reshape (- 1 , input_hp .shape [- 1 ])
@@ -138,6 +160,7 @@ def forward(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo):
138
160
ctx .config = config
139
161
ctx .linear_mm_config = linear_mm_config
140
162
ctx .kernel_algo = kernel_algo
163
+ ctx .no_precompute_for_backward = False
141
164
142
165
# reshape back to expected dims
143
166
output = output .reshape (* orig_input_shape [:- 1 ], output .shape [- 1 ])
@@ -178,15 +201,118 @@ def backward(ctx, grad_output):
178
201
)
179
202
grad_input = torch .mm (grad_output_fp8_row_major , weight_fp8_col_major )
180
203
204
+ # reshape grad input to match original shape
205
+ grad_input = grad_input .reshape (
206
+ * orig_grad_output_shape [:- 1 ], grad_input .shape [- 1 ]
207
+ )
208
+
181
209
# grad_weight = grad_output_t @ input
182
210
# apparently this variant is slightly faster than `grad_weight_t = input_t @ grad_output`
183
211
# source: https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_linear.py#L84-L85
184
212
grad_weight = torch .mm (grad_output_t_row_major , input_fp8_col_major )
185
213
214
+ # grad input shape
215
+ return grad_input , grad_weight , None , None , None , None
216
+
217
+
218
+ class matmul_with_args_in_hp_no_precompute_for_backward (torch .autograd .Function ):
219
+ """FP8 matmul with args in high precision to be used in a region with AC.
220
+ FP8 tensors only needed for backward are only computed in the backward pass
221
+ when needed, to reduce peak memory usage."""
222
+
223
+ @staticmethod
224
+ def forward (
225
+ ctx ,
226
+ input_hp : torch .Tensor ,
227
+ weight_hp : torch .Tensor ,
228
+ config : Float8LinearConfig ,
229
+ linear_mm_config : LinearMMConfig ,
230
+ kernel_algo : KernelAlgorithm ,
231
+ ):
232
+ # reshape to be 2D for triton kernels
233
+ orig_input_shape = input_hp .shape
234
+ input_hp = input_hp .reshape (- 1 , input_hp .shape [- 1 ])
235
+
236
+ # output = input @ weight_t
237
+ input_fp8_row_major = ToFP8RowMajor .apply (
238
+ input_hp ,
239
+ config .cast_config_input .target_dtype ,
240
+ linear_mm_config ,
241
+ GemmInputRole .INPUT ,
242
+ kernel_algo ,
243
+ )
244
+ weight_t_fp8_col_major = ToFP8ColumnMajorT .apply (
245
+ weight_hp ,
246
+ config .cast_config_weight .target_dtype ,
247
+ linear_mm_config ,
248
+ GemmInputRole .WEIGHT ,
249
+ kernel_algo ,
250
+ )
251
+ output = torch .mm (input_fp8_row_major , weight_t_fp8_col_major )
252
+
253
+ # with AC we only will save the original hp input tensor and weight for backward,
254
+ # and do the necessary fp8 conversions during the backward pass.
255
+ ctx .save_for_backward (input_hp , weight_hp )
256
+ ctx .config = config
257
+ ctx .linear_mm_config = linear_mm_config
258
+ ctx .kernel_algo = kernel_algo
259
+ ctx .no_precompute_for_backward = True
260
+
261
+ # reshape back to expected dims
262
+ output = output .reshape (* orig_input_shape [:- 1 ], output .shape [- 1 ])
263
+ return output
264
+
265
+ @staticmethod
266
+ def backward (ctx , grad_output ):
267
+ # grad_output may not be contiguous in cases like:
268
+ # output.sum().backward() where grad is all 1s, so the (M,N) view of the scalar "1"
269
+ # results in a non-contiguous tensor with stride (0,0).
270
+ if not grad_output .is_contiguous ():
271
+ grad_output = grad_output .contiguous ()
272
+
273
+ input_hp , weight_hp = ctx .saved_tensors
274
+
275
+ # reshsape to be 2D for triton kernels
276
+ orig_grad_output_shape = grad_output .shape
277
+ grad_output = grad_output .reshape (- 1 , grad_output .shape [- 1 ])
278
+
279
+ # cast grad output to float8_e5m2 for backward
280
+ grad_output_fp8_row_major , grad_output_t_row_major = (
281
+ ToFP8RowMajorTAndNonT .apply (
282
+ grad_output ,
283
+ ctx .config .cast_config_grad_output .target_dtype ,
284
+ ctx .linear_mm_config ,
285
+ GemmInputRole .GRAD_OUTPUT ,
286
+ ctx .kernel_algo ,
287
+ )
288
+ )
289
+
290
+ # grad_input = grad_output @ weight
291
+ weight_fp8_col_major = ToFP8ColumnMajor .apply (
292
+ weight_hp ,
293
+ ctx .config .cast_config_weight .target_dtype ,
294
+ ctx .linear_mm_config ,
295
+ GemmInputRole .WEIGHT ,
296
+ ctx .kernel_algo ,
297
+ )
298
+ grad_input = torch .mm (grad_output_fp8_row_major , weight_fp8_col_major )
299
+
186
300
# reshape grad input to match original shape
187
301
grad_input = grad_input .reshape (
188
302
* orig_grad_output_shape [:- 1 ], grad_input .shape [- 1 ]
189
303
)
190
304
305
+ # grad_weight = grad_output_t @ input
306
+ # apparently this variant is slightly faster than `grad_weight_t = input_t @ grad_output`
307
+ # source: https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_linear.py#L84-L85
308
+ input_fp8_col_major = ToFP8ColumnMajor .apply (
309
+ input_hp ,
310
+ ctx .config .cast_config_input .target_dtype ,
311
+ ctx .linear_mm_config ,
312
+ GemmInputRole .INPUT ,
313
+ ctx .kernel_algo ,
314
+ )
315
+ grad_weight = torch .mm (grad_output_t_row_major , input_fp8_col_major )
316
+
191
317
# grad input shape
192
- return grad_input , grad_weight , None , None , None
318
+ return grad_input , grad_weight , None , None , None , None
0 commit comments