Skip to content

Add float8 FakeQuantizeConfig and FakeQuantizer #2735

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 13, 2025
Merged

Conversation

andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented Aug 11, 2025

Summary: This commit adds a QAT path for float8, using the same primitives as torchao.quantization.Float8Tensor targeting the following PTQ configs:

  • Float8DynamicActivationFloat8WeightConfig
  • Float8ActivationInt4WeightConfig

Usage:

from torchao.quantization.granularity import PerRow
from torchao.quantization.qat import quantize_, QATConfig

base_config = Float8DynamicActivationFloat8WeightConfig(
    torch.float8_e4m3fn, PerRow(),
)
quantize_(model, QATConfig(base_config, step="prepare"))
quantize_(model, QATConfig(base_config, step="convert"))

OR

from torchao.quantization.granularity import PerRow
from torchao.quantization.qat import (
    Float8FakeQuantizeConfig,
    QATConfig,
    quantize_,
)

dtype = torch.float8_e4m3fn
granularity = PerRow()
quantize_(model, QATConfig(
    activation_config=Float8FakeQuantizeConfig(dtype, granularity),
    weight_config=Float8FakeQuantizeConfig(dtype, granularity),
    step="prepare",
)

# convert (same as above, not shown)

Test Plan:

python test/quantization/test_qat.py -k test_float8_fake_quantize_config
python test/quantization/test_qat.py -k test_float8_fake_quantize
python test/quantization/test_qat.py -k test_quantize_api_fp8_fp8
python test/quantization/test_qat.py -k test_quantize_api_fp8_int4

Identical outputs between normal bf16 and QAT fine-tuning for both fp8-fp8 and fp8-int4, reproduced on Llama3.1 using this unsloth notebook. Loss curves also overlap almost exactly (not shown):

<|begin_of_text|>Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
Continue the fibonnaci sequence.

### Input:
1, 1, 2, 3, 5, 8

### Response:
13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584, 4181, 6765, 10946, 17711, 28657, 46368, 75025, 121393, 196418, 317811, 514229, 832040, 1346269, 2178309, 3524578, 5702887, 9227465, 14930352, 24157817, 39088169, 632459

Copy link

pytorch-bot bot commented Aug 11, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2735

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 2 Pending

As of commit 6b85230 with merge base a1a9632 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 11, 2025
@andrewor14 andrewor14 marked this pull request as draft August 11, 2025 17:28
@andrewor14 andrewor14 added the topic: new feature Use this tag if this PR adds a new feature label Aug 11, 2025
@andrewor14 andrewor14 requested a review from jerryzh168 August 11, 2025 23:31
@andrewor14 andrewor14 changed the title [draft] Add float8 FakeQuantizeConfig and FakeQuantizer Add float8 FakeQuantizeConfig and FakeQuantizer Aug 11, 2025
@andrewor14 andrewor14 requested review from drisspg and vkuzo August 11, 2025 23:32
@andrewor14 andrewor14 marked this pull request as ready for review August 11, 2025 23:32
@andrewor14 andrewor14 force-pushed the fp8-fake-quantizer branch 3 times, most recently from 73bca60 to 7460a2d Compare August 11, 2025 23:39
dtype=base_config.weight_dtype,
granularity=weight_granularity,
)
elif isinstance(base_config, Float8ActivationInt4WeightConfig):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerryzh168 can you confirm these config settings?



# TODO: don't register as custom op?
@_register_custom_op(quant_lib, False)
def _dequantize_affine_float8(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerryzh168 I'm seeing this warning. Maybe should also skip registering this custom op?

/home/andrewor/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/autograd/graph.py:824:
UserWarning: torchao::dequantize_affine_float8: an autograd kernel was not registered to the
Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior.
This behavior is deprecated and will be removed in a future version of PyTorch.
If your operator is differentiable, please ensure you have registered an autograd kernel to
the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd).
If your operator is not differentiable, or to squash this warning and use the previous behavior,
please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd.
(Triggered internally at /pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:62.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, we didn't know this would be a problem, we can do

@register_custom_op
for now I think, this is added from intel for #2565 I believe and it's broken now anyways

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, will relax in a separate PR

)
else:
# targeting tinygemm kernel
assert base_config.VERSION == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just support version 2, to minimize complexity?

@vkuzo
Copy link
Contributor

vkuzo commented Aug 12, 2025

can the test plan include training a real model and verifying loss converges

@andrewor14 andrewor14 force-pushed the fp8-fake-quantizer branch 2 times, most recently from 7f19d27 to 1162ac3 Compare August 12, 2025 14:11
@andrewor14
Copy link
Contributor Author

can the test plan include training a real model and verifying loss converges

Yes this is in progress

@andrewor14 andrewor14 force-pushed the fp8-fake-quantizer branch 2 times, most recently from 3129101 to 204e99b Compare August 12, 2025 20:53
_get_qmin_qmax,
)
from torchao.quantization.quant_api import (
Float8ActivationInt4WeightConfig,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we just renamed this one

sqnr = compute_error(out, out_expected)
self.assertGreater(sqnr, 16)

@parameterized.expand([(PerRow(),), (PerTensor(),)])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why not use

@common_utils.parametrize("granularity", [PerTensor(), PerRow()])

Comment on lines 1970 to 1971
if "fbgemm-gpu-genai" in str(e):
self.skipTest("fbgemm-gpu-genai not available")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can skip the test when fbgemm-gpu-genai is not installed:

try:
quantize_(m, QATConfig(base_config, step="prepare"))
quantize_(m, QATConfig(base_config, step="convert"))
m(*example_inputs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this happen between prepare and convert as well

granularity=weight_granularity,
)
elif isinstance(base_config, Float8DynamicActivationInt4WeightConfig):
act_config = Float8FakeQuantizeConfig(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thing we should pay extra attention here is whether the simulation works for int4 preshuffled tensor as well I think, we need some numerics testing to make sure

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added an fp8-int4 numerics test

max_abs = tensor.abs().max()
if hp_value_lb is not None or hp_value_ub is not None:
max_abs = torch.clamp(max_abs, min=hp_value_lb, max=hp_value_ub)
scale = max_abs / quant_max
else:
# rowwise
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not necessarily rowwise I think, I believe this includes all granularities, and the len(block_size) == 0 is more of special case for tensorwise quant, I'm not sure where it comes from and whether it's needed, we could try to trace it and see if it can be removed as well to reduce complexity

@andrewor14 andrewor14 force-pushed the fp8-fake-quantizer branch 2 times, most recently from f002586 to e1823c6 Compare August 13, 2025 14:23
Float8FakeQuantizeConfig(granularity=PerToken())

@parametrize("granularity", [PerTensor(), PerRow()])
def test_float8_fake_quantize(self, granularity: Granularity):
Copy link
Contributor

@jerryzh168 jerryzh168 Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a same test for fp8_int4?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added some sqnr comparison against PTQ fp8_int4

**Summary:** This commit adds a QAT path for float8, using the
same primitives as `torchao.quantization.Float8Tensor` targeting
the following PTQ configs:
- `Float8DynamicActivationFloat8WeightConfig`
- `Float8DynamicActivationInt4WeightConfig`

Usage:
```
from torchao.quantization.granularity import PerRow
from torchao.quantization.qat import quantize_, QATConfig

base_config = Float8DynamicActivationFloat8WeightConfig(
    torch.float8_e4m3fn, PerRow(),
)
quantize_(model, QATConfig(base_config, step="prepare"))
quantize_(model, QATConfig(base_config, step="convert"))
```

OR

```
from torchao.quantization.granularity import PerRow
from torchao.quantization.qat import (
    Float8FakeQuantizeConfig,
    QATConfig,
    quantize_,
)

dtype = torch.float8_e4m3fn
granularity = PerRow()
quantize_(model, QATConfig(
    activation_config=Float8FakeQuantizeConfig(dtype, granularity),
    weight_config=Float8FakeQuantizeConfig(dtype, granularity),
    step="prepare",
)

# convert (same as above, not shown)
```

**Test Plan:**
```
python test/quantization/test_qat.py -k test_float8_fake_quantize_config
python test/quantization/test_qat.py -k test_float8_fake_quantize
python test/quantization/test_qat.py -k test_quantize_api_fp8_fp8
python test/quantization/test_qat.py -k test_quantize_api_fp8_int4
```
@andrewor14 andrewor14 merged commit 715ea9f into main Aug 13, 2025
21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants