Skip to content

Commit 3eb18e7

Browse files
authored
float8 rowwise training: add FSDP workaround (#1629)
Summary: Adds the workaround from pytorch/pytorch#141881 to the torchao float8 rowwise recipe, to reduce memory usage when FSDP is on. Test Plan: tested in torchtitan, LLaMa 3 8B 8H100 training with rowwise peak memory decreased from 67GiB to 59GiB Reviewers: Subscribers: Tasks: Tags:
1 parent 48fdd31 commit 3eb18e7

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

torchao/float8/float8_linear.py

+9
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,15 @@ def backward(ctx, grad_output):
159159
elif c.cast_config_weight_for_grad_input.scaling_type is ScalingType.DISABLED:
160160
weight_t_maybe_fp8_dim0 = weight_hp_t
161161
else:
162+
if (
163+
c.cast_config_weight_for_grad_input.scaling_granularity
164+
is ScalingGranularity.AXISWISE
165+
):
166+
# workaround from https://github.com/pytorch/pytorch/issues/141881
167+
# to avoid saving float8 weight from forward to backward when
168+
# FSDP is on
169+
weight_hp_t = weight_hp_t + (grad_output_reshaped[0, 0] * 0)
170+
162171
# Note: we need https://github.com/pytorch/pytorch/issues/136267
163172
# to be solved to have a chance to reuse max(abs(weight, dim=...))
164173
# from the forward to get max(abs(weight)) here without reading

0 commit comments

Comments
 (0)