1515from torchao .float8 .config import Float8LinearConfig , ScalingGranularity , ScalingType
1616from torchao .float8 .distributed_utils import tensor_already_casted_to_fp8
1717from torchao .float8 .float8_scaling_utils import (
18- NoopFwToFloat8BwDynamic ,
1918 get_maybe_axiswise_dim ,
2019 hp_tensor_to_float8_dynamic ,
2120)
2928from torchao .float8 .fsdp_utils import WeightWithDynamicFloat8CastTensor
3029
3130
32- def _cast_input_to_float8 (
33- input : torch .Tensor ,
34- scaling_type_input : ScalingType ,
35- config : Float8LinearConfig ,
36- linear_mm_config : LinearMMConfig ,
37- ) -> torch .Tensor :
38- # Duplicate the autocast logic for F.linear, so that the output
39- # of our module has the right original precision
40- if torch .is_autocast_enabled ():
41- # For now, hardcode to GPU's autocast dtype
42- # if we need CPU support in the future, we can add it
43- autocast_dtype = torch .get_autocast_gpu_dtype ()
44- input = input .to (autocast_dtype )
45-
46- if tensor_already_casted_to_fp8 (input ):
47- input_fp8 = input
48- else :
49- assert scaling_type_input is ScalingType .DYNAMIC
50- input_fp8 = hp_tensor_to_float8_dynamic (
51- input ,
52- config .cast_config_input .target_dtype ,
53- linear_mm_config ,
54- gemm_input_role = GemmInputRole .INPUT ,
55- )
56- return input_fp8
57-
58-
5931def _get_weight_scale (
6032 weight : torch .Tensor ,
6133 scaling_type_weight : ScalingType ,
@@ -85,21 +57,6 @@ def _cast_weight_to_float8_t(
8557 return weight_fp8 .t ()
8658
8759
88- def _cast_output_to_float8_in_bw (
89- output : torch .Tensor ,
90- scaling_type_grad_output ,
91- linear_mm_config : LinearMMConfig ,
92- config : Float8LinearConfig ,
93- ) -> torch .Tensor :
94- assert scaling_type_grad_output is ScalingType .DYNAMIC
95- output = NoopFwToFloat8BwDynamic .apply (
96- output ,
97- linear_mm_config ,
98- config .cast_config_grad_output .target_dtype ,
99- )
100- return output
101-
102-
10360@torch ._dynamo .allow_in_graph
10461class matmul_with_hp_or_float8_args (torch .autograd .Function ):
10562 """
@@ -329,6 +286,14 @@ def __init__(self, *args, **kwargs):
329286 )
330287
331288 def forward (self , input : torch .Tensor ) -> torch .Tensor :
289+ # Duplicate the autocast logic for F.linear, so that the output
290+ # of our module has the right original precision
291+ if torch .is_autocast_enabled ():
292+ # For now, hardcode to GPU's autocast dtype
293+ # if we need CPU support in the future, we can add it
294+ autocast_dtype = torch .get_autocast_gpu_dtype ()
295+ input = input .to (autocast_dtype )
296+
332297 has_any_axiswise_scaling = any (
333298 cc .scaling_granularity is ScalingGranularity .AXISWISE
334299 for cc in [
@@ -341,18 +306,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
341306 ]
342307 )
343308
344- input_maybe_fp8 = input
345309 weight_maybe_fp8_t = self .weight .t ()
346310
347311 # TODO(future PR): check for axiswise scaling for input, weight,
348312 # grad_output separately instead of together
349313 if not has_any_axiswise_scaling :
350- input_fp8 = _cast_input_to_float8 (
351- input ,
352- self .scaling_type_input ,
353- self .config ,
354- self .linear_mm_config ,
355- )
356314 # If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight,
357315 # weight_scale should be saved.
358316 weight_scale = _get_weight_scale (
@@ -375,25 +333,15 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
375333 weight_scale ,
376334 )
377335
378- input_maybe_fp8 = input_fp8
379336 weight_maybe_fp8_t = weight_fp8_t
380337
381338 output = matmul_with_hp_or_float8_args .apply (
382- input_maybe_fp8 ,
339+ input ,
383340 weight_maybe_fp8_t ,
384341 self .linear_mm_config ,
385342 self .config ,
386343 )
387344
388- if not has_any_axiswise_scaling :
389- # Cast grad_output to float8_e5m2 during backward
390- output = _cast_output_to_float8_in_bw (
391- output ,
392- self .scaling_type_grad_output ,
393- self .linear_mm_config ,
394- self .config ,
395- )
396-
397345 if self .bias is not None :
398346 output = output + self .bias .to (output .dtype )
399347 return output
0 commit comments