-
Notifications
You must be signed in to change notification settings - Fork 19
Updates to sync_float_amax_history #211
Changes from all commits
e46b137
5d47936
ab5043e
745d73a
09a52a9
bca2121
2a306b8
1eec9ff
66d22cb
d8015c7
bfbc868
4ca6ddf
238e8f9
5bee42b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,16 +4,23 @@ | |
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import copy | ||
import logging | ||
from enum import auto, Enum | ||
from typing import List, Optional, Type | ||
|
||
import float8_experimental.config as fp8_config | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.nn as nn | ||
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear | ||
from float8_experimental.float8_linear import Float8Linear | ||
|
||
from float8_experimental.float8_utils import amax_history_to_scale | ||
from float8_experimental.float8_utils import amax_history_to_scale_stack | ||
from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor | ||
|
||
log = logging.getLogger(__name__) | ||
log.addHandler(logging.NullHandler()) | ||
|
||
|
||
class LinearType(Enum): | ||
|
@@ -57,14 +64,26 @@ def linear_requires_sync(linear_type: LinearType): | |
return linear_type in REQUIRES_SYNC | ||
|
||
|
||
def _update_history_with_new_amax(new_amax, amax_history): | ||
def _update_history_stack( | ||
new_amax: torch.Tensor, amax_history_stack: torch.Tensor | ||
) -> torch.Tensor: | ||
""" | ||
Updates `amax_history` (the last N cur_amax values) inplace with the value | ||
of `new_amax`. | ||
|
||
Args: | ||
new_amax (torch.Tensor): The new amax value to add to the history. (n_amaxes, 1) | ||
amax_history_stack (torch.Tensor): The history of amax values. (n_amaxes, history_length) | ||
""" | ||
new_amax_history = torch.roll(amax_history, 1) | ||
new_amax_history[0] = new_amax | ||
amax_history.copy_(new_amax_history) | ||
assert ( | ||
amax_history_stack.dim() == 2 | ||
), f"Expected amat_history_stack to be 2D, got {amax_history_stack.shape()}" | ||
assert new_amax.size(0) == amax_history_stack.size( | ||
0 | ||
), f"Expected new_amax to have the same size as the first dimension of amax_history_stack, got {new_amax.size(0)} and {amax_history_stack.size(0)}" | ||
new_amax_history_stack = torch.roll(amax_history_stack, 1, dims=1) | ||
new_amax_history_stack[:, 0] = new_amax.squeeze(-1) | ||
amax_history_stack.copy_(new_amax_history_stack) | ||
|
||
|
||
def swap_linear_with_float8_linear( | ||
|
@@ -121,21 +140,20 @@ def post_order_traversal( | |
return root_module | ||
|
||
|
||
def get_float8_layers(model: torch.nn.Module, fp8_classes=None): | ||
if fp8_classes is None: | ||
fp8_classes = Float8Linear | ||
def get_float8_layers(model: torch.nn.Module): | ||
"""Iterates through the model and returns all the Float8Linear layers. | ||
Args: | ||
model (torch.nn.Module): The model to look for Float8Linear layers in. | ||
""" | ||
|
||
# Get all fp8 layers and tensors | ||
fp8_layers = [ | ||
child for name, child in model.named_modules() if isinstance(child, fp8_classes) | ||
] | ||
fp8_layers = [child for child in model.modules() if isinstance(child, Float8Linear)] | ||
|
||
return fp8_layers | ||
|
||
|
||
def sync_float8_amax_and_scale_history( | ||
model: torch.nn.Module, fp8_classes=None, fp8_layers=None | ||
) -> None: | ||
@torch.no_grad() | ||
def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) -> None: | ||
""" | ||
Manages the float8 amax and scale bookkeeping. In detail, it does the | ||
following: | ||
|
@@ -147,95 +165,123 @@ def sync_float8_amax_and_scale_history( | |
|
||
TODO(future): design the UX for this (context manager, etc) | ||
|
||
PERFORMANCE NOTE: | ||
When you can, it is much more efficient to call get_float8_layers once at | ||
the beginning of the training loop and pass the result to this function. | ||
Because of how this interacts with torch.compile | ||
|
||
Args: | ||
model (torch.nn.Module): The model to track amaxes for | ||
fp8_classes (optional): The fp8 classes to look for in the model. | ||
The default is Float8Linear. | ||
When using with TP, users can pass in the customized TP classes instead. | ||
fp8_layers (optional): If fp8_layers are provided, fp8_classes are ignored, | ||
and we loop over all fp8_layers to sync and update amax scale histories. | ||
Users can use get_float8_layers to get all fp8 layers. | ||
""" | ||
|
||
# For now, this is written in a naive way to maximize code readability. | ||
# TODO(future): benchmark and optimize as needed, we have combined all | ||
# the reductions into one and we can probably try other optimizatons to | ||
# make the history update faster. | ||
|
||
if fp8_layers is None: | ||
fp8_layers = get_float8_layers(model, fp8_classes) | ||
fp8_layers = get_float8_layers(model) | ||
|
||
if dist.is_initialized(): | ||
fp8_amax_x_tensor = torch.tensor( | ||
[child.fp8_amax_x for child in fp8_layers], | ||
dtype=torch.float32, | ||
device="cuda", | ||
requires_grad=False, | ||
if len(fp8_layers) == 0: | ||
log.warn( | ||
"Calling sync_float8_amax_and_scale_history on a module with no Float8Linear layers" | ||
) | ||
fp8_amax_w_tensor = torch.tensor( | ||
[child.fp8_amax_w for child in fp8_layers], | ||
dtype=torch.float32, | ||
device="cuda", | ||
requires_grad=False, | ||
) | ||
fp8_amax_dL_dY_tensor = torch.tensor( | ||
[child.fp8_amax_dL_dY for child in fp8_layers], | ||
dtype=torch.float32, | ||
device="cuda", | ||
requires_grad=False, | ||
) | ||
dist.all_reduce(fp8_amax_x_tensor, op=dist.ReduceOp.MAX) | ||
dist.all_reduce(fp8_amax_w_tensor, op=dist.ReduceOp.MAX) | ||
dist.all_reduce(fp8_amax_dL_dY_tensor, op=dist.ReduceOp.MAX) | ||
|
||
for idx in range(len(fp8_layers)): | ||
child = fp8_layers[idx] | ||
|
||
# | ||
# 1. in distributed contexts, syncs amax values across workers | ||
# | ||
if dist.is_initialized(): | ||
child.fp8_amax_x = fp8_amax_x_tensor[idx].clone() | ||
child.fp8_amax_w = fp8_amax_w_tensor[idx].clone() | ||
child.fp8_amax_dL_dY = fp8_amax_dL_dY_tensor[idx].clone() | ||
|
||
# | ||
# 2. adds the `amax` values to history | ||
# | ||
_update_history_with_new_amax(child.fp8_amax_x, child.fp8_amax_history_x) | ||
_update_history_with_new_amax(child.fp8_amax_w, child.fp8_amax_history_w) | ||
_update_history_with_new_amax( | ||
child.fp8_amax_dL_dY, child.fp8_amax_history_dL_dY | ||
return | ||
|
||
# Loop over all fp8 layers and grab the needed tensors | ||
fp8_amax_x_tensor_list = [None] * len(fp8_layers) | ||
fp8_amax_w_tensor_list = [None] * len(fp8_layers) | ||
fp8_amax_dL_dY_tensor_list = [None] * len(fp8_layers) | ||
|
||
fp8_x_amax_history_stack = [None] * len(fp8_layers) | ||
fp8_w_amax_history_stack = [None] * len(fp8_layers) | ||
fp8_dL_dY_amax_history_stack = [None] * len(fp8_layers) | ||
|
||
x_dtypes = set() | ||
scale_fn_recipes = set() | ||
|
||
for idx, child in enumerate(fp8_layers): | ||
fp8_amax_x_tensor_list[idx] = child.fp8_amax_x | ||
fp8_amax_w_tensor_list[idx] = child.fp8_amax_w | ||
fp8_amax_dL_dY_tensor_list[idx] = child.fp8_amax_dL_dY | ||
|
||
fp8_x_amax_history_stack[idx] = child.fp8_amax_history_x | ||
fp8_w_amax_history_stack[idx] = child.fp8_amax_history_w | ||
fp8_dL_dY_amax_history_stack[idx] = child.fp8_amax_history_dL_dY | ||
|
||
x_dtypes.add(child.last_seen_input_dtype) | ||
scale_fn_recipes.add(child.recipe.scale_fn_name) | ||
|
||
# TODO This way to get the activation dtype is not ideal | ||
if len(x_dtypes) != 1: | ||
raise ValueError( | ||
f"All layers must have the same last seen input_dtype, got {x_dtypes}" | ||
) | ||
x_dtype = next(iter(x_dtypes)) | ||
|
||
# | ||
# 3. calculate the scales | ||
# | ||
# TODO what to do with x_dtype | ||
x_dtype = child.last_seen_input_dtype | ||
new_scale = amax_history_to_scale( | ||
child.fp8_amax_history_x, | ||
torch.float8_e4m3fn, | ||
x_dtype, | ||
child.recipe.scale_fn_name, | ||
if len(scale_fn_recipes) != 1: | ||
raise ValueError( | ||
f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}" | ||
) | ||
child.fp8_scale_x.copy_(new_scale) | ||
new_scale = amax_history_to_scale( | ||
child.fp8_amax_history_w, | ||
torch.float8_e4m3fn, | ||
x_dtype, | ||
child.recipe.scale_fn_name, | ||
scale_fn_recipe = next(iter(scale_fn_recipes)) | ||
|
||
assert ( | ||
len(fp8_amax_x_tensor_list) | ||
== len(fp8_amax_w_tensor_list) | ||
== len(fp8_amax_dL_dY_tensor_list) | ||
), "Mismatched lengths of amax tensors." | ||
|
||
if dist.is_initialized(): | ||
# Combine all the amax tensors into one tensor and reduce it | ||
drisspg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
all_amax_tensors = torch.cat( | ||
fp8_amax_x_tensor_list + fp8_amax_w_tensor_list + fp8_amax_dL_dY_tensor_list | ||
) | ||
child.fp8_scale_w.copy_(new_scale) | ||
new_scale = amax_history_to_scale( | ||
child.fp8_amax_history_dL_dY, | ||
torch.float8_e5m2, | ||
x_dtype, | ||
child.recipe.scale_fn_name, | ||
all_reduced_amax_tensor = all_reduce( | ||
all_amax_tensors, "MAX", list(range(dist.get_world_size())) | ||
) | ||
child.fp8_scale_dL_dY.copy_(new_scale) | ||
if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor): | ||
all_reduced_amax_tensor = all_reduced_amax_tensor.wait() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aside: It seems like a bug on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm... I actually tried to add an optimization to AsyncCollectiveTensor a while ago: pytorch/pytorch#105240. Where the thought behind it was: (1) if we have a collective in-flight, we only need to sync the collective if the data needs to be used in an operation (2) view ops generally don't need to access the data of their inputs, only the metadata. So that PR "delays" syncs when there are view ops called on AsyncCollectiveTensor (e.g. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, I was working through this with @wanchaol and without the all_reduced_amax_tensor was garbage... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Let me know if either of you want to stare at it together (I think fixing the underlying bug will help other areas too) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interestingly, I was not able to repro the issue: P1184812077
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmmm let me try again then There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right... however: return self.__get_result()
File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
raise self._exception
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
CompilationError: at 10:30:def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, out_ptr3, out_ptr4, out_ptr5, out_ptr7, out_ptr8, out_ptr9, out_ptr11):
xpid = tl.program_id(0)
XBLOCK: tl.constexpr = 1024
if xpid >= 0 and xpid < 1:
xpid_offset = xpid - 0
xnumel = 1
xoffset = xpid_offset * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rindex = tl.arange(0, RBLOCK)[None, :]
^
NameError('RBLOCK is not defined')
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True compiling the sync function now causes an inductor fault... so I will create another PR and investigate there but otherwise I think thats everything for this one There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
( | ||
reduced_fp8_amax_tensor, | ||
reduced_fp8_amax_w_tensor, | ||
reduced_fp8_amax_dL_dY_tensor, | ||
) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list)) | ||
|
||
for idx, child in enumerate(fp8_layers): | ||
child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx]) | ||
child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx]) | ||
child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx]) | ||
|
||
# We create two stacked tensor groups, one for the amax history and one for the current scales | ||
fp8_amax_x_tensors = torch.vstack(fp8_amax_x_tensor_list) | ||
drisspg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
fp8_amax_w_tensors = torch.vstack(fp8_amax_w_tensor_list) | ||
fp8_amax_dL_dY_tensors = torch.vstack(fp8_amax_dL_dY_tensor_list) | ||
|
||
fp8_x_amax_history_stack = torch.vstack(fp8_x_amax_history_stack) | ||
fp8_w_amax_history_stack = torch.vstack(fp8_w_amax_history_stack) | ||
fp8_dL_dY_amax_history_stack = torch.vstack(fp8_dL_dY_amax_history_stack) | ||
|
||
# Update the history stacks with the new amax values | ||
_update_history_stack(fp8_amax_x_tensors, fp8_x_amax_history_stack) | ||
_update_history_stack(fp8_amax_w_tensors, fp8_w_amax_history_stack) | ||
_update_history_stack(fp8_amax_dL_dY_tensors, fp8_dL_dY_amax_history_stack) | ||
|
||
# Calculate the new scales from the updated history stacks | ||
new_x_scales = amax_history_to_scale_stack( | ||
fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe | ||
) | ||
new_w_scales = amax_history_to_scale_stack( | ||
fp8_w_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe | ||
drisspg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
new_dL_dY_scales = amax_history_to_scale_stack( | ||
fp8_dL_dY_amax_history_stack, torch.float8_e5m2, x_dtype, scale_fn_recipe | ||
) | ||
|
||
# Iterate through the layers and update the scales, and set the flag to signal that the amaxes/scales are ready | ||
for idx, child in enumerate(fp8_layers): | ||
child.fp8_scale_x.copy_(new_x_scales[idx]) | ||
child.fp8_scale_w.copy_(new_w_scales[idx]) | ||
child.fp8_scale_dL_dY.copy_(new_dL_dY_scales[idx]) | ||
|
||
# | ||
# 4. set a flag to signal amaxes/scales are ready | ||
# | ||
child.amax_and_scale_synced = True | ||
# We only update the flag if we know it will be checked by the modules | ||
if fp8_config.enable_amax_init: | ||
child.amax_and_scale_synced = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my understanding, how does passing
float8_layers
affecttorch.compile
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
honestly I wrote this at the beginning of the PR and now I am not sure that this is true, but I found the get_attr was causing graph breaks but that was a few iterations ago and reading the graph_breaks logs can be a lil hard to decipher some times