Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Updates to sync_float_amax_history #211

Closed
wants to merge 14 commits into from
13 changes: 7 additions & 6 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,23 +138,24 @@ def __init__(self, *args, **kwargs):
self.recipe = delayed_scaling_recipe
history_len = self.recipe.history_len

self.register_always_float32_buffer("fp8_amax_x", torch.tensor(E4M3_MAX_POS))
self.register_always_float32_buffer("fp8_amax_x", torch.tensor([E4M3_MAX_POS]))
self.register_always_float32_buffer(
"fp8_amax_history_x", torch.zeros(history_len)
)
self.register_always_float32_buffer("fp8_scale_x", torch.tensor(1.0))
self.register_always_float32_buffer("fp8_amax_w", torch.tensor(E4M3_MAX_POS))
self.register_always_float32_buffer("fp8_scale_x", torch.tensor([1.0]))
self.register_always_float32_buffer("fp8_amax_w", torch.tensor([E4M3_MAX_POS]))
self.register_always_float32_buffer(
"fp8_amax_history_w", torch.zeros(history_len)
)
self.register_always_float32_buffer("fp8_scale_w", torch.tensor(1.0))
self.register_always_float32_buffer("fp8_scale_w", torch.tensor([1.0]))
self.register_always_float32_buffer(
"fp8_amax_dL_dY", torch.tensor(E5M2_MAX_POS)
"fp8_amax_dL_dY", torch.tensor([E5M2_MAX_POS])
)
self.register_always_float32_buffer(
"fp8_amax_history_dL_dY", torch.zeros(history_len)
)
self.register_always_float32_buffer("fp8_scale_dL_dY", torch.tensor(1.0))
self.register_always_float32_buffer("fp8_scale_dL_dY", torch.tensor([1.0]))

# Whether to emulate the fp8 matmul logic in float32
self.emulate = False

Expand Down
226 changes: 136 additions & 90 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,23 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import copy
import logging
from enum import auto, Enum
from typing import List, Optional, Type

import float8_experimental.config as fp8_config

import torch
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear

from float8_experimental.float8_utils import amax_history_to_scale
from float8_experimental.float8_utils import amax_history_to_scale_stack
from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor

log = logging.getLogger(__name__)
log.addHandler(logging.NullHandler())


class LinearType(Enum):
Expand Down Expand Up @@ -57,14 +64,26 @@ def linear_requires_sync(linear_type: LinearType):
return linear_type in REQUIRES_SYNC


def _update_history_with_new_amax(new_amax, amax_history):
def _update_history_stack(
new_amax: torch.Tensor, amax_history_stack: torch.Tensor
) -> torch.Tensor:
"""
Updates `amax_history` (the last N cur_amax values) inplace with the value
of `new_amax`.

Args:
new_amax (torch.Tensor): The new amax value to add to the history. (n_amaxes, 1)
amax_history_stack (torch.Tensor): The history of amax values. (n_amaxes, history_length)
"""
new_amax_history = torch.roll(amax_history, 1)
new_amax_history[0] = new_amax
amax_history.copy_(new_amax_history)
assert (
amax_history_stack.dim() == 2
), f"Expected amat_history_stack to be 2D, got {amax_history_stack.shape()}"
assert new_amax.size(0) == amax_history_stack.size(
0
), f"Expected new_amax to have the same size as the first dimension of amax_history_stack, got {new_amax.size(0)} and {amax_history_stack.size(0)}"
new_amax_history_stack = torch.roll(amax_history_stack, 1, dims=1)
new_amax_history_stack[:, 0] = new_amax.squeeze(-1)
amax_history_stack.copy_(new_amax_history_stack)


def swap_linear_with_float8_linear(
Expand Down Expand Up @@ -121,21 +140,20 @@ def post_order_traversal(
return root_module


def get_float8_layers(model: torch.nn.Module, fp8_classes=None):
if fp8_classes is None:
fp8_classes = Float8Linear
def get_float8_layers(model: torch.nn.Module):
"""Iterates through the model and returns all the Float8Linear layers.
Args:
model (torch.nn.Module): The model to look for Float8Linear layers in.
"""

# Get all fp8 layers and tensors
fp8_layers = [
child for name, child in model.named_modules() if isinstance(child, fp8_classes)
]
fp8_layers = [child for child in model.modules() if isinstance(child, Float8Linear)]

return fp8_layers


def sync_float8_amax_and_scale_history(
model: torch.nn.Module, fp8_classes=None, fp8_layers=None
) -> None:
@torch.no_grad()
def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) -> None:
"""
Manages the float8 amax and scale bookkeeping. In detail, it does the
following:
Expand All @@ -147,95 +165,123 @@ def sync_float8_amax_and_scale_history(

TODO(future): design the UX for this (context manager, etc)

PERFORMANCE NOTE:
When you can, it is much more efficient to call get_float8_layers once at
the beginning of the training loop and pass the result to this function.
Because of how this interacts with torch.compile
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, how does passing float8_layers affect torch.compile?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

honestly I wrote this at the beginning of the PR and now I am not sure that this is true, but I found the get_attr was causing graph breaks but that was a few iterations ago and reading the graph_breaks logs can be a lil hard to decipher some times


Args:
model (torch.nn.Module): The model to track amaxes for
fp8_classes (optional): The fp8 classes to look for in the model.
The default is Float8Linear.
When using with TP, users can pass in the customized TP classes instead.
fp8_layers (optional): If fp8_layers are provided, fp8_classes are ignored,
and we loop over all fp8_layers to sync and update amax scale histories.
Users can use get_float8_layers to get all fp8 layers.
"""

# For now, this is written in a naive way to maximize code readability.
# TODO(future): benchmark and optimize as needed, we have combined all
# the reductions into one and we can probably try other optimizatons to
# make the history update faster.

if fp8_layers is None:
fp8_layers = get_float8_layers(model, fp8_classes)
fp8_layers = get_float8_layers(model)

if dist.is_initialized():
fp8_amax_x_tensor = torch.tensor(
[child.fp8_amax_x for child in fp8_layers],
dtype=torch.float32,
device="cuda",
requires_grad=False,
if len(fp8_layers) == 0:
log.warn(
"Calling sync_float8_amax_and_scale_history on a module with no Float8Linear layers"
)
fp8_amax_w_tensor = torch.tensor(
[child.fp8_amax_w for child in fp8_layers],
dtype=torch.float32,
device="cuda",
requires_grad=False,
)
fp8_amax_dL_dY_tensor = torch.tensor(
[child.fp8_amax_dL_dY for child in fp8_layers],
dtype=torch.float32,
device="cuda",
requires_grad=False,
)
dist.all_reduce(fp8_amax_x_tensor, op=dist.ReduceOp.MAX)
dist.all_reduce(fp8_amax_w_tensor, op=dist.ReduceOp.MAX)
dist.all_reduce(fp8_amax_dL_dY_tensor, op=dist.ReduceOp.MAX)

for idx in range(len(fp8_layers)):
child = fp8_layers[idx]

#
# 1. in distributed contexts, syncs amax values across workers
#
if dist.is_initialized():
child.fp8_amax_x = fp8_amax_x_tensor[idx].clone()
child.fp8_amax_w = fp8_amax_w_tensor[idx].clone()
child.fp8_amax_dL_dY = fp8_amax_dL_dY_tensor[idx].clone()

#
# 2. adds the `amax` values to history
#
_update_history_with_new_amax(child.fp8_amax_x, child.fp8_amax_history_x)
_update_history_with_new_amax(child.fp8_amax_w, child.fp8_amax_history_w)
_update_history_with_new_amax(
child.fp8_amax_dL_dY, child.fp8_amax_history_dL_dY
return

# Loop over all fp8 layers and grab the needed tensors
fp8_amax_x_tensor_list = [None] * len(fp8_layers)
fp8_amax_w_tensor_list = [None] * len(fp8_layers)
fp8_amax_dL_dY_tensor_list = [None] * len(fp8_layers)

fp8_x_amax_history_stack = [None] * len(fp8_layers)
fp8_w_amax_history_stack = [None] * len(fp8_layers)
fp8_dL_dY_amax_history_stack = [None] * len(fp8_layers)

x_dtypes = set()
scale_fn_recipes = set()

for idx, child in enumerate(fp8_layers):
fp8_amax_x_tensor_list[idx] = child.fp8_amax_x
fp8_amax_w_tensor_list[idx] = child.fp8_amax_w
fp8_amax_dL_dY_tensor_list[idx] = child.fp8_amax_dL_dY

fp8_x_amax_history_stack[idx] = child.fp8_amax_history_x
fp8_w_amax_history_stack[idx] = child.fp8_amax_history_w
fp8_dL_dY_amax_history_stack[idx] = child.fp8_amax_history_dL_dY

x_dtypes.add(child.last_seen_input_dtype)
scale_fn_recipes.add(child.recipe.scale_fn_name)

# TODO This way to get the activation dtype is not ideal
if len(x_dtypes) != 1:
raise ValueError(
f"All layers must have the same last seen input_dtype, got {x_dtypes}"
)
x_dtype = next(iter(x_dtypes))

#
# 3. calculate the scales
#
# TODO what to do with x_dtype
x_dtype = child.last_seen_input_dtype
new_scale = amax_history_to_scale(
child.fp8_amax_history_x,
torch.float8_e4m3fn,
x_dtype,
child.recipe.scale_fn_name,
if len(scale_fn_recipes) != 1:
raise ValueError(
f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}"
)
child.fp8_scale_x.copy_(new_scale)
new_scale = amax_history_to_scale(
child.fp8_amax_history_w,
torch.float8_e4m3fn,
x_dtype,
child.recipe.scale_fn_name,
scale_fn_recipe = next(iter(scale_fn_recipes))

assert (
len(fp8_amax_x_tensor_list)
== len(fp8_amax_w_tensor_list)
== len(fp8_amax_dL_dY_tensor_list)
), "Mismatched lengths of amax tensors."

if dist.is_initialized():
# Combine all the amax tensors into one tensor and reduce it
all_amax_tensors = torch.cat(
fp8_amax_x_tensor_list + fp8_amax_w_tensor_list + fp8_amax_dL_dY_tensor_list
)
child.fp8_scale_w.copy_(new_scale)
new_scale = amax_history_to_scale(
child.fp8_amax_history_dL_dY,
torch.float8_e5m2,
x_dtype,
child.recipe.scale_fn_name,
all_reduced_amax_tensor = all_reduce(
all_amax_tensors, "MAX", list(range(dist.get_world_size()))
)
child.fp8_scale_dL_dY.copy_(new_scale)
if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor):
all_reduced_amax_tensor = all_reduced_amax_tensor.wait()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside: It seems like a bug on AsyncCollectiveTensor if it does not make sure to call wait() when you torch.split it. 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm...

I actually tried to add an optimization to AsyncCollectiveTensor a while ago: pytorch/pytorch#105240. Where the thought behind it was:

(1) if we have a collective in-flight, we only need to sync the collective if the data needs to be used in an operation

(2) view ops generally don't need to access the data of their inputs, only the metadata.

So that PR "delays" syncs when there are view ops called on AsyncCollectiveTensor (e.g. aten.split()).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, I was working through this with @wanchaol and without the all_reduced_amax_tensor was garbage...

Copy link
Contributor

@bdhirsh bdhirsh Feb 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious why the manual sync here is needed - does it paper over some underlying bug? Didn't see your last comment. I wonder if this is related to some internal AsyncCollectiveTensor issues 🤔

Let me know if either of you want to stare at it together (I think fixing the underlying bug will help other areas too)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interestingly, I was not able to repro the issue: P1184812077

_foreach_copy_ and functional all-reduce without .wait() seemed to be fine, but I might not be repro-ing correctly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm let me try again then

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right... however:

    return self.__get_result()
  File "/home/drisspg/miniconda3/envs/dev/lib/python3.10/concurrent/futures/_base.py", line 403, in __get_result
    raise self._exception
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
CompilationError: at 10:30:def triton_(in_ptr0, in_ptr1, in_ptr2, in_ptr3, out_ptr0, out_ptr1, out_ptr3, out_ptr4, out_ptr5, out_ptr7, out_ptr8, out_ptr9, out_ptr11):
    xpid = tl.program_id(0)
    XBLOCK: tl.constexpr = 1024
    if xpid >= 0 and xpid < 1:
        xpid_offset = xpid - 0
        xnumel = 1
        xoffset = xpid_offset * XBLOCK
        xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
        xmask = xindex < xnumel
        rindex = tl.arange(0, RBLOCK)[None, :]
                              ^
NameError('RBLOCK is not defined')

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

compiling the sync function now causes an inductor fault... so I will create another PR and investigate there but otherwise I think thats everything for this one

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


(
reduced_fp8_amax_tensor,
reduced_fp8_amax_w_tensor,
reduced_fp8_amax_dL_dY_tensor,
) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list))

for idx, child in enumerate(fp8_layers):
child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx])
child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx])
child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx])

# We create two stacked tensor groups, one for the amax history and one for the current scales
fp8_amax_x_tensors = torch.vstack(fp8_amax_x_tensor_list)
fp8_amax_w_tensors = torch.vstack(fp8_amax_w_tensor_list)
fp8_amax_dL_dY_tensors = torch.vstack(fp8_amax_dL_dY_tensor_list)

fp8_x_amax_history_stack = torch.vstack(fp8_x_amax_history_stack)
fp8_w_amax_history_stack = torch.vstack(fp8_w_amax_history_stack)
fp8_dL_dY_amax_history_stack = torch.vstack(fp8_dL_dY_amax_history_stack)

# Update the history stacks with the new amax values
_update_history_stack(fp8_amax_x_tensors, fp8_x_amax_history_stack)
_update_history_stack(fp8_amax_w_tensors, fp8_w_amax_history_stack)
_update_history_stack(fp8_amax_dL_dY_tensors, fp8_dL_dY_amax_history_stack)

# Calculate the new scales from the updated history stacks
new_x_scales = amax_history_to_scale_stack(
fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
)
new_w_scales = amax_history_to_scale_stack(
fp8_w_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
)
new_dL_dY_scales = amax_history_to_scale_stack(
fp8_dL_dY_amax_history_stack, torch.float8_e5m2, x_dtype, scale_fn_recipe
)

# Iterate through the layers and update the scales, and set the flag to signal that the amaxes/scales are ready
for idx, child in enumerate(fp8_layers):
child.fp8_scale_x.copy_(new_x_scales[idx])
child.fp8_scale_w.copy_(new_w_scales[idx])
child.fp8_scale_dL_dY.copy_(new_dL_dY_scales[idx])

#
# 4. set a flag to signal amaxes/scales are ready
#
child.amax_and_scale_synced = True
# We only update the flag if we know it will be checked by the modules
if fp8_config.enable_amax_init:
child.amax_and_scale_synced = True
2 changes: 2 additions & 0 deletions float8_experimental/float8_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from typing import Optional, Tuple

import float8_experimental.float8_aten_api # noqa

import torch
from float8_experimental.float8_tensor import Float8Tensor

Expand Down
18 changes: 17 additions & 1 deletion float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

@torch.no_grad()
def amax_to_scale(amax, float8_dtype, orig_dtype):
scale = torch.empty((), device=amax.device, dtype=torch.float32)
scale = torch.empty_like(amax, dtype=torch.float32)
if float8_dtype == torch.float8_e4m3fn:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
else: # e5m2
Expand Down Expand Up @@ -51,6 +51,22 @@ def amax_history_to_scale(
raise NotImplementedError()


@torch.no_grad()
def amax_history_to_scale_stack(
amax_history: torch.Tensor,
float8_dtype: torch.dtype,
orig_dtype: torch.dtype,
history_to_scale_fn_type: str,
) -> torch.Tensor:
"""Takes in a stack of amax_history tensors and returns a scale tensor."""
if history_to_scale_fn_type == "max":
amax_stack = torch.max(amax_history, dim=1).values
return amax_to_scale(amax_stack, float8_dtype, orig_dtype)
raise NotImplementedError(
f"Invalid history_to_scale_fn_type, only 'max' is supported. Got: {history_to_scale_fn_type}"
)


@torch.no_grad()
def tensor_to_amax(x, distributed_reduction=False):
amax = torch.max(torch.abs(x))
Expand Down