diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index af8d89b5..0fc868d4 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -138,23 +138,24 @@ def __init__(self, *args, **kwargs): self.recipe = delayed_scaling_recipe history_len = self.recipe.history_len - self.register_always_float32_buffer("fp8_amax_x", torch.tensor(E4M3_MAX_POS)) + self.register_always_float32_buffer("fp8_amax_x", torch.tensor([E4M3_MAX_POS])) self.register_always_float32_buffer( "fp8_amax_history_x", torch.zeros(history_len) ) - self.register_always_float32_buffer("fp8_scale_x", torch.tensor(1.0)) - self.register_always_float32_buffer("fp8_amax_w", torch.tensor(E4M3_MAX_POS)) + self.register_always_float32_buffer("fp8_scale_x", torch.tensor([1.0])) + self.register_always_float32_buffer("fp8_amax_w", torch.tensor([E4M3_MAX_POS])) self.register_always_float32_buffer( "fp8_amax_history_w", torch.zeros(history_len) ) - self.register_always_float32_buffer("fp8_scale_w", torch.tensor(1.0)) + self.register_always_float32_buffer("fp8_scale_w", torch.tensor([1.0])) self.register_always_float32_buffer( - "fp8_amax_dL_dY", torch.tensor(E5M2_MAX_POS) + "fp8_amax_dL_dY", torch.tensor([E5M2_MAX_POS]) ) self.register_always_float32_buffer( "fp8_amax_history_dL_dY", torch.zeros(history_len) ) - self.register_always_float32_buffer("fp8_scale_dL_dY", torch.tensor(1.0)) + self.register_always_float32_buffer("fp8_scale_dL_dY", torch.tensor([1.0])) + # Whether to emulate the fp8 matmul logic in float32 self.emulate = False diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 3450f154..b96a4a1c 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -4,16 +4,23 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import copy +import logging from enum import auto, Enum from typing import List, Optional, Type +import float8_experimental.config as fp8_config + import torch import torch.distributed as dist import torch.nn as nn from float8_experimental.float8_dynamic_linear import Float8DynamicLinear from float8_experimental.float8_linear import Float8Linear -from float8_experimental.float8_utils import amax_history_to_scale +from float8_experimental.float8_utils import amax_history_to_scale_stack +from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor + +log = logging.getLogger(__name__) +log.addHandler(logging.NullHandler()) class LinearType(Enum): @@ -57,14 +64,26 @@ def linear_requires_sync(linear_type: LinearType): return linear_type in REQUIRES_SYNC -def _update_history_with_new_amax(new_amax, amax_history): +def _update_history_stack( + new_amax: torch.Tensor, amax_history_stack: torch.Tensor +) -> torch.Tensor: """ Updates `amax_history` (the last N cur_amax values) inplace with the value of `new_amax`. + + Args: + new_amax (torch.Tensor): The new amax value to add to the history. (n_amaxes, 1) + amax_history_stack (torch.Tensor): The history of amax values. (n_amaxes, history_length) """ - new_amax_history = torch.roll(amax_history, 1) - new_amax_history[0] = new_amax - amax_history.copy_(new_amax_history) + assert ( + amax_history_stack.dim() == 2 + ), f"Expected amat_history_stack to be 2D, got {amax_history_stack.shape()}" + assert new_amax.size(0) == amax_history_stack.size( + 0 + ), f"Expected new_amax to have the same size as the first dimension of amax_history_stack, got {new_amax.size(0)} and {amax_history_stack.size(0)}" + new_amax_history_stack = torch.roll(amax_history_stack, 1, dims=1) + new_amax_history_stack[:, 0] = new_amax.squeeze(-1) + amax_history_stack.copy_(new_amax_history_stack) def swap_linear_with_float8_linear( @@ -121,21 +140,20 @@ def post_order_traversal( return root_module -def get_float8_layers(model: torch.nn.Module, fp8_classes=None): - if fp8_classes is None: - fp8_classes = Float8Linear +def get_float8_layers(model: torch.nn.Module): + """Iterates through the model and returns all the Float8Linear layers. + Args: + model (torch.nn.Module): The model to look for Float8Linear layers in. + """ # Get all fp8 layers and tensors - fp8_layers = [ - child for name, child in model.named_modules() if isinstance(child, fp8_classes) - ] + fp8_layers = [child for child in model.modules() if isinstance(child, Float8Linear)] return fp8_layers -def sync_float8_amax_and_scale_history( - model: torch.nn.Module, fp8_classes=None, fp8_layers=None -) -> None: +@torch.no_grad() +def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) -> None: """ Manages the float8 amax and scale bookkeeping. In detail, it does the following: @@ -147,95 +165,123 @@ def sync_float8_amax_and_scale_history( TODO(future): design the UX for this (context manager, etc) + PERFORMANCE NOTE: + When you can, it is much more efficient to call get_float8_layers once at + the beginning of the training loop and pass the result to this function. + Because of how this interacts with torch.compile + Args: model (torch.nn.Module): The model to track amaxes for - fp8_classes (optional): The fp8 classes to look for in the model. - The default is Float8Linear. - When using with TP, users can pass in the customized TP classes instead. fp8_layers (optional): If fp8_layers are provided, fp8_classes are ignored, and we loop over all fp8_layers to sync and update amax scale histories. Users can use get_float8_layers to get all fp8 layers. """ - - # For now, this is written in a naive way to maximize code readability. - # TODO(future): benchmark and optimize as needed, we have combined all - # the reductions into one and we can probably try other optimizatons to - # make the history update faster. - if fp8_layers is None: - fp8_layers = get_float8_layers(model, fp8_classes) + fp8_layers = get_float8_layers(model) - if dist.is_initialized(): - fp8_amax_x_tensor = torch.tensor( - [child.fp8_amax_x for child in fp8_layers], - dtype=torch.float32, - device="cuda", - requires_grad=False, + if len(fp8_layers) == 0: + log.warn( + "Calling sync_float8_amax_and_scale_history on a module with no Float8Linear layers" ) - fp8_amax_w_tensor = torch.tensor( - [child.fp8_amax_w for child in fp8_layers], - dtype=torch.float32, - device="cuda", - requires_grad=False, - ) - fp8_amax_dL_dY_tensor = torch.tensor( - [child.fp8_amax_dL_dY for child in fp8_layers], - dtype=torch.float32, - device="cuda", - requires_grad=False, - ) - dist.all_reduce(fp8_amax_x_tensor, op=dist.ReduceOp.MAX) - dist.all_reduce(fp8_amax_w_tensor, op=dist.ReduceOp.MAX) - dist.all_reduce(fp8_amax_dL_dY_tensor, op=dist.ReduceOp.MAX) - - for idx in range(len(fp8_layers)): - child = fp8_layers[idx] - - # - # 1. in distributed contexts, syncs amax values across workers - # - if dist.is_initialized(): - child.fp8_amax_x = fp8_amax_x_tensor[idx].clone() - child.fp8_amax_w = fp8_amax_w_tensor[idx].clone() - child.fp8_amax_dL_dY = fp8_amax_dL_dY_tensor[idx].clone() - - # - # 2. adds the `amax` values to history - # - _update_history_with_new_amax(child.fp8_amax_x, child.fp8_amax_history_x) - _update_history_with_new_amax(child.fp8_amax_w, child.fp8_amax_history_w) - _update_history_with_new_amax( - child.fp8_amax_dL_dY, child.fp8_amax_history_dL_dY + return + + # Loop over all fp8 layers and grab the needed tensors + fp8_amax_x_tensor_list = [None] * len(fp8_layers) + fp8_amax_w_tensor_list = [None] * len(fp8_layers) + fp8_amax_dL_dY_tensor_list = [None] * len(fp8_layers) + + fp8_x_amax_history_stack = [None] * len(fp8_layers) + fp8_w_amax_history_stack = [None] * len(fp8_layers) + fp8_dL_dY_amax_history_stack = [None] * len(fp8_layers) + + x_dtypes = set() + scale_fn_recipes = set() + + for idx, child in enumerate(fp8_layers): + fp8_amax_x_tensor_list[idx] = child.fp8_amax_x + fp8_amax_w_tensor_list[idx] = child.fp8_amax_w + fp8_amax_dL_dY_tensor_list[idx] = child.fp8_amax_dL_dY + + fp8_x_amax_history_stack[idx] = child.fp8_amax_history_x + fp8_w_amax_history_stack[idx] = child.fp8_amax_history_w + fp8_dL_dY_amax_history_stack[idx] = child.fp8_amax_history_dL_dY + + x_dtypes.add(child.last_seen_input_dtype) + scale_fn_recipes.add(child.recipe.scale_fn_name) + + # TODO This way to get the activation dtype is not ideal + if len(x_dtypes) != 1: + raise ValueError( + f"All layers must have the same last seen input_dtype, got {x_dtypes}" ) + x_dtype = next(iter(x_dtypes)) - # - # 3. calculate the scales - # - # TODO what to do with x_dtype - x_dtype = child.last_seen_input_dtype - new_scale = amax_history_to_scale( - child.fp8_amax_history_x, - torch.float8_e4m3fn, - x_dtype, - child.recipe.scale_fn_name, + if len(scale_fn_recipes) != 1: + raise ValueError( + f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}" ) - child.fp8_scale_x.copy_(new_scale) - new_scale = amax_history_to_scale( - child.fp8_amax_history_w, - torch.float8_e4m3fn, - x_dtype, - child.recipe.scale_fn_name, + scale_fn_recipe = next(iter(scale_fn_recipes)) + + assert ( + len(fp8_amax_x_tensor_list) + == len(fp8_amax_w_tensor_list) + == len(fp8_amax_dL_dY_tensor_list) + ), "Mismatched lengths of amax tensors." + + if dist.is_initialized(): + # Combine all the amax tensors into one tensor and reduce it + all_amax_tensors = torch.cat( + fp8_amax_x_tensor_list + fp8_amax_w_tensor_list + fp8_amax_dL_dY_tensor_list ) - child.fp8_scale_w.copy_(new_scale) - new_scale = amax_history_to_scale( - child.fp8_amax_history_dL_dY, - torch.float8_e5m2, - x_dtype, - child.recipe.scale_fn_name, + all_reduced_amax_tensor = all_reduce( + all_amax_tensors, "MAX", list(range(dist.get_world_size())) ) - child.fp8_scale_dL_dY.copy_(new_scale) + if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor): + all_reduced_amax_tensor = all_reduced_amax_tensor.wait() + + ( + reduced_fp8_amax_tensor, + reduced_fp8_amax_w_tensor, + reduced_fp8_amax_dL_dY_tensor, + ) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list)) + + for idx, child in enumerate(fp8_layers): + child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx]) + child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx]) + child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx]) + + # We create two stacked tensor groups, one for the amax history and one for the current scales + fp8_amax_x_tensors = torch.vstack(fp8_amax_x_tensor_list) + fp8_amax_w_tensors = torch.vstack(fp8_amax_w_tensor_list) + fp8_amax_dL_dY_tensors = torch.vstack(fp8_amax_dL_dY_tensor_list) + + fp8_x_amax_history_stack = torch.vstack(fp8_x_amax_history_stack) + fp8_w_amax_history_stack = torch.vstack(fp8_w_amax_history_stack) + fp8_dL_dY_amax_history_stack = torch.vstack(fp8_dL_dY_amax_history_stack) + + # Update the history stacks with the new amax values + _update_history_stack(fp8_amax_x_tensors, fp8_x_amax_history_stack) + _update_history_stack(fp8_amax_w_tensors, fp8_w_amax_history_stack) + _update_history_stack(fp8_amax_dL_dY_tensors, fp8_dL_dY_amax_history_stack) + + # Calculate the new scales from the updated history stacks + new_x_scales = amax_history_to_scale_stack( + fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe + ) + new_w_scales = amax_history_to_scale_stack( + fp8_w_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe + ) + new_dL_dY_scales = amax_history_to_scale_stack( + fp8_dL_dY_amax_history_stack, torch.float8_e5m2, x_dtype, scale_fn_recipe + ) + + # Iterate through the layers and update the scales, and set the flag to signal that the amaxes/scales are ready + for idx, child in enumerate(fp8_layers): + child.fp8_scale_x.copy_(new_x_scales[idx]) + child.fp8_scale_w.copy_(new_w_scales[idx]) + child.fp8_scale_dL_dY.copy_(new_dL_dY_scales[idx]) - # # 4. set a flag to signal amaxes/scales are ready - # - child.amax_and_scale_synced = True + # We only update the flag if we know it will be checked by the modules + if fp8_config.enable_amax_init: + child.amax_and_scale_synced = True diff --git a/float8_experimental/float8_python_api.py b/float8_experimental/float8_python_api.py index 4f670e17..9182f626 100644 --- a/float8_experimental/float8_python_api.py +++ b/float8_experimental/float8_python_api.py @@ -12,6 +12,8 @@ from typing import Optional, Tuple +import float8_experimental.float8_aten_api # noqa + import torch from float8_experimental.float8_tensor import Float8Tensor diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index ed6de1c6..8661d3ee 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -23,7 +23,7 @@ @torch.no_grad() def amax_to_scale(amax, float8_dtype, orig_dtype): - scale = torch.empty((), device=amax.device, dtype=torch.float32) + scale = torch.empty_like(amax, dtype=torch.float32) if float8_dtype == torch.float8_e4m3fn: res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) else: # e5m2 @@ -51,6 +51,22 @@ def amax_history_to_scale( raise NotImplementedError() +@torch.no_grad() +def amax_history_to_scale_stack( + amax_history: torch.Tensor, + float8_dtype: torch.dtype, + orig_dtype: torch.dtype, + history_to_scale_fn_type: str, +) -> torch.Tensor: + """Takes in a stack of amax_history tensors and returns a scale tensor.""" + if history_to_scale_fn_type == "max": + amax_stack = torch.max(amax_history, dim=1).values + return amax_to_scale(amax_stack, float8_dtype, orig_dtype) + raise NotImplementedError( + f"Invalid history_to_scale_fn_type, only 'max' is supported. Got: {history_to_scale_fn_type}" + ) + + @torch.no_grad() def tensor_to_amax(x, distributed_reduction=False): amax = torch.max(torch.abs(x))