Skip to content

Commit 17d162c

Browse files
committed
Update
[ghstack-poisoned]
2 parents 23f4a62 + 32d9b0b commit 17d162c

22 files changed

+923
-113
lines changed

.github/workflows/float8_test.yml

+3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ jobs:
2929
gpu-arch-type: "cuda"
3030
gpu-arch-version: "12.1"
3131

32+
permissions:
33+
id-token: write
34+
contents: read
3235
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
3336
with:
3437
timeout: 60

.github/workflows/nightly_smoke_test.yml

+4-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ concurrency:
1111
cancel-in-progress: true
1212

1313
env:
14-
HF_TOKEN: ${{ secrets.HF_TOKEN }}
14+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
1515

1616
jobs:
1717
test:
@@ -25,7 +25,9 @@ jobs:
2525
gpu-arch-type: "cuda"
2626
gpu-arch-version: "12.1"
2727

28-
28+
permissions:
29+
id-token: write
30+
contents: read
2931
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
3032
with:
3133
runner: ${{ matrix.runs-on }}

.github/workflows/regression_test.yml

+3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ jobs:
3434
gpu-arch-type: "cpu"
3535
gpu-arch-version: ""
3636

37+
permissions:
38+
id-token: write
39+
contents: read
3740
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
3841
with:
3942
timeout: 120

benchmarks/float8/profile_linear_float8.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
update_triton_kernels_in_prof_chome_trace_with_torch_logs,
3838
)
3939

40+
from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes
4041
from torchao.float8.config import (
4142
Float8LinearRecipeName,
4243
ScalingType,
@@ -206,7 +207,7 @@ def profile_function(
206207
# by default torch.compile appends to log_file_name, so we delete it
207208
# if it exists
208209
if os.path.isfile(config.logs_file_path):
209-
pathlib.Path.unlink(config.logs_file_path)
210+
pathlib.Path(config.logs_file_path).unlink()
210211
torch._logging._init_logs(log_file_name=config.logs_file_path)
211212

212213
activities = [ProfilerActivity.CPU]
@@ -288,6 +289,7 @@ def main(
288289
add_inductor_metadata_to_trace: bool = True,
289290
enable_sync_amax_history: bool = True,
290291
enable_activation_checkpointing: bool = False,
292+
enable_float8_delayed_scaling_inductor_passes: bool = False,
291293
):
292294
assert model_type in (
293295
"linear",
@@ -325,6 +327,12 @@ def main(
325327
print(
326328
f"enable_activation_checkpointing is set to {enable_activation_checkpointing}"
327329
)
330+
print(
331+
f"enable_float8_delayed_scaling_inductor_passes is set to {enable_float8_delayed_scaling_inductor_passes}"
332+
)
333+
334+
if enable_float8_delayed_scaling_inductor_passes:
335+
_prototype_register_float8_delayed_scaling_inductor_passes()
328336

329337
device = "cuda"
330338
ref_dtype = torch.bfloat16

test/float8/test_compile.py

+68
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import random
88
import sys
99
import unittest
10+
from dataclasses import replace
1011
from io import StringIO
1112

1213
import pytest
@@ -25,6 +26,7 @@
2526
from torch._dynamo.test_case import TestCase as DynamoTestCase
2627
from torch._dynamo.testing import CompileCounterWithBackend
2728

29+
from torchao.float8 import _prototype_register_float8_delayed_scaling_inductor_passes
2830
from torchao.float8.config import (
2931
CastConfig,
3032
Float8LinearConfig,
@@ -51,6 +53,7 @@
5153
from torchao.float8.float8_utils import config_has_stateful_scaling
5254
from torchao.float8.stateful_float8_linear import StatefulFloat8Linear
5355
from torchao.testing.float8.test_utils import get_test_float8_linear_config
56+
from torchao.utils import is_fbcode
5457

5558

5659
def _test_compile_base(
@@ -465,5 +468,70 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
465468
assert torch.equal(float8_eager._data, float8_compile._data)
466469

467470

471+
@unittest.skipIf(
472+
not is_sm_at_least_89() or not is_fbcode(),
473+
"CUDA with float8 support not available; or not on fbcode (the test needs be run with the latest pytorch package)",
474+
)
475+
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
476+
def test_delayed_scaling_pattern_replacement(dtype: torch.dtype):
477+
from torch._inductor import config as inductor_config
478+
from torch._inductor import metrics
479+
480+
inductor_config.loop_ordering_after_fusion = True
481+
482+
def clear_all():
483+
metrics.reset()
484+
from torch._inductor.fx_passes.post_grad import (
485+
pass_patterns as post_grad_patterns_all,
486+
)
487+
488+
post_grad_patterns_all[1].clear()
489+
post_grad_patterns_all[1].seen_patterns.clear()
490+
491+
def compile_and_run_single_layer():
492+
random.seed(0)
493+
torch.manual_seed(0)
494+
x_shape = (2048, 3072)
495+
linear_dtype = dtype
496+
497+
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype).requires_grad_()
498+
m_ref = nn.Linear(3072, 2048, bias=True, device="cuda", dtype=linear_dtype)
499+
500+
config = get_test_float8_linear_config(
501+
ScalingType.DELAYED,
502+
ScalingType.DELAYED,
503+
ScalingType.DELAYED,
504+
False,
505+
)
506+
507+
config = replace(config, enable_amax_init=False)
508+
509+
m_fp8 = StatefulFloat8Linear.from_float(
510+
copy.deepcopy(m_ref),
511+
config,
512+
)
513+
514+
m_fp8 = torch.compile(m_fp8, backend="inductor", fullgraph=True)
515+
m_ref = torch.compile(m_ref, backend="inductor", fullgraph=True)
516+
517+
y_fp8 = m_fp8(x)
518+
y_fp8.sum().backward()
519+
520+
return m_fp8.weight.grad
521+
522+
clear_all()
523+
ref_output = compile_and_run_single_layer()
524+
ref_count_kernel = metrics.generated_kernel_count
525+
526+
clear_all()
527+
_prototype_register_float8_delayed_scaling_inductor_passes()
528+
new_output = compile_and_run_single_layer()
529+
new_count_kernel = metrics.generated_kernel_count
530+
531+
torch.equal(ref_output, new_output)
532+
# With the pattern replacement workaround, amax reduction kernels for the 3 tensors (weight, activation, gradient) are fused.
533+
assert ref_count_kernel == new_count_kernel + 3
534+
535+
468536
if __name__ == "__main__":
469537
pytest.main([__file__])

test/integration/test_integration.py

+100-25
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
2626
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight,
2727
AQFloat8WeightOnlyQuantizedLinearWeight,
28+
AQGemliteInt4G64WeightOnlyQuantizedLinearWeight,
29+
AQInt4G32WeightOnlyQuantizedLinearWeight,
30+
AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight,
2831
AQInt8DynamicallyQuantizedLinearWeight,
2932
AQInt8WeightOnlyQuantizedLinearWeight,
3033
AQInt8WeightOnlyQuantizedLinearWeight2,
@@ -1751,37 +1754,109 @@ def test_autoquant_min_sqnr(self, device, dtype):
17511754
@unittest.skipIf(
17521755
not TORCH_VERSION_AT_LEAST_2_4, "autoquant float option requires 2.4+."
17531756
)
1754-
def test_autoquant_float(self):
1757+
def test_autoquant_hp_float(self):
17551758
device = "cuda"
17561759
dtype = torch.float32
17571760
m, k, n = 128, 128, 128
17581761
example_input = torch.randn(m, k, device=device, dtype=dtype)
1759-
model = (
1760-
torch.nn.Sequential(
1761-
torch.nn.ReLU(),
1762-
torch.nn.Linear(k, n),
1763-
torch.nn.ReLU(),
1762+
for qclass in torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST:
1763+
model = (
1764+
torch.nn.Sequential(
1765+
torch.nn.ReLU(),
1766+
torch.nn.Linear(k, n, bias=True),
1767+
torch.nn.ReLU(),
1768+
)
1769+
.to(device)
1770+
.to(dtype)
17641771
)
1765-
.to(device)
1766-
.to(dtype)
1767-
)
1768-
ref = model(example_input)
1769-
torchao.autoquant(
1770-
model,
1771-
qtensor_class_list=torchao.quantization.DEFAULT_FLOAT_AUTOQUANT_CLASS_LIST,
1772-
)
1773-
out = model(example_input)
1774-
from torchao.quantization.autoquant import (
1775-
BFloat16Tensor,
1776-
Float16Tensor,
1777-
Float32Tensor,
1778-
)
1772+
ref = model(example_input)
1773+
qtensor_class_list = [qclass]
1774+
torchao.autoquant(
1775+
model,
1776+
qtensor_class_list=qtensor_class_list,
1777+
)
1778+
out = model(example_input)
1779+
self.assertIn(
1780+
type(model[1].weight),
1781+
qtensor_class_list,
1782+
)
1783+
self.assertGreater(compute_error(out, ref), 40)
17791784

1780-
self.assertIn(
1781-
type(model[1].weight), [Float32Tensor, Float16Tensor, BFloat16Tensor]
1782-
)
1783-
print(compute_error(out, ref))
1784-
self.assertGreater(compute_error(out, ref), 60)
1785+
@parameterized.expand(COMMON_DEVICE_DTYPE)
1786+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
1787+
@unittest.skipIf(
1788+
not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+."
1789+
)
1790+
@unittest.skipIf(not has_gemlite, "gemlite not available")
1791+
def test_autoquant_int4wo(self, device, dtype):
1792+
if device == "cpu":
1793+
self.skipTest(f"int4wo is for cuda, not {device}")
1794+
1795+
m, k, n = 128, 128, 128
1796+
example_input = torch.randn(m, k, device=device, dtype=dtype)
1797+
1798+
for qclass in [
1799+
AQGemliteInt4G64WeightOnlyQuantizedLinearWeight,
1800+
AQInt4G32WeightOnlyQuantizedLinearWeight,
1801+
AQInt4G128WeightOnlyQuantizedMarlinSparseLinearWeight,
1802+
]:
1803+
model = (
1804+
torch.nn.Sequential(
1805+
torch.nn.ReLU(),
1806+
torch.nn.Linear(k, n, bias=True),
1807+
torch.nn.ReLU(),
1808+
)
1809+
.to(device)
1810+
.to(dtype)
1811+
)
1812+
ref = model(example_input)
1813+
qtensor_class_list = [qclass]
1814+
torchao.autoquant(
1815+
model,
1816+
qtensor_class_list=qtensor_class_list,
1817+
)
1818+
out = model(example_input)
1819+
1820+
self.assertIn(type(model[1].weight), qtensor_class_list)
1821+
self.assertGreater(compute_error(ref, out), 20)
1822+
1823+
@parameterized.expand(COMMON_DEVICE_DTYPE)
1824+
@unittest.skipIf(not is_sm_at_least_90(), "Need cuda arch greater than SM90")
1825+
@unittest.skipIf(
1826+
not TORCH_VERSION_AT_LEAST_2_5, "autoquant int4 option requires 2.5+."
1827+
)
1828+
def test_autoquant_float8(self, device, dtype):
1829+
if device == "cpu":
1830+
self.skipTest(f"int4wo is for cuda, not {device}")
1831+
1832+
# note: marlin sparse layout failed when scale_t has a dimension of 1d
1833+
m, k, n = 128, 128, 128
1834+
example_input = torch.randn(m, k, device=device, dtype=dtype)
1835+
1836+
for qclass in [
1837+
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
1838+
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight,
1839+
AQFloat8WeightOnlyQuantizedLinearWeight,
1840+
]:
1841+
model = (
1842+
torch.nn.Sequential(
1843+
torch.nn.ReLU(),
1844+
torch.nn.Linear(k, n, bias=True),
1845+
torch.nn.ReLU(),
1846+
)
1847+
.to(device)
1848+
.to(dtype)
1849+
)
1850+
ref = model(example_input)
1851+
qtensor_class_list = [qclass]
1852+
torchao.autoquant(
1853+
model,
1854+
qtensor_class_list=qtensor_class_list,
1855+
)
1856+
out = model(example_input)
1857+
1858+
self.assertIn(type(model[1].weight), qtensor_class_list)
1859+
self.assertGreater(compute_error(ref, out), 20)
17851860

17861861

17871862
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.")

torchao/_models/llama/eval.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def run_evaluation(
345345
args.device,
346346
args.precision,
347347
args.quantization,
348-
args.sparstiy,
348+
args.sparsity,
349349
args.compile,
350350
args.max_length,
351351
args.calibration_tasks,

torchao/dtypes/uintx/marlin_sparse_layout.py

+5
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,11 @@ def from_plain(
227227
# Linear layers are (in_features, out_features) but the int_data that is reaching this point
228228
# is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code.
229229
q_w_24 = int_data.t()
230+
# addressing the case when scale has dimension 1, happens when
231+
# weight_shape[-1] == group_size == 128
232+
if scale.ndim == 1:
233+
scale = scale.reshape(scale.shape[0], -1)
234+
230235
scale_t = scale.t()
231236

232237
if not torch.cuda.get_device_capability()[0] >= 8:

torchao/float8/README.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
8282
if not TORCH_VERSION_AT_LEAST_2_5:
8383
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")
8484

85+
# Recommended: enable additional torchinductor passes to improve the performance of delayed scaling
86+
torchao.float8._prototype_register_float8_delayed_scaling_inductor_passes()
87+
8588
# create model and sample input
8689
m = nn.Sequential(
8790
nn.Linear(2048, 4096),
@@ -172,7 +175,7 @@ For small shapes, a combination of (2) and (3) leads to speedup < 1. For medium
172175

173176
## Scaling type vs speedup
174177

175-
Delayed scaling is theoretically faster than dynamic scaling because of reduced read/write traffic requirements. Today, torch.compile has a couple of limitations (see the performance section of https://github.com/pytorch/ao/issues/556) which prevent us from reaching the optimal behavior for delayed scaling, so the observed performance of delayed scaling is close to that of dynamic scaling. As the torch.compile limitations are fixed, we expect delayed scaling to eventually become more performant compared to dynamic scaling.
178+
Delayed scaling is theoretically faster than dynamic scaling because of reduced read/write traffic requirements. Today, torch.compile has a couple of limitations (see the performance section of https://github.com/pytorch/ao/issues/556) which prevent us from reaching the optimal behavior for delayed scaling without workarounds. We have a prototype workaround (API subject to change) with the `torchao.float8._prototype_register_float8_delayed_scaling_inductor_passes()` API to improve delayed scaling performance.
176179

177180
## torch.compile behavior vs speedup
178181

torchao/float8/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
ScaledMMConfig,
2424
)
2525
from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
26+
from torchao.float8.inductor_utils import (
27+
_prototype_register_float8_delayed_scaling_inductor_passes,
28+
)
2629
from torchao.float8.inference import Float8MMConfig
2730
from torchao.float8.stateful_float8_linear import WeightWithDelayedFloat8CastTensor
2831
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
@@ -54,5 +57,6 @@
5457
"linear_requires_sync",
5558
"sync_float8_amax_and_scale_history",
5659
"precompute_float8_dynamic_scale_for_fsdp",
60+
"_prototype_register_float8_delayed_scaling_inductor_passes",
5761
# note: Float8Tensor and Float8Linear are not public APIs
5862
]

0 commit comments

Comments
 (0)