Skip to content

Commit 9853ba2

Browse files
committed
remove outdated Float8Linear workarounds
Summary: These workarounds are no longer needed after #2356 and the corresponding improvements in PyTorch core. Test Plan: torchtitan bench on llama 3 8b on 8 H100s: before rowwise Median Tokens/Second (excluding step 1): 7013.0 Max Memory Usage: 37.19 GiB gw_hp Median Tokens/Second (excluding step 1): 7232.0 Max Memory Usage: 37.13 GiB after rowwise Median Tokens/Second (excluding step 1): 6984.5 Max Memory Usage: 37.19 GiB gw_hp Median Tokens/Second (excluding step 1): 7319.5 Max Memory Usage: 37.13 GiB Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: ae11ea7 ghstack-comment-id: 3113561383 Pull Request resolved: #2595
1 parent c6de9b4 commit 9853ba2

File tree

1 file changed

+1
-16
lines changed

1 file changed

+1
-16
lines changed

torchao/float8/float8_linear.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import torch
1313

14-
from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType
14+
from torchao.float8.config import Float8LinearConfig, ScalingType
1515
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
1616
from torchao.float8.float8_scaling_utils import (
1717
get_maybe_axiswise_dim,
@@ -128,21 +128,6 @@ def backward(ctx, grad_output):
128128
elif c.cast_config_weight_for_grad_input.scaling_type is ScalingType.DISABLED:
129129
weight_t_maybe_fp8_dim0 = weight_hp_t
130130
else:
131-
if (
132-
c.cast_config_weight_for_grad_input.scaling_granularity
133-
is ScalingGranularity.AXISWISE
134-
):
135-
# workaround from https://github.com/pytorch/pytorch/issues/141881
136-
# to avoid saving float8 weight from forward to backward when
137-
# FSDP is on: add a fake dependency on `grad_output`.
138-
g_reshaped = grad_output.reshape(-1, grad_output.shape[-1]) * 0
139-
zero = g_reshaped[:1] * 0
140-
weight_hp_t = weight_hp_t + zero
141-
142-
# Note: we need https://github.com/pytorch/pytorch/issues/136267
143-
# to be solved to have a chance to reuse max(abs(weight, dim=...))
144-
# from the forward to get max(abs(weight)) here without reading
145-
# the entire tensor.
146131
weight_t_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic(
147132
weight_hp_t,
148133
c.cast_config_weight_for_grad_input.target_dtype,

0 commit comments

Comments
 (0)