File tree Expand file tree Collapse file tree 1 file changed +1
-16
lines changed Expand file tree Collapse file tree 1 file changed +1
-16
lines changed Original file line number Diff line number Diff line change 11
11
12
12
import torch
13
13
14
- from torchao .float8 .config import Float8LinearConfig , ScalingGranularity , ScalingType
14
+ from torchao .float8 .config import Float8LinearConfig , ScalingType
15
15
from torchao .float8 .distributed_utils import tensor_already_casted_to_fp8
16
16
from torchao .float8 .float8_scaling_utils import (
17
17
get_maybe_axiswise_dim ,
@@ -128,21 +128,6 @@ def backward(ctx, grad_output):
128
128
elif c .cast_config_weight_for_grad_input .scaling_type is ScalingType .DISABLED :
129
129
weight_t_maybe_fp8_dim0 = weight_hp_t
130
130
else :
131
- if (
132
- c .cast_config_weight_for_grad_input .scaling_granularity
133
- is ScalingGranularity .AXISWISE
134
- ):
135
- # workaround from https://github.com/pytorch/pytorch/issues/141881
136
- # to avoid saving float8 weight from forward to backward when
137
- # FSDP is on: add a fake dependency on `grad_output`.
138
- g_reshaped = grad_output .reshape (- 1 , grad_output .shape [- 1 ]) * 0
139
- zero = g_reshaped [:1 ] * 0
140
- weight_hp_t = weight_hp_t + zero
141
-
142
- # Note: we need https://github.com/pytorch/pytorch/issues/136267
143
- # to be solved to have a chance to reuse max(abs(weight, dim=...))
144
- # from the forward to get max(abs(weight)) here without reading
145
- # the entire tensor.
146
131
weight_t_maybe_fp8_dim0 = hp_tensor_to_float8_dynamic (
147
132
weight_hp_t ,
148
133
c .cast_config_weight_for_grad_input .target_dtype ,
You can’t perform that action at this time.
0 commit comments