Skip to content

Commit

Permalink
Add a register_replacement to fix float8 delayed scaling kernel fusio…
Browse files Browse the repository at this point in the history
…n issues in torchao/float8

Differential Revision: D67758184

Pull Request resolved: #1469
  • Loading branch information
y-sq authored Jan 16, 2025
1 parent 522f5b8 commit 74a15f1
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 2 deletions.
10 changes: 9 additions & 1 deletion benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
update_triton_kernels_in_prof_chome_trace_with_torch_logs,
)

from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes
from torchao.float8.config import (
Float8LinearRecipeName,
ScalingType,
Expand Down Expand Up @@ -206,7 +207,7 @@ def profile_function(
# by default torch.compile appends to log_file_name, so we delete it
# if it exists
if os.path.isfile(config.logs_file_path):
pathlib.Path.unlink(config.logs_file_path)
pathlib.Path(config.logs_file_path).unlink()
torch._logging._init_logs(log_file_name=config.logs_file_path)

activities = [ProfilerActivity.CPU]
Expand Down Expand Up @@ -288,6 +289,7 @@ def main(
add_inductor_metadata_to_trace: bool = True,
enable_sync_amax_history: bool = True,
enable_activation_checkpointing: bool = False,
enable_float8_delayed_scaling_inductor_passes: bool = False,
):
assert model_type in (
"linear",
Expand Down Expand Up @@ -325,6 +327,12 @@ def main(
print(
f"enable_activation_checkpointing is set to {enable_activation_checkpointing}"
)
print(
f"enable_float8_delayed_scaling_inductor_passes is set to {enable_float8_delayed_scaling_inductor_passes}"
)

if enable_float8_delayed_scaling_inductor_passes:
_prototype_register_float8_delayed_scaling_inductor_passes()

device = "cuda"
ref_dtype = torch.bfloat16
Expand Down
68 changes: 68 additions & 0 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import random
import sys
import unittest
from dataclasses import replace
from io import StringIO

import pytest
Expand All @@ -25,6 +26,7 @@
from torch._dynamo.test_case import TestCase as DynamoTestCase
from torch._dynamo.testing import CompileCounterWithBackend

from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes
from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
Expand All @@ -51,6 +53,7 @@
from torchao.float8.float8_utils import config_has_stateful_scaling
from torchao.float8.stateful_float8_linear import StatefulFloat8Linear
from torchao.testing.float8.test_utils import get_test_float8_linear_config
from torchao.utils import is_fbcode


def _test_compile_base(
Expand Down Expand Up @@ -465,5 +468,70 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
assert torch.equal(float8_eager._data, float8_compile._data)


@unittest.skipIf(
not is_sm_at_least_89() or not is_fbcode(),
"CUDA with float8 support not available; or not on fbcode (the test needs be run with the latest pytorch package)",
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_delayed_scaling_pattern_replacement(dtype: torch.dtype):
from torch._inductor import config as inductor_config
from torch._inductor import metrics

inductor_config.loop_ordering_after_fusion = True

def clear_all():
metrics.reset()
from torch._inductor.fx_passes.post_grad import (
pass_patterns as post_grad_patterns_all,
)

post_grad_patterns_all[1].clear()
post_grad_patterns_all[1].seen_patterns.clear()

def compile_and_run_single_layer():
random.seed(0)
torch.manual_seed(0)
x_shape = (2048, 3072)
linear_dtype = dtype

x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype).requires_grad_()
m_ref = nn.Linear(3072, 2048, bias=True, device="cuda", dtype=linear_dtype)

config = get_test_float8_linear_config(
ScalingType.DELAYED,
ScalingType.DELAYED,
ScalingType.DELAYED,
False,
)

config = replace(config, enable_amax_init=False)

m_fp8 = StatefulFloat8Linear.from_float(
copy.deepcopy(m_ref),
config,
)

m_fp8 = torch.compile(m_fp8, backend="inductor", fullgraph=True)
m_ref = torch.compile(m_ref, backend="inductor", fullgraph=True)

y_fp8 = m_fp8(x)
y_fp8.sum().backward()

return m_fp8.weight.grad

clear_all()
ref_output = compile_and_run_single_layer()
ref_count_kernel = metrics.generated_kernel_count

clear_all()
_prototype_register_float8_delayed_scaling_inductor_passes()
new_output = compile_and_run_single_layer()
new_count_kernel = metrics.generated_kernel_count

torch.equal(ref_output, new_output)
# With the pattern replacement workaround, amax reduction kernels for the 3 tensors (weight, activation, gradient) are fused.
assert ref_count_kernel == new_count_kernel + 3


if __name__ == "__main__":
pytest.main([__file__])
5 changes: 4 additions & 1 deletion torchao/float8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
if not TORCH_VERSION_AT_LEAST_2_5:
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")

# Recommended: enable additional torchinductor passes to improve the performance of delayed scaling
torchao.float8._prototype_register_float8_delayed_scaling_inductor_passes()

# create model and sample input
m = nn.Sequential(
nn.Linear(2048, 4096),
Expand Down Expand Up @@ -172,7 +175,7 @@ For small shapes, a combination of (2) and (3) leads to speedup < 1. For medium

## Scaling type vs speedup

Delayed scaling is theoretically faster than dynamic scaling because of reduced read/write traffic requirements. Today, torch.compile has a couple of limitations (see the performance section of https://github.com/pytorch/ao/issues/556) which prevent us from reaching the optimal behavior for delayed scaling, so the observed performance of delayed scaling is close to that of dynamic scaling. As the torch.compile limitations are fixed, we expect delayed scaling to eventually become more performant compared to dynamic scaling.
Delayed scaling is theoretically faster than dynamic scaling because of reduced read/write traffic requirements. Today, torch.compile has a couple of limitations (see the performance section of https://github.com/pytorch/ao/issues/556) which prevent us from reaching the optimal behavior for delayed scaling without workarounds. We have a prototype workaround (API subject to change) with the `torchao.float8._prototype_register_float8_delayed_scaling_inductor_passes()` API to improve delayed scaling performance.

## torch.compile behavior vs speedup

Expand Down
4 changes: 4 additions & 0 deletions torchao/float8/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
ScaledMMConfig,
)
from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
from torchao.float8.inductor_utils import (
_prototype_register_float8_delayed_scaling_inductor_passes,
)
from torchao.float8.inference import Float8MMConfig
from torchao.float8.stateful_float8_linear import WeightWithDelayedFloat8CastTensor
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
Expand Down Expand Up @@ -54,5 +57,6 @@
"linear_requires_sync",
"sync_float8_amax_and_scale_history",
"precompute_float8_dynamic_scale_for_fsdp",
"_prototype_register_float8_delayed_scaling_inductor_passes",
# note: Float8Tensor and Float8Linear are not public APIs
]
126 changes: 126 additions & 0 deletions torchao/float8/inductor_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import functools
import inspect
import traceback
from collections import deque

import torch


def amax_with_scaling_pattern(tensor_x_inp, scale_x, fp8_dtype, fp8_max):
tensor_x = tensor_x_inp.to(torch.float32) * scale_x
tensor_x = tensor_x.clamp(min=-1 * fp8_max, max=fp8_max)
tensor_x = tensor_x.to(fp8_dtype)
amax = torch.max(torch.abs(tensor_x_inp))
return (tensor_x, amax)


def amax_with_scaling_tiled_replacement(tensor_x_inp, scale_x, fp8_dtype, fp8_max):
tensor_x = tensor_x_inp.to(torch.float32) * scale_x
tensor_x = tensor_x.clamp(min=-1 * fp8_max, max=fp8_max)
tensor_x = tensor_x.to(fp8_dtype)
amax_1 = torch.max(torch.abs(tensor_x_inp), dim=-1).values
amax = torch.max(amax_1)
return (tensor_x, amax)


# The amax_with_scaling_pattern will also match dynamic scaling cases, we want to avoid that.
# `scale_x` of delayed scaling comes from the previous iteration, instead of from `tensor_x_inp`.
# We check that `scale_x` is not a dependency of `tensor_x_inp`
def fp8_delayed_scaling_extra_check(match):
scale_x_inputs = deque([match.kwargs["scale_x"]])
max_num_node_to_check = 20 # Don't traverse too many nodes
current_num_node = 0
while len(scale_x_inputs) > 0 and current_num_node < max_num_node_to_check:
current_node = scale_x_inputs.popleft()
for n in current_node.all_input_nodes:
if n == match.kwargs["tensor_x_inp"]:
return False
scale_x_inputs.append(n)
current_num_node += 1
return True


def partialize_and_update_signature(func, **kwargs):
"""
Equivalent to functools.partial but also updates the signature on returned function
"""
original_sig = inspect.signature(func)
parameters = original_sig.parameters

new_parameters = {
key: value for key, value in parameters.items() if key not in kwargs
}
new_sig = inspect.Signature(parameters=list(new_parameters.values()))

partial_func = functools.partial(func, **kwargs)

def wrapper(*args, **kwargs):
return partial_func(*args, **kwargs)

wrapper.__signature__ = new_sig # type: ignore[attr-defined]
wrapper.__name__ = func.__name__

return wrapper


def register_fp8_delayed_scaling_patterns_inner():
from torch._inductor.fx_passes.post_grad import (
pass_patterns as post_grad_patterns_all,
)
from torch._inductor.pattern_matcher import fwd_only, register_replacement

post_grad_patterns = post_grad_patterns_all[1] # medium priority

if torch.cuda.is_available():
for fp8_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
]:
# torch.float16 has the same pattern as torch.bfloat16, because they both needs `tensor_x_inp.to(torch.float32)`
for dtype in [torch.float32, torch.bfloat16]:
device = "cuda"
register_replacement(
partialize_and_update_signature(
amax_with_scaling_pattern,
fp8_dtype=fp8_dtype,
fp8_max=torch.finfo(fp8_dtype).max,
),
partialize_and_update_signature(
amax_with_scaling_tiled_replacement,
fp8_dtype=fp8_dtype,
fp8_max=torch.finfo(fp8_dtype).max,
),
[
torch.tensor((16, 16), device=device, dtype=dtype),
torch.tensor(2.0, device=device, dtype=torch.float32),
],
fwd_only,
post_grad_patterns,
extra_check=fp8_delayed_scaling_extra_check,
)


"""
This a short-term workaround of the delayed scaling performance issue.
It explicitly replaces `max(x)` with `max(max(x, dim=-1))`, enabling the fusion of amax scaling factor calculation and fp8 casting.
Usage:
To use this solution, add the following line at the beginning of your user code:
torchao.float8._prototype_register_float8_delayed_scaling_inductor_passes()
"""


def _prototype_register_float8_delayed_scaling_inductor_passes() -> None:
# To make the fp8 delayed scaling pattern work, we need a fix pr from inductor, https://github.com/pytorch/pytorch/pull/139321
# Will throw the error if the pattern registration did not work, up to user to decide what to do with it
try:
register_fp8_delayed_scaling_patterns_inner()
except AssertionError as e:
if "assert pattern_repr not in _seen_patterns" in traceback.format_exc():
print(
f"Caught duplicated patterns in register_fp8_delayed_scaling_patterns: {traceback.format_exc()}",
"\nPlease update your pytorch dependency to the latest main branch to fix it.\n",
)
raise e

0 comments on commit 74a15f1

Please sign in to comment.