From e46b137d687e5265d4b7844bc3261df1ab4f89d9 Mon Sep 17 00:00:00 2001 From: drisspg Date: Mon, 12 Feb 2024 13:42:05 -0800 Subject: [PATCH 01/14] update signature and add not about usage --- float8_experimental/float8_linear_utils.py | 33 +++++++++++++--------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 3450f154..13875ae0 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -7,6 +7,8 @@ from enum import auto, Enum from typing import List, Optional, Type +import float8_experimental.config as fp8_config + import torch import torch.distributed as dist import torch.nn as nn @@ -121,21 +123,23 @@ def post_order_traversal( return root_module -def get_float8_layers(model: torch.nn.Module, fp8_classes=None): - if fp8_classes is None: - fp8_classes = Float8Linear +def get_float8_layers(model: torch.nn.Module): + """Iterates through the model and returns all the Float8Linear layers. + Args: + model (torch.nn.Module): The model to look for Float8Linear layers in. + """ # Get all fp8 layers and tensors fp8_layers = [ - child for name, child in model.named_modules() if isinstance(child, fp8_classes) + child + for name, child in model.named_modules() + if isinstance(child, Float8Linear) ] return fp8_layers -def sync_float8_amax_and_scale_history( - model: torch.nn.Module, fp8_classes=None, fp8_layers=None -) -> None: +def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) -> None: """ Manages the float8 amax and scale bookkeeping. In detail, it does the following: @@ -147,11 +151,13 @@ def sync_float8_amax_and_scale_history( TODO(future): design the UX for this (context manager, etc) + PERFORMANCE NOTE: + When you can it is much more efficient to call te get_float8_layers once a + the beginning of the training loop and pass the result to this function. + Because of how this interacts with torch.compile + Args: model (torch.nn.Module): The model to track amaxes for - fp8_classes (optional): The fp8 classes to look for in the model. - The default is Float8Linear. - When using with TP, users can pass in the customized TP classes instead. fp8_layers (optional): If fp8_layers are provided, fp8_classes are ignored, and we loop over all fp8_layers to sync and update amax scale histories. Users can use get_float8_layers to get all fp8 layers. @@ -163,7 +169,7 @@ def sync_float8_amax_and_scale_history( # make the history update faster. if fp8_layers is None: - fp8_layers = get_float8_layers(model, fp8_classes) + fp8_layers = get_float8_layers(model) if dist.is_initialized(): fp8_amax_x_tensor = torch.tensor( @@ -237,5 +243,6 @@ def sync_float8_amax_and_scale_history( # # 4. set a flag to signal amaxes/scales are ready - # - child.amax_and_scale_synced = True + # We only update the flag if we know it will be checked by the modules + if fp8_config.enable_amax_init: + child.amax_and_scale_synced = True From 5d479362a058b57638a8e353c991e4f0ff1abf13 Mon Sep 17 00:00:00 2001 From: drisspg Date: Mon, 12 Feb 2024 16:14:27 -0800 Subject: [PATCH 02/14] update to make more comple friendly --- float8_experimental/float8_linear.py | 13 ++--- float8_experimental/float8_linear_utils.py | 58 +++++++++++----------- float8_experimental/float8_python_api.py | 2 + float8_experimental/float8_tensor.py | 4 +- float8_experimental/float8_utils.py | 2 +- 5 files changed, 41 insertions(+), 38 deletions(-) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index af8d89b5..0fc868d4 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -138,23 +138,24 @@ def __init__(self, *args, **kwargs): self.recipe = delayed_scaling_recipe history_len = self.recipe.history_len - self.register_always_float32_buffer("fp8_amax_x", torch.tensor(E4M3_MAX_POS)) + self.register_always_float32_buffer("fp8_amax_x", torch.tensor([E4M3_MAX_POS])) self.register_always_float32_buffer( "fp8_amax_history_x", torch.zeros(history_len) ) - self.register_always_float32_buffer("fp8_scale_x", torch.tensor(1.0)) - self.register_always_float32_buffer("fp8_amax_w", torch.tensor(E4M3_MAX_POS)) + self.register_always_float32_buffer("fp8_scale_x", torch.tensor([1.0])) + self.register_always_float32_buffer("fp8_amax_w", torch.tensor([E4M3_MAX_POS])) self.register_always_float32_buffer( "fp8_amax_history_w", torch.zeros(history_len) ) - self.register_always_float32_buffer("fp8_scale_w", torch.tensor(1.0)) + self.register_always_float32_buffer("fp8_scale_w", torch.tensor([1.0])) self.register_always_float32_buffer( - "fp8_amax_dL_dY", torch.tensor(E5M2_MAX_POS) + "fp8_amax_dL_dY", torch.tensor([E5M2_MAX_POS]) ) self.register_always_float32_buffer( "fp8_amax_history_dL_dY", torch.zeros(history_len) ) - self.register_always_float32_buffer("fp8_scale_dL_dY", torch.tensor(1.0)) + self.register_always_float32_buffer("fp8_scale_dL_dY", torch.tensor([1.0])) + # Whether to emulate the fp8 matmul logic in float32 self.emulate = False diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 13875ae0..9f33ee29 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -172,39 +172,37 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) fp8_layers = get_float8_layers(model) if dist.is_initialized(): - fp8_amax_x_tensor = torch.tensor( - [child.fp8_amax_x for child in fp8_layers], - dtype=torch.float32, - device="cuda", - requires_grad=False, - ) - fp8_amax_w_tensor = torch.tensor( - [child.fp8_amax_w for child in fp8_layers], - dtype=torch.float32, - device="cuda", - requires_grad=False, - ) - fp8_amax_dL_dY_tensor = torch.tensor( - [child.fp8_amax_dL_dY for child in fp8_layers], - dtype=torch.float32, - device="cuda", - requires_grad=False, - ) - dist.all_reduce(fp8_amax_x_tensor, op=dist.ReduceOp.MAX) - dist.all_reduce(fp8_amax_w_tensor, op=dist.ReduceOp.MAX) - dist.all_reduce(fp8_amax_dL_dY_tensor, op=dist.ReduceOp.MAX) - + fp8_amax_x_tensors = [child.fp8_amax_x for child in fp8_layers] + fp8_amax_w_tensors = [child.fp8_amax_w for child in fp8_layers] + fp8_amax_dL_dY_tensors = [child.fp8_amax_dL_dY for child in fp8_layers] + + assert ( + len(fp8_amax_x_tensors) + == len(fp8_amax_w_tensors) + == len(fp8_amax_dL_dY_tensors) + ), "Mismatched lengths of amax tensors." + if len(fp8_amax_x_tensors) > 0: + # Combine all the amax tensors into one tensor and reduce it + fp8_amax_x_tensor = torch.cat(fp8_amax_x_tensors) + fp8_amax_w_tensor = torch.cat(fp8_amax_w_tensors) + fp8_amax_dL_dY_tensor = torch.cat(fp8_amax_dL_dY_tensors) + + dist.all_reduce(fp8_amax_x_tensor, op=dist.ReduceOp.MAX) + dist.all_reduce(fp8_amax_w_tensor, op=dist.ReduceOp.MAX) + dist.all_reduce(fp8_amax_dL_dY_tensor, op=dist.ReduceOp.MAX) + + # Reassign the reduced amax values to the original tensors + + for idx in range(len(fp8_layers)): + child = fp8_layers[idx] + child.fp8_amax_x.copy_(fp8_amax_x_tensor[idx].clone()) + child.fp8_amax_w.copy_(fp8_amax_w_tensor[idx].clone()) + child.fp8_amax_dL_dY.copy_(fp8_amax_dL_dY_tensor[idx].clone()) + + # Itearte over all the layers and update the amax history and scales for idx in range(len(fp8_layers)): child = fp8_layers[idx] - # - # 1. in distributed contexts, syncs amax values across workers - # - if dist.is_initialized(): - child.fp8_amax_x = fp8_amax_x_tensor[idx].clone() - child.fp8_amax_w = fp8_amax_w_tensor[idx].clone() - child.fp8_amax_dL_dY = fp8_amax_dL_dY_tensor[idx].clone() - # # 2. adds the `amax` values to history # diff --git a/float8_experimental/float8_python_api.py b/float8_experimental/float8_python_api.py index 4f670e17..9182f626 100644 --- a/float8_experimental/float8_python_api.py +++ b/float8_experimental/float8_python_api.py @@ -12,6 +12,8 @@ from typing import Optional, Tuple +import float8_experimental.float8_aten_api # noqa + import torch from float8_experimental.float8_tensor import Float8Tensor diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 6063f7c1..d88d3f73 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -54,7 +54,9 @@ class FromFloat8ConstrFunc(torch.autograd.Function): @staticmethod def forward(ctx, tensor): - return tensor._data.to(tensor._orig_dtype) / tensor._scale + return (tensor._data.to(tensor._orig_dtype) / tensor._scale).to( + tensor._orig_dtype + ) @staticmethod def backward(ctx, g): diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index ed6de1c6..5ecb7bba 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -23,7 +23,7 @@ @torch.no_grad() def amax_to_scale(amax, float8_dtype, orig_dtype): - scale = torch.empty((), device=amax.device, dtype=torch.float32) + scale = torch.empty((1,), device=amax.device, dtype=torch.float32) if float8_dtype == torch.float8_e4m3fn: res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) else: # e5m2 From ab5043e05416e44662d92e42f81e4158b029d0a6 Mon Sep 17 00:00:00 2001 From: drisspg Date: Mon, 12 Feb 2024 19:30:22 -0800 Subject: [PATCH 03/14] combine into larger tensor operations --- float8_experimental/float8_linear_utils.py | 173 ++++++++++++--------- float8_experimental/float8_utils.py | 18 ++- 2 files changed, 118 insertions(+), 73 deletions(-) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 9f33ee29..d2b76fdf 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import copy +import logging from enum import auto, Enum from typing import List, Optional, Type @@ -15,7 +16,10 @@ from float8_experimental.float8_dynamic_linear import Float8DynamicLinear from float8_experimental.float8_linear import Float8Linear -from float8_experimental.float8_utils import amax_history_to_scale +from float8_experimental.float8_utils import amax_history_to_scale_stack + +log = logging.getLogger(__name__) +log.addHandler(logging.NullHandler()) class LinearType(Enum): @@ -69,6 +73,26 @@ def _update_history_with_new_amax(new_amax, amax_history): amax_history.copy_(new_amax_history) +def _update_history_stack( + new_amax: torch.Tensor, amax_history_stack: torch.Tensor +) -> torch.Tensor: + """ + Updates `amax_history` (the last N cur_amax values) inplace with the value + of `new_amax`. + + Args: + new_amax (torch.Tensor): The new amax value to add to the history. (n_amaxes, 1) + amax_history_stack (torch.Tensor): The history of amax values. (n_amaxes, history_length) + """ + assert amax_history_stack.dim() == 2, "amax_history_stack must be 2D" + assert new_amax.size(0) == amax_history_stack.size( + 0 + ), "new_amax must have the same size as the second dimension of amax_history_stack" + new_amax_history_stack = torch.roll(amax_history_stack, 1, dims=1) + new_amax_history_stack[:, 0] = new_amax.squeeze(-1) + amax_history_stack.copy_(new_amax_history_stack) + + def swap_linear_with_float8_linear( module: nn.Module, module_cls: Type[nn.Module], @@ -162,84 +186,89 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) and we loop over all fp8_layers to sync and update amax scale histories. Users can use get_float8_layers to get all fp8 layers. """ - - # For now, this is written in a naive way to maximize code readability. - # TODO(future): benchmark and optimize as needed, we have combined all - # the reductions into one and we can probably try other optimizatons to - # make the history update faster. - if fp8_layers is None: fp8_layers = get_float8_layers(model) + if len(fp8_layers) == 0: + log.warn( + "Calling sync_float8_amax_and_scale_history on a module with no Float8Linear layers" + ) + return + + fp8_amax_x_tensor_list = [child.fp8_amax_x for child in fp8_layers] + fp8_amax_w_tensor_list = [child.fp8_amax_w for child in fp8_layers] + fp8_amax_dL_dY_tensor_list = [child.fp8_amax_dL_dY for child in fp8_layers] + + assert ( + len(fp8_amax_x_tensor_list) + == len(fp8_amax_w_tensor_list) + == len(fp8_amax_dL_dY_tensor_list) + ), "Mismatched lengths of amax tensors." + if dist.is_initialized(): - fp8_amax_x_tensors = [child.fp8_amax_x for child in fp8_layers] - fp8_amax_w_tensors = [child.fp8_amax_w for child in fp8_layers] - fp8_amax_dL_dY_tensors = [child.fp8_amax_dL_dY for child in fp8_layers] - - assert ( - len(fp8_amax_x_tensors) - == len(fp8_amax_w_tensors) - == len(fp8_amax_dL_dY_tensors) - ), "Mismatched lengths of amax tensors." - if len(fp8_amax_x_tensors) > 0: - # Combine all the amax tensors into one tensor and reduce it - fp8_amax_x_tensor = torch.cat(fp8_amax_x_tensors) - fp8_amax_w_tensor = torch.cat(fp8_amax_w_tensors) - fp8_amax_dL_dY_tensor = torch.cat(fp8_amax_dL_dY_tensors) - - dist.all_reduce(fp8_amax_x_tensor, op=dist.ReduceOp.MAX) - dist.all_reduce(fp8_amax_w_tensor, op=dist.ReduceOp.MAX) - dist.all_reduce(fp8_amax_dL_dY_tensor, op=dist.ReduceOp.MAX) - - # Reassign the reduced amax values to the original tensors - - for idx in range(len(fp8_layers)): - child = fp8_layers[idx] - child.fp8_amax_x.copy_(fp8_amax_x_tensor[idx].clone()) - child.fp8_amax_w.copy_(fp8_amax_w_tensor[idx].clone()) - child.fp8_amax_dL_dY.copy_(fp8_amax_dL_dY_tensor[idx].clone()) - - # Itearte over all the layers and update the amax history and scales - for idx in range(len(fp8_layers)): - child = fp8_layers[idx] + # Combine all the amax tensors into one tensor and reduce it + fp8_amax_x_tensor = torch.cat(fp8_amax_x_tensor_list) + fp8_amax_w_tensor = torch.cat(fp8_amax_w_tensor_list) + fp8_amax_dL_dY_tensor = torch.cat(fp8_amax_dL_dY_tensor_list) + + dist.all_reduce(fp8_amax_x_tensor, op=dist.ReduceOp.MAX) + dist.all_reduce(fp8_amax_w_tensor, op=dist.ReduceOp.MAX) + dist.all_reduce(fp8_amax_dL_dY_tensor, op=dist.ReduceOp.MAX) + + # Reassign the reduced amax values to the original tensors + for idx in range(len(fp8_layers)): + child = fp8_layers[idx] + # Do we need this extra clone? + child.fp8_amax_x.copy_(fp8_amax_x_tensor[idx].clone()) + child.fp8_amax_w.copy_(fp8_amax_w_tensor[idx].clone()) + child.fp8_amax_dL_dY.copy_(fp8_amax_dL_dY_tensor[idx].clone()) + + # We create two stacked tensors, one for the amax history and one for the current scales + fp8_amax_x_tensors = torch.vstack(fp8_amax_x_tensor_list) + fp8_amax_w_tensors = torch.vstack(fp8_amax_w_tensor_list) + fp8_amax_dL_dY_tensors = torch.vstack(fp8_amax_dL_dY_tensor_list) + + fp8_x_amax_history_stack = torch.stack( + [child.fp8_amax_history_x for child in fp8_layers] + ) + fp8_w_amax_history_stack = torch.stack( + [child.fp8_amax_history_w for child in fp8_layers] + ) + fp8_dL_dY_amax_history_stack = torch.stack( + [child.fp8_amax_history_dL_dY for child in fp8_layers] + ) - # - # 2. adds the `amax` values to history - # - _update_history_with_new_amax(child.fp8_amax_x, child.fp8_amax_history_x) - _update_history_with_new_amax(child.fp8_amax_w, child.fp8_amax_history_w) - _update_history_with_new_amax( - child.fp8_amax_dL_dY, child.fp8_amax_history_dL_dY - ) + _update_history_stack(fp8_amax_x_tensors, fp8_x_amax_history_stack) + _update_history_stack(fp8_amax_w_tensors, fp8_w_amax_history_stack) + _update_history_stack(fp8_amax_dL_dY_tensors, fp8_dL_dY_amax_history_stack) - # - # 3. calculate the scales - # - # TODO what to do with x_dtype - x_dtype = child.last_seen_input_dtype - new_scale = amax_history_to_scale( - child.fp8_amax_history_x, - torch.float8_e4m3fn, - x_dtype, - child.recipe.scale_fn_name, - ) - child.fp8_scale_x.copy_(new_scale) - new_scale = amax_history_to_scale( - child.fp8_amax_history_w, - torch.float8_e4m3fn, - x_dtype, - child.recipe.scale_fn_name, - ) - child.fp8_scale_w.copy_(new_scale) - new_scale = amax_history_to_scale( - child.fp8_amax_history_dL_dY, - torch.float8_e5m2, - x_dtype, - child.recipe.scale_fn_name, - ) - child.fp8_scale_dL_dY.copy_(new_scale) + # TODO This way to get the activation dtype is not ideal + x_dtypes = {child.last_seen_input_dtype for child in fp8_layers} + assert len(x_dtypes) == 1, "All layers must have the same last seen input_dtype" + x_dtype = next(iter(x_dtypes)) + + scale_fn_recipes = {child.recipe.scale_fn_name for child in fp8_layers} + assert len(scale_fn_recipes) == 1, "All layers must have the same scale_fn recipe" + scale_fn_recipe = next(iter(scale_fn_recipes)) + + # We are not reading the + new_x_scales = amax_history_to_scale_stack( + fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe + ) + new_w_scales = amax_history_to_scale_stack( + fp8_w_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe + ) + new_dL_dY_scales = amax_history_to_scale_stack( + fp8_dL_dY_amax_history_stack, torch.float8_e5m2, x_dtype, scale_fn_recipe + ) + + # Iterate through the layers and update the scales, and set the flag to signal that the amaxes/scales are ready + for idx in range(len(fp8_layers)): + child = fp8_layers[idx] + child.fp8_scale_x.copy_(new_x_scales[idx]) + child.fp8_scale_w.copy_(new_w_scales[idx]) + child.fp8_scale_dL_dY.copy_(new_dL_dY_scales[idx]) - # # 4. set a flag to signal amaxes/scales are ready # We only update the flag if we know it will be checked by the modules if fp8_config.enable_amax_init: diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 5ecb7bba..bd1d9f26 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -23,7 +23,7 @@ @torch.no_grad() def amax_to_scale(amax, float8_dtype, orig_dtype): - scale = torch.empty((1,), device=amax.device, dtype=torch.float32) + scale = torch.empty_like(amax, dtype=torch.float32) if float8_dtype == torch.float8_e4m3fn: res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) else: # e5m2 @@ -51,6 +51,22 @@ def amax_history_to_scale( raise NotImplementedError() +@torch.no_grad() +def amax_history_to_scale_stack( + amax_history: torch.Tensor, + float8_dtype: torch.dtype, + orig_dtype: torch.dtype, + history_to_scale_fn_type: str, +) -> torch.Tensor: + """Takes in a stack of amax_history tensors and returns a scale tensor.""" + if history_to_scale_fn_type == "max": + amax_stack = torch.max(amax_history, dim=1).values + return amax_to_scale(amax_stack, float8_dtype, orig_dtype) + raise NotImplementedError( + "Invalid history_to_scale_fn_type, only 'max' is supported." + ) + + @torch.no_grad() def tensor_to_amax(x, distributed_reduction=False): amax = torch.max(torch.abs(x)) From 745d73aa2371446315bff977dd3fd41e101be5ab Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 13 Feb 2024 09:29:17 -0800 Subject: [PATCH 04/14] Use less loops --- float8_experimental/float8_linear_utils.py | 71 ++++++++++++---------- 1 file changed, 39 insertions(+), 32 deletions(-) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index d2b76fdf..66d73e03 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -155,9 +155,7 @@ def get_float8_layers(model: torch.nn.Module): # Get all fp8 layers and tensors fp8_layers = [ - child - for name, child in model.named_modules() - if isinstance(child, Float8Linear) + child for _, child in model.named_modules() if isinstance(child, Float8Linear) ] return fp8_layers @@ -195,9 +193,36 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) ) return - fp8_amax_x_tensor_list = [child.fp8_amax_x for child in fp8_layers] - fp8_amax_w_tensor_list = [child.fp8_amax_w for child in fp8_layers] - fp8_amax_dL_dY_tensor_list = [child.fp8_amax_dL_dY for child in fp8_layers] + # Loop over all fp8 layers and grab the needed tensors + fp8_amax_x_tensor_list = [] + fp8_amax_w_tensor_list = [] + fp8_amax_dL_dY_tensor_list = [] + + fp8_x_amax_history_stack = [] + fp8_w_amax_history_stack = [] + fp8_dL_dY_amax_history_stack = [] + + x_dtypes = set() + scale_fn_recipes = set() + + for child in fp8_layers: + fp8_amax_x_tensor_list.append(child.fp8_amax_x) + fp8_amax_w_tensor_list.append(child.fp8_amax_w) + fp8_amax_dL_dY_tensor_list.append(child.fp8_amax_dL_dY) + + fp8_x_amax_history_stack.append(child.fp8_amax_history_x) + fp8_w_amax_history_stack.append(child.fp8_amax_history_w) + fp8_dL_dY_amax_history_stack.append(child.fp8_amax_history_dL_dY) + + x_dtypes.add(child.last_seen_input_dtype) + scale_fn_recipes.add(child.recipe.scale_fn_name) + + # TODO This way to get the activation dtype is not ideal + assert len(x_dtypes) == 1, "All layers must have the same last seen input_dtype" + x_dtype = next(iter(x_dtypes)) + + assert len(scale_fn_recipes) == 1, "All layers must have the same scale_fn recipe" + scale_fn_recipe = next(iter(scale_fn_recipes)) assert ( len(fp8_amax_x_tensor_list) @@ -216,41 +241,24 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) dist.all_reduce(fp8_amax_dL_dY_tensor, op=dist.ReduceOp.MAX) # Reassign the reduced amax values to the original tensors - for idx in range(len(fp8_layers)): - child = fp8_layers[idx] - # Do we need this extra clone? - child.fp8_amax_x.copy_(fp8_amax_x_tensor[idx].clone()) - child.fp8_amax_w.copy_(fp8_amax_w_tensor[idx].clone()) - child.fp8_amax_dL_dY.copy_(fp8_amax_dL_dY_tensor[idx].clone()) + for idx, child in enumerate(fp8_layers): + child.fp8_amax_x.copy_(fp8_amax_x_tensor[idx]) + child.fp8_amax_w.copy_(fp8_amax_w_tensor[idx]) + child.fp8_amax_dL_dY.copy_(fp8_amax_dL_dY_tensor[idx]) # We create two stacked tensors, one for the amax history and one for the current scales fp8_amax_x_tensors = torch.vstack(fp8_amax_x_tensor_list) fp8_amax_w_tensors = torch.vstack(fp8_amax_w_tensor_list) fp8_amax_dL_dY_tensors = torch.vstack(fp8_amax_dL_dY_tensor_list) - fp8_x_amax_history_stack = torch.stack( - [child.fp8_amax_history_x for child in fp8_layers] - ) - fp8_w_amax_history_stack = torch.stack( - [child.fp8_amax_history_w for child in fp8_layers] - ) - fp8_dL_dY_amax_history_stack = torch.stack( - [child.fp8_amax_history_dL_dY for child in fp8_layers] - ) + fp8_x_amax_history_stack = torch.vstack(fp8_x_amax_history_stack) + fp8_w_amax_history_stack = torch.vstack(fp8_w_amax_history_stack) + fp8_dL_dY_amax_history_stack = torch.vstack(fp8_dL_dY_amax_history_stack) _update_history_stack(fp8_amax_x_tensors, fp8_x_amax_history_stack) _update_history_stack(fp8_amax_w_tensors, fp8_w_amax_history_stack) _update_history_stack(fp8_amax_dL_dY_tensors, fp8_dL_dY_amax_history_stack) - # TODO This way to get the activation dtype is not ideal - x_dtypes = {child.last_seen_input_dtype for child in fp8_layers} - assert len(x_dtypes) == 1, "All layers must have the same last seen input_dtype" - x_dtype = next(iter(x_dtypes)) - - scale_fn_recipes = {child.recipe.scale_fn_name for child in fp8_layers} - assert len(scale_fn_recipes) == 1, "All layers must have the same scale_fn recipe" - scale_fn_recipe = next(iter(scale_fn_recipes)) - # We are not reading the new_x_scales = amax_history_to_scale_stack( fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe @@ -263,8 +271,7 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) ) # Iterate through the layers and update the scales, and set the flag to signal that the amaxes/scales are ready - for idx in range(len(fp8_layers)): - child = fp8_layers[idx] + for idx, child in enumerate(fp8_layers): child.fp8_scale_x.copy_(new_x_scales[idx]) child.fp8_scale_w.copy_(new_w_scales[idx]) child.fp8_scale_dL_dY.copy_(new_dL_dY_scales[idx]) From 09a52a922f7ba6fb09da9e4e60b701da843697b1 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 13 Feb 2024 10:00:54 -0800 Subject: [PATCH 05/14] use functional reduce --- float8_experimental/float8_linear_utils.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 66d73e03..0f7f4504 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -17,6 +17,8 @@ from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_utils import amax_history_to_scale_stack +from torch.distributed._functional_collectives import all_reduce +from torch.distributed.distributed_c10d import _get_default_group log = logging.getLogger(__name__) log.addHandler(logging.NullHandler()) @@ -236,15 +238,15 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) fp8_amax_w_tensor = torch.cat(fp8_amax_w_tensor_list) fp8_amax_dL_dY_tensor = torch.cat(fp8_amax_dL_dY_tensor_list) - dist.all_reduce(fp8_amax_x_tensor, op=dist.ReduceOp.MAX) - dist.all_reduce(fp8_amax_w_tensor, op=dist.ReduceOp.MAX) - dist.all_reduce(fp8_amax_dL_dY_tensor, op=dist.ReduceOp.MAX) + reduced_fp8_amax_tensor = all_reduce(fp8_amax_x_tensor, "MAX", _get_default_group()) + reduced_fp8_amax_w_tensor = all_reduce(fp8_amax_w_tensor, "MAX", _get_default_group()) + reduced_fp8_amax_dL_dY_tensor = all_reduce(fp8_amax_dL_dY_tensor, "MAX", _get_default_group()) # Reassign the reduced amax values to the original tensors for idx, child in enumerate(fp8_layers): - child.fp8_amax_x.copy_(fp8_amax_x_tensor[idx]) - child.fp8_amax_w.copy_(fp8_amax_w_tensor[idx]) - child.fp8_amax_dL_dY.copy_(fp8_amax_dL_dY_tensor[idx]) + child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx]) + child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx]) + child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx]) # We create two stacked tensors, one for the amax history and one for the current scales fp8_amax_x_tensors = torch.vstack(fp8_amax_x_tensor_list) From bca21218494635b51eeb6aa4d0056babf622933c Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 13 Feb 2024 10:39:05 -0800 Subject: [PATCH 06/14] use one reduce intead of 3 --- float8_experimental/float8_linear_utils.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 0f7f4504..ae306ac9 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -234,15 +234,17 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) if dist.is_initialized(): # Combine all the amax tensors into one tensor and reduce it - fp8_amax_x_tensor = torch.cat(fp8_amax_x_tensor_list) - fp8_amax_w_tensor = torch.cat(fp8_amax_w_tensor_list) - fp8_amax_dL_dY_tensor = torch.cat(fp8_amax_dL_dY_tensor_list) - - reduced_fp8_amax_tensor = all_reduce(fp8_amax_x_tensor, "MAX", _get_default_group()) - reduced_fp8_amax_w_tensor = all_reduce(fp8_amax_w_tensor, "MAX", _get_default_group()) - reduced_fp8_amax_dL_dY_tensor = all_reduce(fp8_amax_dL_dY_tensor, "MAX", _get_default_group()) - - # Reassign the reduced amax values to the original tensors + all_amax_tensors = torch.cat( + fp8_amax_x_tensor_list + fp8_amax_w_tensor_list + fp8_amax_dL_dY_tensor_list + ) + all_reduced_amax_tensor = all_reduce( + all_amax_tensors, "MAX", _get_default_group() + ) + ( + reduced_fp8_amax_tensor, + reduced_fp8_amax_w_tensor, + reduced_fp8_amax_dL_dY_tensor, + ) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list)) for idx, child in enumerate(fp8_layers): child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx]) child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx]) From 2a306b8c5b29a19ffcc13637375607fcfec802b8 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 13 Feb 2024 11:17:57 -0800 Subject: [PATCH 07/14] add comment on trying foreach --- float8_experimental/float8_linear_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index ae306ac9..5a6dce51 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -245,6 +245,8 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) reduced_fp8_amax_w_tensor, reduced_fp8_amax_dL_dY_tensor, ) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list)) + + # TODO foreach is not supported with AsyncCollectiveTensor for idx, child in enumerate(fp8_layers): child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx]) child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx]) From 1eec9ff69552599e7c84858fd74c60e636738382 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 13 Feb 2024 12:30:54 -0800 Subject: [PATCH 08/14] preallocate lists --- float8_experimental/float8_linear_utils.py | 26 +++++++++++----------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 5a6dce51..eac859ae 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -196,25 +196,25 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) return # Loop over all fp8 layers and grab the needed tensors - fp8_amax_x_tensor_list = [] - fp8_amax_w_tensor_list = [] - fp8_amax_dL_dY_tensor_list = [] + fp8_amax_x_tensor_list = [None] * len(fp8_layers) + fp8_amax_w_tensor_list = [None] * len(fp8_layers) + fp8_amax_dL_dY_tensor_list = [None] * len(fp8_layers) - fp8_x_amax_history_stack = [] - fp8_w_amax_history_stack = [] - fp8_dL_dY_amax_history_stack = [] + fp8_x_amax_history_stack = [None] * len(fp8_layers) + fp8_w_amax_history_stack = [None] * len(fp8_layers) + fp8_dL_dY_amax_history_stack = [None] * len(fp8_layers) x_dtypes = set() scale_fn_recipes = set() - for child in fp8_layers: - fp8_amax_x_tensor_list.append(child.fp8_amax_x) - fp8_amax_w_tensor_list.append(child.fp8_amax_w) - fp8_amax_dL_dY_tensor_list.append(child.fp8_amax_dL_dY) + for idx, child in enumerate(fp8_layers): + fp8_amax_x_tensor_list[idx] = child.fp8_amax_x + fp8_amax_w_tensor_list[idx] = child.fp8_amax_w + fp8_amax_dL_dY_tensor_list[idx] = child.fp8_amax_dL_dY - fp8_x_amax_history_stack.append(child.fp8_amax_history_x) - fp8_w_amax_history_stack.append(child.fp8_amax_history_w) - fp8_dL_dY_amax_history_stack.append(child.fp8_amax_history_dL_dY) + fp8_x_amax_history_stack[idx] = child.fp8_amax_history_x + fp8_w_amax_history_stack[idx] = child.fp8_amax_history_w + fp8_dL_dY_amax_history_stack[idx] = child.fp8_amax_history_dL_dY x_dtypes.add(child.last_seen_input_dtype) scale_fn_recipes.add(child.recipe.scale_fn_name) From 66d22cbe24bea1804def3144e3b0b838527957d3 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 13 Feb 2024 16:56:02 -0800 Subject: [PATCH 09/14] use functional tensor --- float8_experimental/float8_linear_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index eac859ae..add024ae 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -17,8 +17,7 @@ from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_utils import amax_history_to_scale_stack -from torch.distributed._functional_collectives import all_reduce -from torch.distributed.distributed_c10d import _get_default_group +from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor log = logging.getLogger(__name__) log.addHandler(logging.NullHandler()) @@ -238,8 +237,11 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) fp8_amax_x_tensor_list + fp8_amax_w_tensor_list + fp8_amax_dL_dY_tensor_list ) all_reduced_amax_tensor = all_reduce( - all_amax_tensors, "MAX", _get_default_group() + all_amax_tensors, "MAX", list(range(dist.get_world_size())) ) + if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor): + all_reduced_amax_tensor = all_reduced_amax_tensor.wait() + ( reduced_fp8_amax_tensor, reduced_fp8_amax_w_tensor, From d8015c7fd6a312a6ea19c214560b9237465f40e7 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 13 Feb 2024 18:17:10 -0800 Subject: [PATCH 10/14] remove re-casting --- float8_experimental/float8_tensor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index d88d3f73..6063f7c1 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -54,9 +54,7 @@ class FromFloat8ConstrFunc(torch.autograd.Function): @staticmethod def forward(ctx, tensor): - return (tensor._data.to(tensor._orig_dtype) / tensor._scale).to( - tensor._orig_dtype - ) + return tensor._data.to(tensor._orig_dtype) / tensor._scale @staticmethod def backward(ctx, g): From bfbc86846d2e74c1c4025c0855dc131dc31bfe2d Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 14 Feb 2024 09:36:07 -0800 Subject: [PATCH 11/14] add no grad --- float8_experimental/float8_linear_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index add024ae..f7256f25 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -162,6 +162,7 @@ def get_float8_layers(model: torch.nn.Module): return fp8_layers +@torch.no_grad() def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) -> None: """ Manages the float8 amax and scale bookkeeping. In detail, it does the From 4ca6ddf1da17948011565f55596f653d607452ad Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 14 Feb 2024 12:35:20 -0800 Subject: [PATCH 12/14] comments --- float8_experimental/float8_linear_utils.py | 28 +++++++--------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index f7256f25..1b2ef24e 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -64,16 +64,6 @@ def linear_requires_sync(linear_type: LinearType): return linear_type in REQUIRES_SYNC -def _update_history_with_new_amax(new_amax, amax_history): - """ - Updates `amax_history` (the last N cur_amax values) inplace with the value - of `new_amax`. - """ - new_amax_history = torch.roll(amax_history, 1) - new_amax_history[0] = new_amax - amax_history.copy_(new_amax_history) - - def _update_history_stack( new_amax: torch.Tensor, amax_history_stack: torch.Tensor ) -> torch.Tensor: @@ -85,10 +75,12 @@ def _update_history_stack( new_amax (torch.Tensor): The new amax value to add to the history. (n_amaxes, 1) amax_history_stack (torch.Tensor): The history of amax values. (n_amaxes, history_length) """ - assert amax_history_stack.dim() == 2, "amax_history_stack must be 2D" + assert ( + amax_history_stack.dim() == 2 + ), f"Expected amat_history_stack to be 2D, got {amax_history_stack.shape()}" assert new_amax.size(0) == amax_history_stack.size( 0 - ), "new_amax must have the same size as the second dimension of amax_history_stack" + ), f"Expected new_amax to have the same size as the first dimension of amax_history_stack, got {new_amax.size(0)} and {amax_history_stack.size(0)}" new_amax_history_stack = torch.roll(amax_history_stack, 1, dims=1) new_amax_history_stack[:, 0] = new_amax.squeeze(-1) amax_history_stack.copy_(new_amax_history_stack) @@ -155,9 +147,7 @@ def get_float8_layers(model: torch.nn.Module): """ # Get all fp8 layers and tensors - fp8_layers = [ - child for _, child in model.named_modules() if isinstance(child, Float8Linear) - ] + fp8_layers = [child for child in model.modules() if isinstance(child, Float8Linear)] return fp8_layers @@ -176,7 +166,7 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) TODO(future): design the UX for this (context manager, etc) PERFORMANCE NOTE: - When you can it is much more efficient to call te get_float8_layers once a + When you can, it is much more efficient to call get_float8_layers once at the beginning of the training loop and pass the result to this function. Because of how this interacts with torch.compile @@ -249,13 +239,12 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) reduced_fp8_amax_dL_dY_tensor, ) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list)) - # TODO foreach is not supported with AsyncCollectiveTensor for idx, child in enumerate(fp8_layers): child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx]) child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx]) child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx]) - # We create two stacked tensors, one for the amax history and one for the current scales + # We create two stacked tensor groups, one for the amax history and one for the current scales fp8_amax_x_tensors = torch.vstack(fp8_amax_x_tensor_list) fp8_amax_w_tensors = torch.vstack(fp8_amax_w_tensor_list) fp8_amax_dL_dY_tensors = torch.vstack(fp8_amax_dL_dY_tensor_list) @@ -264,11 +253,12 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) fp8_w_amax_history_stack = torch.vstack(fp8_w_amax_history_stack) fp8_dL_dY_amax_history_stack = torch.vstack(fp8_dL_dY_amax_history_stack) + # Update the history stacks with the new amax values _update_history_stack(fp8_amax_x_tensors, fp8_x_amax_history_stack) _update_history_stack(fp8_amax_w_tensors, fp8_w_amax_history_stack) _update_history_stack(fp8_amax_dL_dY_tensors, fp8_dL_dY_amax_history_stack) - # We are not reading the + # Calculate the new scales from the updated history stacks new_x_scales = amax_history_to_scale_stack( fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe ) From 238e8f986fb13aa1101eb94bd49996b04ff0cf2f Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 14 Feb 2024 12:37:04 -0800 Subject: [PATCH 13/14] update asserts with more info --- float8_experimental/float8_linear_utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 1b2ef24e..b96a4a1c 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -210,10 +210,16 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) scale_fn_recipes.add(child.recipe.scale_fn_name) # TODO This way to get the activation dtype is not ideal - assert len(x_dtypes) == 1, "All layers must have the same last seen input_dtype" + if len(x_dtypes) != 1: + raise ValueError( + f"All layers must have the same last seen input_dtype, got {x_dtypes}" + ) x_dtype = next(iter(x_dtypes)) - assert len(scale_fn_recipes) == 1, "All layers must have the same scale_fn recipe" + if len(scale_fn_recipes) != 1: + raise ValueError( + f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}" + ) scale_fn_recipe = next(iter(scale_fn_recipes)) assert ( From 5bee42b6c6a7a989831a4249dd568e5ba7fa0d8f Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 14 Feb 2024 13:23:48 -0800 Subject: [PATCH 14/14] make error message better --- float8_experimental/float8_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index bd1d9f26..8661d3ee 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -63,7 +63,7 @@ def amax_history_to_scale_stack( amax_stack = torch.max(amax_history, dim=1).values return amax_to_scale(amax_stack, float8_dtype, orig_dtype) raise NotImplementedError( - "Invalid history_to_scale_fn_type, only 'max' is supported." + f"Invalid history_to_scale_fn_type, only 'max' is supported. Got: {history_to_scale_fn_type}" )