Skip to content

Commit 3fd7a06

Browse files
committed
Update
[ghstack-poisoned]
1 parent 1c1ad9c commit 3fd7a06

File tree

1 file changed

+1
-16
lines changed

1 file changed

+1
-16
lines changed

torchao/float8/float8_linear.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import torch
1313

14-
from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType
14+
from torchao.float8.config import Float8LinearConfig, ScalingType
1515
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
1616
from torchao.float8.float8_scaling_utils import (
1717
get_maybe_axiswise_dim,
@@ -128,21 +128,6 @@ def backward(ctx, grad_output):
128128
elif c.cast_config_weight_for_grad_input.scaling_type is ScalingType.DISABLED:
129129
weight_t_maybe_fp8_dim0 = weight_hp_t
130130
else:
131-
if (
132-
c.cast_config_weight_for_grad_input.scaling_granularity
133-
is ScalingGranularity.AXISWISE
134-
):
135-
# workaround from https://github.com/pytorch/pytorch/issues/141881
136-
# to avoid saving float8 weight from forward to backward when
137-
# FSDP is on: add a fake dependency on `grad_output`.
138-
g_reshaped = grad_output.reshape(-1, grad_output.shape[-1]) * 0
139-
zero = g_reshaped[:1] * 0
140-
weight_hp_t = weight_hp_t + zero
141-
142-
# Note: we need https://github.com/pytorch/pytorch/issues/136267
143-
# to be solved to have a chance to reuse max(abs(weight, dim=...))
144-
# from the forward to get max(abs(weight)) here without reading
145-
# the entire tensor.
146131
weight_t_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic(
147132
weight_hp_t,
148133
c.cast_config_weight_for_grad_input.target_dtype,

0 commit comments

Comments
 (0)