Skip to content

Commit f90b29e

Browse files
[float8nocompile] support option to not precompute fp8 tensor for backward (#1517)
1 parent e1cb44a commit f90b29e

File tree

3 files changed

+157
-22
lines changed

3 files changed

+157
-22
lines changed

torchao/prototype/float8nocompile/float8nocompile_linear.py

+145-19
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ToFP8ColumnMajor,
1717
ToFP8ColumnMajorT,
1818
ToFP8RowAndColumnMajor,
19+
ToFP8RowMajor,
1920
ToFP8RowMajorTAndNonT,
2021
)
2122
from torchao.prototype.float8nocompile.kernels.fp8_dynamic_tensorwise import (
@@ -36,47 +37,54 @@ def __init__(self, *args, **kwargs):
3637
Additional arguments on top of `torch.nn.Linear`'s arguments:
3738
* `config`: Float8LinearConfig
3839
"""
39-
config = kwargs.pop("config")
40-
kernel_algo = kwargs.pop("kernel_algo")
41-
emulate = config.emulate
40+
self.config = kwargs.pop("config")
41+
self.kernel_algo = kwargs.pop("kernel_algo")
42+
self.no_precompute_for_backward = kwargs.pop(
43+
"no_precompute_for_backward", False
44+
)
4245
super().__init__(*args, **kwargs)
4346

44-
self.config = config
45-
self.kernel_algo = kernel_algo
46-
4747
self.linear_mm_config = LinearMMConfig(
4848
# output
4949
ScaledMMConfig(
50-
emulate,
50+
self.config.emulate,
5151
self.config.gemm_config_output.use_fast_accum,
5252
False,
5353
self.config.pad_inner_dim,
5454
),
5555
# grad_input
5656
ScaledMMConfig(
57-
emulate,
57+
self.config.emulate,
5858
self.config.gemm_config_grad_input.use_fast_accum,
5959
False,
6060
self.config.pad_inner_dim,
6161
),
6262
# grad_weight
6363
ScaledMMConfig(
64-
emulate,
64+
self.config.emulate,
6565
self.config.gemm_config_grad_weight.use_fast_accum,
6666
False,
6767
self.config.pad_inner_dim,
6868
),
6969
)
7070

7171
def forward(self, input: torch.Tensor) -> torch.Tensor:
72-
# TODO(danielvegamyhre): support for FSDP once dependencies are implemented
73-
output = matmul_with_args_in_hp.apply(
74-
input,
75-
self.weight,
76-
self.config,
77-
self.linear_mm_config,
78-
self.kernel_algo,
79-
)
72+
if self.no_precompute_for_backward:
73+
output = matmul_with_args_in_hp_no_precompute_for_backward.apply(
74+
input,
75+
self.weight,
76+
self.config,
77+
self.linear_mm_config,
78+
self.kernel_algo,
79+
)
80+
else:
81+
output = matmul_with_args_in_hp.apply(
82+
input,
83+
self.weight,
84+
self.config,
85+
self.linear_mm_config,
86+
self.kernel_algo,
87+
)
8088
return output
8189

8290
@classmethod
@@ -85,6 +93,7 @@ def from_float(
8593
mod,
8694
config: Float8LinearConfig, # only default config is supported, non-defaults silently ignored
8795
kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
96+
no_precompute_for_backward: bool = False,
8897
):
8998
"""
9099
Create an nn.Linear with fp8 compute from a regular nn.Linear
@@ -101,6 +110,7 @@ def from_float(
101110
bias=False,
102111
config=config,
103112
kernel_algo=kernel_algo,
113+
no_precompute_for_backward=no_precompute_for_backward,
104114
)
105115
new_mod.weight = mod.weight
106116
new_mod.bias = mod.bias
@@ -110,8 +120,20 @@ def from_float(
110120

111121

112122
class matmul_with_args_in_hp(torch.autograd.Function):
123+
"""FP8 matmul with args in high precision to be used in a region without AC.
124+
FP8 tensors only needed for backward are computed as part of kernels in the forward pass,
125+
to reduce number of kernel dispatches and increase throughput, at the cost of higher
126+
peak memory usage."""
127+
113128
@staticmethod
114-
def forward(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo):
129+
def forward(
130+
ctx,
131+
input_hp: torch.Tensor,
132+
weight_hp: torch.Tensor,
133+
config: Float8LinearConfig,
134+
linear_mm_config: LinearMMConfig,
135+
kernel_algo: KernelAlgorithm,
136+
):
115137
# reshape to be 2D for triton kernels
116138
orig_input_shape = input_hp.shape
117139
input_hp = input_hp.reshape(-1, input_hp.shape[-1])
@@ -138,6 +160,7 @@ def forward(ctx, input_hp, weight_hp, config, linear_mm_config, kernel_algo):
138160
ctx.config = config
139161
ctx.linear_mm_config = linear_mm_config
140162
ctx.kernel_algo = kernel_algo
163+
ctx.no_precompute_for_backward = False
141164

142165
# reshape back to expected dims
143166
output = output.reshape(*orig_input_shape[:-1], output.shape[-1])
@@ -178,15 +201,118 @@ def backward(ctx, grad_output):
178201
)
179202
grad_input = torch.mm(grad_output_fp8_row_major, weight_fp8_col_major)
180203

204+
# reshape grad input to match original shape
205+
grad_input = grad_input.reshape(
206+
*orig_grad_output_shape[:-1], grad_input.shape[-1]
207+
)
208+
181209
# grad_weight = grad_output_t @ input
182210
# apparently this variant is slightly faster than `grad_weight_t = input_t @ grad_output`
183211
# source: https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_linear.py#L84-L85
184212
grad_weight = torch.mm(grad_output_t_row_major, input_fp8_col_major)
185213

214+
# grad input shape
215+
return grad_input, grad_weight, None, None, None, None
216+
217+
218+
class matmul_with_args_in_hp_no_precompute_for_backward(torch.autograd.Function):
219+
"""FP8 matmul with args in high precision to be used in a region with AC.
220+
FP8 tensors only needed for backward are only computed in the backward pass
221+
when needed, to reduce peak memory usage."""
222+
223+
@staticmethod
224+
def forward(
225+
ctx,
226+
input_hp: torch.Tensor,
227+
weight_hp: torch.Tensor,
228+
config: Float8LinearConfig,
229+
linear_mm_config: LinearMMConfig,
230+
kernel_algo: KernelAlgorithm,
231+
):
232+
# reshape to be 2D for triton kernels
233+
orig_input_shape = input_hp.shape
234+
input_hp = input_hp.reshape(-1, input_hp.shape[-1])
235+
236+
# output = input @ weight_t
237+
input_fp8_row_major = ToFP8RowMajor.apply(
238+
input_hp,
239+
config.cast_config_input.target_dtype,
240+
linear_mm_config,
241+
GemmInputRole.INPUT,
242+
kernel_algo,
243+
)
244+
weight_t_fp8_col_major = ToFP8ColumnMajorT.apply(
245+
weight_hp,
246+
config.cast_config_weight.target_dtype,
247+
linear_mm_config,
248+
GemmInputRole.WEIGHT,
249+
kernel_algo,
250+
)
251+
output = torch.mm(input_fp8_row_major, weight_t_fp8_col_major)
252+
253+
# with AC we only will save the original hp input tensor and weight for backward,
254+
# and do the necessary fp8 conversions during the backward pass.
255+
ctx.save_for_backward(input_hp, weight_hp)
256+
ctx.config = config
257+
ctx.linear_mm_config = linear_mm_config
258+
ctx.kernel_algo = kernel_algo
259+
ctx.no_precompute_for_backward = True
260+
261+
# reshape back to expected dims
262+
output = output.reshape(*orig_input_shape[:-1], output.shape[-1])
263+
return output
264+
265+
@staticmethod
266+
def backward(ctx, grad_output):
267+
# grad_output may not be contiguous in cases like:
268+
# output.sum().backward() where grad is all 1s, so the (M,N) view of the scalar "1"
269+
# results in a non-contiguous tensor with stride (0,0).
270+
if not grad_output.is_contiguous():
271+
grad_output = grad_output.contiguous()
272+
273+
input_hp, weight_hp = ctx.saved_tensors
274+
275+
# reshsape to be 2D for triton kernels
276+
orig_grad_output_shape = grad_output.shape
277+
grad_output = grad_output.reshape(-1, grad_output.shape[-1])
278+
279+
# cast grad output to float8_e5m2 for backward
280+
grad_output_fp8_row_major, grad_output_t_row_major = (
281+
ToFP8RowMajorTAndNonT.apply(
282+
grad_output,
283+
ctx.config.cast_config_grad_output.target_dtype,
284+
ctx.linear_mm_config,
285+
GemmInputRole.GRAD_OUTPUT,
286+
ctx.kernel_algo,
287+
)
288+
)
289+
290+
# grad_input = grad_output @ weight
291+
weight_fp8_col_major = ToFP8ColumnMajor.apply(
292+
weight_hp,
293+
ctx.config.cast_config_weight.target_dtype,
294+
ctx.linear_mm_config,
295+
GemmInputRole.WEIGHT,
296+
ctx.kernel_algo,
297+
)
298+
grad_input = torch.mm(grad_output_fp8_row_major, weight_fp8_col_major)
299+
186300
# reshape grad input to match original shape
187301
grad_input = grad_input.reshape(
188302
*orig_grad_output_shape[:-1], grad_input.shape[-1]
189303
)
190304

305+
# grad_weight = grad_output_t @ input
306+
# apparently this variant is slightly faster than `grad_weight_t = input_t @ grad_output`
307+
# source: https://github.com/pytorch/ao/blob/fe5f11b2c58b452e01ba9ec7359629928b143619/torchao/float8/float8_linear.py#L84-L85
308+
input_fp8_col_major = ToFP8ColumnMajor.apply(
309+
input_hp,
310+
ctx.config.cast_config_input.target_dtype,
311+
ctx.linear_mm_config,
312+
GemmInputRole.INPUT,
313+
ctx.kernel_algo,
314+
)
315+
grad_weight = torch.mm(grad_output_t_row_major, input_fp8_col_major)
316+
191317
# grad input shape
192-
return grad_input, grad_weight, None, None, None
318+
return grad_input, grad_weight, None, None, None, None

torchao/prototype/float8nocompile/float8nocompile_linear_utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def convert_to_float8_nocompile_training(
2727
config: Float8LinearConfig = None,
2828
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
2929
kernel_algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
30+
no_precompute_for_backward: bool = False,
3031
) -> nn.Module:
3132
"""
3233
Swaps `torch.nn.Linear` in `module` with `Float8LinearNoCompile`.
@@ -45,7 +46,10 @@ def convert_to_float8_nocompile_training(
4546
config = Float8LinearConfig()
4647

4748
from_float = lambda m: Float8LinearNoCompile.from_float(
48-
m, config=config, kernel_algo=kernel_algo
49+
m,
50+
config=config,
51+
kernel_algo=kernel_algo,
52+
no_precompute_for_backward=no_precompute_for_backward,
4953
)
5054
return swap_linear_layers(
5155
module,

torchao/prototype/float8nocompile/test/train_test.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ def model2():
3939
@pytest.mark.parametrize(
4040
"input_shape", [(16, 32), (1, 16, 32), (2, 16, 32), (128, 8192, 32)]
4141
)
42-
def test_model_weights_and_gradients(model1, model2, input_shape: tuple[int, int]):
42+
@pytest.mark.parametrize("no_precompute_for_backward", [True, False])
43+
def test_model_weights_and_gradients(
44+
model1, model2, input_shape: tuple[int, int], no_precompute_for_backward: bool
45+
):
4346
assert torch.cuda.is_available()
4447
device = torch.device("cuda")
4548

@@ -48,7 +51,9 @@ def test_model_weights_and_gradients(model1, model2, input_shape: tuple[int, int
4851

4952
# compare production float8 linear conversion with no-compile version
5053
convert_to_float8_training(model2)
51-
convert_to_float8_nocompile_training(model1)
54+
convert_to_float8_nocompile_training(
55+
model1, no_precompute_for_backward=no_precompute_for_backward
56+
)
5257

5358
input_tensor = torch.randn(
5459
*input_shape, requires_grad=True, dtype=torch.bfloat16, device=device

0 commit comments

Comments
 (0)