@@ -125,7 +125,7 @@ def backward(ctx, go):
125
125
fp8_scale_grad_output ,
126
126
e5m2_dtype ,
127
127
linear_mm_config = ctx .linear_mm_config ,
128
- gemm_input_role = GemmInputRole .DL_DY ,
128
+ gemm_input_role = GemmInputRole .GRAD_OUTPUT ,
129
129
)
130
130
empty_grads = None , None , None , None , None , None
131
131
return res , * empty_grads
@@ -273,21 +273,21 @@ def convert_amax_buffer_to_float32(self):
273
273
if self ._buffers [key ] is not None :
274
274
self ._buffers [key ] = self ._buffers [key ].to (torch .float32 )
275
275
276
- def cast_x_to_float8 (
277
- self , x : torch .Tensor , is_amax_initialized : bool
276
+ def cast_input_to_float8 (
277
+ self , input : torch .Tensor , is_amax_initialized : bool
278
278
) -> torch .Tensor :
279
279
# Duplicate the autocast logic for F.linear, so that the output
280
280
# of our module has the right original precision
281
281
if torch .is_autocast_enabled ():
282
282
# For now, hardcode to GPU's autocast dtype
283
283
# if we need CPU support in the future, we can add it
284
284
autocast_dtype = torch .get_autocast_gpu_dtype ()
285
- x = x .to (autocast_dtype )
285
+ input = input .to (autocast_dtype )
286
286
287
287
if self .scaling_type_input is TensorScalingType .DELAYED :
288
288
scale_fn_name = self .config .delayed_scaling_config .scale_fn_name
289
289
_maybe_initialize_amaxes_scales_for_float8_cast (
290
- x ,
290
+ input ,
291
291
self .fp8_amax_input ,
292
292
self .fp8_amax_history_input ,
293
293
self .fp8_scale_input ,
@@ -296,29 +296,29 @@ def cast_x_to_float8(
296
296
is_amax_initialized ,
297
297
reduce_amax = True ,
298
298
)
299
- x_fp8 = Float8Tensor .to_float8 (
300
- x ,
299
+ input_fp8 = Float8Tensor .to_float8 (
300
+ input ,
301
301
self .fp8_scale_input ,
302
302
e4m3_dtype ,
303
303
self .fp8_amax_input ,
304
304
linear_mm_config = self .linear_mm_config ,
305
- gemm_input_role = GemmInputRole .X ,
305
+ gemm_input_role = GemmInputRole .INPUT ,
306
306
)
307
307
else :
308
308
assert self .scaling_type_input is TensorScalingType .DYNAMIC
309
- x_fp8 = cast_to_float8_e4m3_dynamic (x , self .linear_mm_config )
310
- return x_fp8
309
+ input_fp8 = cast_to_float8_e4m3_dynamic (input , self .linear_mm_config )
310
+ return input_fp8
311
311
312
- def cast_w_to_float8 (
313
- self , w : torch .Tensor , is_amax_initialized : bool
312
+ def cast_weight_to_float8 (
313
+ self , weight : torch .Tensor , is_amax_initialized : bool
314
314
) -> torch .Tensor :
315
315
if self .scaling_type_weight is TensorScalingType .DELAYED :
316
316
if isinstance (self .weight , Float8Tensor ): # cast by FSDP
317
- w_fp8 = self .weight
317
+ weight_fp8 = self .weight
318
318
else :
319
319
scale_fn_name = self .config .delayed_scaling_config .scale_fn_name
320
320
_maybe_initialize_amaxes_scales_for_float8_cast (
321
- w ,
321
+ weight ,
322
322
self .fp8_amax_weight ,
323
323
self .fp8_amax_history_weight ,
324
324
self .fp8_scale_weight ,
@@ -328,29 +328,31 @@ def cast_w_to_float8(
328
328
reduce_amax = False ,
329
329
)
330
330
331
- w_fp8 = Float8Tensor .to_float8 (
332
- w ,
331
+ weight_fp8 = Float8Tensor .to_float8 (
332
+ weight ,
333
333
self .fp8_scale_weight ,
334
334
e4m3_dtype ,
335
335
self .fp8_amax_weight ,
336
336
linear_mm_config = self .linear_mm_config ,
337
- gemm_input_role = GemmInputRole .W ,
337
+ gemm_input_role = GemmInputRole .WEIGHT ,
338
338
)
339
339
else :
340
340
assert self .scaling_type_weight is TensorScalingType .DYNAMIC
341
341
if isinstance (self .weight , Float8Tensor ): # cast by FSDP
342
- w_fp8 = self .weight
342
+ weight_fp8 = self .weight
343
343
else :
344
- w_fp8 = cast_to_float8_e4m3_dynamic (
345
- self .weight , self .linear_mm_config , gemm_input_role = GemmInputRole .W
344
+ weight_fp8 = cast_to_float8_e4m3_dynamic (
345
+ self .weight ,
346
+ self .linear_mm_config ,
347
+ gemm_input_role = GemmInputRole .WEIGHT ,
346
348
)
347
- return w_fp8
349
+ return weight_fp8
348
350
349
- def cast_y_to_float8_in_bw (self , y : torch .Tensor ) -> torch .Tensor :
351
+ def cast_output_to_float8_in_bw (self , output : torch .Tensor ) -> torch .Tensor :
350
352
if self .scaling_type_grad_output is TensorScalingType .DELAYED :
351
353
scale_fn_name = self .config .delayed_scaling_config .scale_fn_name
352
- y = NoopFwToFloat8E5M2Bw .apply (
353
- y ,
354
+ output = NoopFwToFloat8E5M2Bw .apply (
355
+ output ,
354
356
self .fp8_amax_grad_output ,
355
357
self .fp8_amax_history_grad_output ,
356
358
self .fp8_scale_grad_output ,
@@ -360,10 +362,10 @@ def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
360
362
)
361
363
else :
362
364
assert self .scaling_type_grad_output is TensorScalingType .DYNAMIC
363
- y = cast_to_float8_e5m2_dynamic_bw (y , self .linear_mm_config )
364
- return y
365
+ output = cast_to_float8_e5m2_dynamic_bw (output , self .linear_mm_config )
366
+ return output
365
367
366
- def float8_pre_forward (self , x ):
368
+ def float8_pre_forward (self , input ):
367
369
if not self .enable_pre_and_post_forward :
368
370
return
369
371
if (
@@ -374,7 +376,7 @@ def float8_pre_forward(self, x):
374
376
raise AssertionError (
375
377
"amaxes and scales not synced, please call `sync_float8_amax_and_scale_history` before forward"
376
378
)
377
- self .last_seen_input_dtype = x .dtype
379
+ self .last_seen_input_dtype = input .dtype
378
380
379
381
def float8_post_forward (self ):
380
382
if not self .enable_pre_and_post_forward :
@@ -388,25 +390,25 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
388
390
if self .has_any_delayed_scaling :
389
391
self .float8_pre_forward (input )
390
392
391
- x_fp8 = self .cast_x_to_float8 (input , self .is_amax_initialized )
392
- w_fp8 = self .cast_w_to_float8 (self .weight , self .is_amax_initialized )
393
+ input_fp8 = self .cast_input_to_float8 (input , self .is_amax_initialized )
394
+ weight_fp8 = self .cast_weight_to_float8 (self .weight , self .is_amax_initialized )
393
395
394
- y = torch .matmul (x_fp8 , w_fp8 .t ())
396
+ output = torch .matmul (input_fp8 , weight_fp8 .t ())
395
397
396
- # Cast gradY to float8_e5m2 during backward
397
- y = self .cast_y_to_float8_in_bw ( y )
398
+ # Cast grad_output to float8_e5m2 during backward
399
+ output = self .cast_output_to_float8_in_bw ( output )
398
400
399
401
if self .bias is not None :
400
- y = y + self .bias .to (y .dtype )
402
+ output = output + self .bias .to (output .dtype )
401
403
402
404
if self .has_any_delayed_scaling :
403
405
self .float8_post_forward ()
404
- return y
406
+ return output
405
407
406
408
def scaling_repr (self ):
407
409
# add scaling settings without using too many characters
408
- # example: "x :del,w:del,dldy :dyn"
409
- return f"x :{ self .scaling_type_input .short_str ()} ,w:{ self .scaling_type_weight .short_str ()} ,dldy :{ self .scaling_type_grad_output .short_str ()} "
410
+ # example: "i :del,w:del,go :dyn"
411
+ return f"i :{ self .scaling_type_input .short_str ()} ,w:{ self .scaling_type_weight .short_str ()} ,go :{ self .scaling_type_grad_output .short_str ()} "
410
412
411
413
def extra_repr (self ):
412
414
s = f'{ super ().extra_repr ()} , scaling="{ self .scaling_repr ()} "'
0 commit comments