From 74a15f1dd72839264eb87adfaf986cdfcc9d6781 Mon Sep 17 00:00:00 2001 From: y-sq <58683402+y-sq@users.noreply.github.com> Date: Thu, 16 Jan 2025 10:06:03 -0800 Subject: [PATCH] Add a register_replacement to fix float8 delayed scaling kernel fusion issues in torchao/float8 Differential Revision: D67758184 Pull Request resolved: https://github.com/pytorch/ao/pull/1469 --- benchmarks/float8/profile_linear_float8.py | 10 +- test/float8/test_compile.py | 68 +++++++++++ torchao/float8/README.md | 5 +- torchao/float8/__init__.py | 4 + torchao/float8/inductor_utils.py | 126 +++++++++++++++++++++ 5 files changed, 211 insertions(+), 2 deletions(-) create mode 100644 torchao/float8/inductor_utils.py diff --git a/benchmarks/float8/profile_linear_float8.py b/benchmarks/float8/profile_linear_float8.py index 19fb492c32..5045956954 100644 --- a/benchmarks/float8/profile_linear_float8.py +++ b/benchmarks/float8/profile_linear_float8.py @@ -37,6 +37,7 @@ update_triton_kernels_in_prof_chome_trace_with_torch_logs, ) +from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes from torchao.float8.config import ( Float8LinearRecipeName, ScalingType, @@ -206,7 +207,7 @@ def profile_function( # by default torch.compile appends to log_file_name, so we delete it # if it exists if os.path.isfile(config.logs_file_path): - pathlib.Path.unlink(config.logs_file_path) + pathlib.Path(config.logs_file_path).unlink() torch._logging._init_logs(log_file_name=config.logs_file_path) activities = [ProfilerActivity.CPU] @@ -288,6 +289,7 @@ def main( add_inductor_metadata_to_trace: bool = True, enable_sync_amax_history: bool = True, enable_activation_checkpointing: bool = False, + enable_float8_delayed_scaling_inductor_passes: bool = False, ): assert model_type in ( "linear", @@ -325,6 +327,12 @@ def main( print( f"enable_activation_checkpointing is set to {enable_activation_checkpointing}" ) + print( + f"enable_float8_delayed_scaling_inductor_passes is set to {enable_float8_delayed_scaling_inductor_passes}" + ) + + if enable_float8_delayed_scaling_inductor_passes: + _prototype_register_float8_delayed_scaling_inductor_passes() device = "cuda" ref_dtype = torch.bfloat16 diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index 32d6bdfbbd..c42ab8ee77 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -7,6 +7,7 @@ import random import sys import unittest +from dataclasses import replace from io import StringIO import pytest @@ -25,6 +26,7 @@ from torch._dynamo.test_case import TestCase as DynamoTestCase from torch._dynamo.testing import CompileCounterWithBackend +from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes from torchao.float8.config import ( CastConfig, Float8LinearConfig, @@ -51,6 +53,7 @@ from torchao.float8.float8_utils import config_has_stateful_scaling from torchao.float8.stateful_float8_linear import StatefulFloat8Linear from torchao.testing.float8.test_utils import get_test_float8_linear_config +from torchao.utils import is_fbcode def _test_compile_base( @@ -465,5 +468,70 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype): assert torch.equal(float8_eager._data, float8_compile._data) +@unittest.skipIf( + not is_sm_at_least_89() or not is_fbcode(), + "CUDA with float8 support not available; or not on fbcode (the test needs be run with the latest pytorch package)", +) +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) +def test_delayed_scaling_pattern_replacement(dtype: torch.dtype): + from torch._inductor import config as inductor_config + from torch._inductor import metrics + + inductor_config.loop_ordering_after_fusion = True + + def clear_all(): + metrics.reset() + from torch._inductor.fx_passes.post_grad import ( + pass_patterns as post_grad_patterns_all, + ) + + post_grad_patterns_all[1].clear() + post_grad_patterns_all[1].seen_patterns.clear() + + def compile_and_run_single_layer(): + random.seed(0) + torch.manual_seed(0) + x_shape = (2048, 3072) + linear_dtype = dtype + + x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype).requires_grad_() + m_ref = nn.Linear(3072, 2048, bias=True, device="cuda", dtype=linear_dtype) + + config = get_test_float8_linear_config( + ScalingType.DELAYED, + ScalingType.DELAYED, + ScalingType.DELAYED, + False, + ) + + config = replace(config, enable_amax_init=False) + + m_fp8 = StatefulFloat8Linear.from_float( + copy.deepcopy(m_ref), + config, + ) + + m_fp8 = torch.compile(m_fp8, backend="inductor", fullgraph=True) + m_ref = torch.compile(m_ref, backend="inductor", fullgraph=True) + + y_fp8 = m_fp8(x) + y_fp8.sum().backward() + + return m_fp8.weight.grad + + clear_all() + ref_output = compile_and_run_single_layer() + ref_count_kernel = metrics.generated_kernel_count + + clear_all() + _prototype_register_float8_delayed_scaling_inductor_passes() + new_output = compile_and_run_single_layer() + new_count_kernel = metrics.generated_kernel_count + + torch.equal(ref_output, new_output) + # With the pattern replacement workaround, amax reduction kernels for the 3 tensors (weight, activation, gradient) are fused. + assert ref_count_kernel == new_count_kernel + 3 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/torchao/float8/README.md b/torchao/float8/README.md index 1a87770899..8487096e6c 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -82,6 +82,9 @@ from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 if not TORCH_VERSION_AT_LEAST_2_5: raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater") +# Recommended: enable additional torchinductor passes to improve the performance of delayed scaling +torchao.float8._prototype_register_float8_delayed_scaling_inductor_passes() + # create model and sample input m = nn.Sequential( nn.Linear(2048, 4096), @@ -172,7 +175,7 @@ For small shapes, a combination of (2) and (3) leads to speedup < 1. For medium ## Scaling type vs speedup -Delayed scaling is theoretically faster than dynamic scaling because of reduced read/write traffic requirements. Today, torch.compile has a couple of limitations (see the performance section of https://github.com/pytorch/ao/issues/556) which prevent us from reaching the optimal behavior for delayed scaling, so the observed performance of delayed scaling is close to that of dynamic scaling. As the torch.compile limitations are fixed, we expect delayed scaling to eventually become more performant compared to dynamic scaling. +Delayed scaling is theoretically faster than dynamic scaling because of reduced read/write traffic requirements. Today, torch.compile has a couple of limitations (see the performance section of https://github.com/pytorch/ao/issues/556) which prevent us from reaching the optimal behavior for delayed scaling without workarounds. We have a prototype workaround (API subject to change) with the `torchao.float8._prototype_register_float8_delayed_scaling_inductor_passes()` API to improve delayed scaling performance. ## torch.compile behavior vs speedup diff --git a/torchao/float8/__init__.py b/torchao/float8/__init__.py index 3336330361..258db53be0 100644 --- a/torchao/float8/__init__.py +++ b/torchao/float8/__init__.py @@ -23,6 +23,9 @@ ScaledMMConfig, ) from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp +from torchao.float8.inductor_utils import ( + _prototype_register_float8_delayed_scaling_inductor_passes, +) from torchao.float8.inference import Float8MMConfig from torchao.float8.stateful_float8_linear import WeightWithDelayedFloat8CastTensor from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 @@ -54,5 +57,6 @@ "linear_requires_sync", "sync_float8_amax_and_scale_history", "precompute_float8_dynamic_scale_for_fsdp", + "_prototype_register_float8_delayed_scaling_inductor_passes", # note: Float8Tensor and Float8Linear are not public APIs ] diff --git a/torchao/float8/inductor_utils.py b/torchao/float8/inductor_utils.py new file mode 100644 index 0000000000..3e86202536 --- /dev/null +++ b/torchao/float8/inductor_utils.py @@ -0,0 +1,126 @@ +import functools +import inspect +import traceback +from collections import deque + +import torch + + +def amax_with_scaling_pattern(tensor_x_inp, scale_x, fp8_dtype, fp8_max): + tensor_x = tensor_x_inp.to(torch.float32) * scale_x + tensor_x = tensor_x.clamp(min=-1 * fp8_max, max=fp8_max) + tensor_x = tensor_x.to(fp8_dtype) + amax = torch.max(torch.abs(tensor_x_inp)) + return (tensor_x, amax) + + +def amax_with_scaling_tiled_replacement(tensor_x_inp, scale_x, fp8_dtype, fp8_max): + tensor_x = tensor_x_inp.to(torch.float32) * scale_x + tensor_x = tensor_x.clamp(min=-1 * fp8_max, max=fp8_max) + tensor_x = tensor_x.to(fp8_dtype) + amax_1 = torch.max(torch.abs(tensor_x_inp), dim=-1).values + amax = torch.max(amax_1) + return (tensor_x, amax) + + +# The amax_with_scaling_pattern will also match dynamic scaling cases, we want to avoid that. +# `scale_x` of delayed scaling comes from the previous iteration, instead of from `tensor_x_inp`. +# We check that `scale_x` is not a dependency of `tensor_x_inp` +def fp8_delayed_scaling_extra_check(match): + scale_x_inputs = deque([match.kwargs["scale_x"]]) + max_num_node_to_check = 20 # Don't traverse too many nodes + current_num_node = 0 + while len(scale_x_inputs) > 0 and current_num_node < max_num_node_to_check: + current_node = scale_x_inputs.popleft() + for n in current_node.all_input_nodes: + if n == match.kwargs["tensor_x_inp"]: + return False + scale_x_inputs.append(n) + current_num_node += 1 + return True + + +def partialize_and_update_signature(func, **kwargs): + """ + Equivalent to functools.partial but also updates the signature on returned function + """ + original_sig = inspect.signature(func) + parameters = original_sig.parameters + + new_parameters = { + key: value for key, value in parameters.items() if key not in kwargs + } + new_sig = inspect.Signature(parameters=list(new_parameters.values())) + + partial_func = functools.partial(func, **kwargs) + + def wrapper(*args, **kwargs): + return partial_func(*args, **kwargs) + + wrapper.__signature__ = new_sig # type: ignore[attr-defined] + wrapper.__name__ = func.__name__ + + return wrapper + + +def register_fp8_delayed_scaling_patterns_inner(): + from torch._inductor.fx_passes.post_grad import ( + pass_patterns as post_grad_patterns_all, + ) + from torch._inductor.pattern_matcher import fwd_only, register_replacement + + post_grad_patterns = post_grad_patterns_all[1] # medium priority + + if torch.cuda.is_available(): + for fp8_dtype in [ + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, + ]: + # torch.float16 has the same pattern as torch.bfloat16, because they both needs `tensor_x_inp.to(torch.float32)` + for dtype in [torch.float32, torch.bfloat16]: + device = "cuda" + register_replacement( + partialize_and_update_signature( + amax_with_scaling_pattern, + fp8_dtype=fp8_dtype, + fp8_max=torch.finfo(fp8_dtype).max, + ), + partialize_and_update_signature( + amax_with_scaling_tiled_replacement, + fp8_dtype=fp8_dtype, + fp8_max=torch.finfo(fp8_dtype).max, + ), + [ + torch.tensor((16, 16), device=device, dtype=dtype), + torch.tensor(2.0, device=device, dtype=torch.float32), + ], + fwd_only, + post_grad_patterns, + extra_check=fp8_delayed_scaling_extra_check, + ) + + +""" +This a short-term workaround of the delayed scaling performance issue. +It explicitly replaces `max(x)` with `max(max(x, dim=-1))`, enabling the fusion of amax scaling factor calculation and fp8 casting. + +Usage: + To use this solution, add the following line at the beginning of your user code: + torchao.float8._prototype_register_float8_delayed_scaling_inductor_passes() +""" + + +def _prototype_register_float8_delayed_scaling_inductor_passes() -> None: + # To make the fp8 delayed scaling pattern work, we need a fix pr from inductor, https://github.com/pytorch/pytorch/pull/139321 + # Will throw the error if the pattern registration did not work, up to user to decide what to do with it + try: + register_fp8_delayed_scaling_patterns_inner() + except AssertionError as e: + if "assert pattern_repr not in _seen_patterns" in traceback.format_exc(): + print( + f"Caught duplicated patterns in register_fp8_delayed_scaling_patterns: {traceback.format_exc()}", + "\nPlease update your pytorch dependency to the latest main branch to fix it.\n", + ) + raise e