@@ -64,16 +64,6 @@ def linear_requires_sync(linear_type: LinearType):
64
64
return linear_type in REQUIRES_SYNC
65
65
66
66
67
- def _update_history_with_new_amax (new_amax , amax_history ):
68
- """
69
- Updates `amax_history` (the last N cur_amax values) inplace with the value
70
- of `new_amax`.
71
- """
72
- new_amax_history = torch .roll (amax_history , 1 )
73
- new_amax_history [0 ] = new_amax
74
- amax_history .copy_ (new_amax_history )
75
-
76
-
77
67
def _update_history_stack (
78
68
new_amax : torch .Tensor , amax_history_stack : torch .Tensor
79
69
) -> torch .Tensor :
@@ -85,10 +75,12 @@ def _update_history_stack(
85
75
new_amax (torch.Tensor): The new amax value to add to the history. (n_amaxes, 1)
86
76
amax_history_stack (torch.Tensor): The history of amax values. (n_amaxes, history_length)
87
77
"""
88
- assert amax_history_stack .dim () == 2 , "amax_history_stack must be 2D"
78
+ assert (
79
+ amax_history_stack .dim () == 2
80
+ ), f"Expected amat_history_stack to be 2D, got { amax_history_stack .shape ()} "
89
81
assert new_amax .size (0 ) == amax_history_stack .size (
90
82
0
91
- ), " new_amax must have the same size as the second dimension of amax_history_stack"
83
+ ), f"Expected new_amax to have the same size as the first dimension of amax_history_stack, got { new_amax . size ( 0 ) } and { amax_history_stack . size ( 0 ) } "
92
84
new_amax_history_stack = torch .roll (amax_history_stack , 1 , dims = 1 )
93
85
new_amax_history_stack [:, 0 ] = new_amax .squeeze (- 1 )
94
86
amax_history_stack .copy_ (new_amax_history_stack )
@@ -155,9 +147,7 @@ def get_float8_layers(model: torch.nn.Module):
155
147
"""
156
148
157
149
# Get all fp8 layers and tensors
158
- fp8_layers = [
159
- child for _ , child in model .named_modules () if isinstance (child , Float8Linear )
160
- ]
150
+ fp8_layers = [child for child in model .modules () if isinstance (child , Float8Linear )]
161
151
162
152
return fp8_layers
163
153
@@ -176,7 +166,7 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
176
166
TODO(future): design the UX for this (context manager, etc)
177
167
178
168
PERFORMANCE NOTE:
179
- When you can it is much more efficient to call te get_float8_layers once a
169
+ When you can, it is much more efficient to call get_float8_layers once at
180
170
the beginning of the training loop and pass the result to this function.
181
171
Because of how this interacts with torch.compile
182
172
@@ -249,13 +239,12 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
249
239
reduced_fp8_amax_dL_dY_tensor ,
250
240
) = torch .split (all_reduced_amax_tensor , len (fp8_amax_x_tensor_list ))
251
241
252
- # TODO foreach is not supported with AsyncCollectiveTensor
253
242
for idx , child in enumerate (fp8_layers ):
254
243
child .fp8_amax_x .copy_ (reduced_fp8_amax_tensor [idx ])
255
244
child .fp8_amax_w .copy_ (reduced_fp8_amax_w_tensor [idx ])
256
245
child .fp8_amax_dL_dY .copy_ (reduced_fp8_amax_dL_dY_tensor [idx ])
257
246
258
- # We create two stacked tensors , one for the amax history and one for the current scales
247
+ # We create two stacked tensor groups , one for the amax history and one for the current scales
259
248
fp8_amax_x_tensors = torch .vstack (fp8_amax_x_tensor_list )
260
249
fp8_amax_w_tensors = torch .vstack (fp8_amax_w_tensor_list )
261
250
fp8_amax_dL_dY_tensors = torch .vstack (fp8_amax_dL_dY_tensor_list )
@@ -264,11 +253,12 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
264
253
fp8_w_amax_history_stack = torch .vstack (fp8_w_amax_history_stack )
265
254
fp8_dL_dY_amax_history_stack = torch .vstack (fp8_dL_dY_amax_history_stack )
266
255
256
+ # Update the history stacks with the new amax values
267
257
_update_history_stack (fp8_amax_x_tensors , fp8_x_amax_history_stack )
268
258
_update_history_stack (fp8_amax_w_tensors , fp8_w_amax_history_stack )
269
259
_update_history_stack (fp8_amax_dL_dY_tensors , fp8_dL_dY_amax_history_stack )
270
260
271
- # We are not reading the
261
+ # Calculate the new scales from the updated history stacks
272
262
new_x_scales = amax_history_to_scale_stack (
273
263
fp8_x_amax_history_stack , torch .float8_e4m3fn , x_dtype , scale_fn_recipe
274
264
)
0 commit comments