Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 4ca6ddf

Browse files
committed
comments
1 parent bfbc868 commit 4ca6ddf

File tree

1 file changed

+9
-19
lines changed

1 file changed

+9
-19
lines changed

float8_experimental/float8_linear_utils.py

+9-19
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,6 @@ def linear_requires_sync(linear_type: LinearType):
6464
return linear_type in REQUIRES_SYNC
6565

6666

67-
def _update_history_with_new_amax(new_amax, amax_history):
68-
"""
69-
Updates `amax_history` (the last N cur_amax values) inplace with the value
70-
of `new_amax`.
71-
"""
72-
new_amax_history = torch.roll(amax_history, 1)
73-
new_amax_history[0] = new_amax
74-
amax_history.copy_(new_amax_history)
75-
76-
7767
def _update_history_stack(
7868
new_amax: torch.Tensor, amax_history_stack: torch.Tensor
7969
) -> torch.Tensor:
@@ -85,10 +75,12 @@ def _update_history_stack(
8575
new_amax (torch.Tensor): The new amax value to add to the history. (n_amaxes, 1)
8676
amax_history_stack (torch.Tensor): The history of amax values. (n_amaxes, history_length)
8777
"""
88-
assert amax_history_stack.dim() == 2, "amax_history_stack must be 2D"
78+
assert (
79+
amax_history_stack.dim() == 2
80+
), f"Expected amat_history_stack to be 2D, got {amax_history_stack.shape()}"
8981
assert new_amax.size(0) == amax_history_stack.size(
9082
0
91-
), "new_amax must have the same size as the second dimension of amax_history_stack"
83+
), f"Expected new_amax to have the same size as the first dimension of amax_history_stack, got {new_amax.size(0)} and {amax_history_stack.size(0)}"
9284
new_amax_history_stack = torch.roll(amax_history_stack, 1, dims=1)
9385
new_amax_history_stack[:, 0] = new_amax.squeeze(-1)
9486
amax_history_stack.copy_(new_amax_history_stack)
@@ -155,9 +147,7 @@ def get_float8_layers(model: torch.nn.Module):
155147
"""
156148

157149
# Get all fp8 layers and tensors
158-
fp8_layers = [
159-
child for _, child in model.named_modules() if isinstance(child, Float8Linear)
160-
]
150+
fp8_layers = [child for child in model.modules() if isinstance(child, Float8Linear)]
161151

162152
return fp8_layers
163153

@@ -176,7 +166,7 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
176166
TODO(future): design the UX for this (context manager, etc)
177167
178168
PERFORMANCE NOTE:
179-
When you can it is much more efficient to call te get_float8_layers once a
169+
When you can, it is much more efficient to call get_float8_layers once at
180170
the beginning of the training loop and pass the result to this function.
181171
Because of how this interacts with torch.compile
182172
@@ -249,13 +239,12 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
249239
reduced_fp8_amax_dL_dY_tensor,
250240
) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list))
251241

252-
# TODO foreach is not supported with AsyncCollectiveTensor
253242
for idx, child in enumerate(fp8_layers):
254243
child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx])
255244
child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx])
256245
child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx])
257246

258-
# We create two stacked tensors, one for the amax history and one for the current scales
247+
# We create two stacked tensor groups, one for the amax history and one for the current scales
259248
fp8_amax_x_tensors = torch.vstack(fp8_amax_x_tensor_list)
260249
fp8_amax_w_tensors = torch.vstack(fp8_amax_w_tensor_list)
261250
fp8_amax_dL_dY_tensors = torch.vstack(fp8_amax_dL_dY_tensor_list)
@@ -264,11 +253,12 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
264253
fp8_w_amax_history_stack = torch.vstack(fp8_w_amax_history_stack)
265254
fp8_dL_dY_amax_history_stack = torch.vstack(fp8_dL_dY_amax_history_stack)
266255

256+
# Update the history stacks with the new amax values
267257
_update_history_stack(fp8_amax_x_tensors, fp8_x_amax_history_stack)
268258
_update_history_stack(fp8_amax_w_tensors, fp8_w_amax_history_stack)
269259
_update_history_stack(fp8_amax_dL_dY_tensors, fp8_dL_dY_amax_history_stack)
270260

271-
# We are not reading the
261+
# Calculate the new scales from the updated history stacks
272262
new_x_scales = amax_history_to_scale_stack(
273263
fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
274264
)

0 commit comments

Comments
 (0)