-
Notifications
You must be signed in to change notification settings - Fork 317
Add float8 FakeQuantizeConfig and FakeQuantizer #2735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2735
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 2 PendingAs of commit 6b85230 with merge base a1a9632 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
df874d5
to
481ac90
Compare
73bca60
to
7460a2d
Compare
dtype=base_config.weight_dtype, | ||
granularity=weight_granularity, | ||
) | ||
elif isinstance(base_config, Float8ActivationInt4WeightConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jerryzh168 can you confirm these config settings?
|
||
|
||
# TODO: don't register as custom op? | ||
@_register_custom_op(quant_lib, False) | ||
def _dequantize_affine_float8( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jerryzh168 I'm seeing this warning. Maybe should also skip registering this custom op?
/home/andrewor/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/torch/autograd/graph.py:824:
UserWarning: torchao::dequantize_affine_float8: an autograd kernel was not registered to the
Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior.
This behavior is deprecated and will be removed in a future version of PyTorch.
If your operator is differentiable, please ensure you have registered an autograd kernel to
the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd).
If your operator is not differentiable, or to squash this warning and use the previous behavior,
please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd.
(Triggered internally at /pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:62.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, we didn't know this would be a problem, we can do
ao/torchao/quantization/quant_primitives.py
Line 361 in 1dca638
@register_custom_op |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, will relax in a separate PR
) | ||
else: | ||
# targeting tinygemm kernel | ||
assert base_config.VERSION == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just support version 2, to minimize complexity?
can the test plan include training a real model and verifying loss converges |
7f19d27
to
1162ac3
Compare
Yes this is in progress |
3129101
to
204e99b
Compare
test/quantization/test_qat.py
Outdated
_get_qmin_qmax, | ||
) | ||
from torchao.quantization.quant_api import ( | ||
Float8ActivationInt4WeightConfig, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we just renamed this one
test/quantization/test_qat.py
Outdated
sqnr = compute_error(out, out_expected) | ||
self.assertGreater(sqnr, 16) | ||
|
||
@parameterized.expand([(PerRow(),), (PerTensor(),)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why not use
@common_utils.parametrize("granularity", [PerTensor(), PerRow()]) |
test/quantization/test_qat.py
Outdated
if "fbgemm-gpu-genai" in str(e): | ||
self.skipTest("fbgemm-gpu-genai not available") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we can skip the test when fbgemm-gpu-genai is not installed:
@unittest.skipIf( |
test/quantization/test_qat.py
Outdated
try: | ||
quantize_(m, QATConfig(base_config, step="prepare")) | ||
quantize_(m, QATConfig(base_config, step="convert")) | ||
m(*example_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this happen between prepare and convert as well
granularity=weight_granularity, | ||
) | ||
elif isinstance(base_config, Float8DynamicActivationInt4WeightConfig): | ||
act_config = Float8FakeQuantizeConfig( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one thing we should pay extra attention here is whether the simulation works for int4 preshuffled tensor as well I think, we need some numerics testing to make sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added an fp8-int4 numerics test
max_abs = tensor.abs().max() | ||
if hp_value_lb is not None or hp_value_ub is not None: | ||
max_abs = torch.clamp(max_abs, min=hp_value_lb, max=hp_value_ub) | ||
scale = max_abs / quant_max | ||
else: | ||
# rowwise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not necessarily rowwise I think, I believe this includes all granularities, and the len(block_size) == 0 is more of special case for tensorwise quant, I'm not sure where it comes from and whether it's needed, we could try to trace it and see if it can be removed as well to reduce complexity
204e99b
to
d669f71
Compare
f002586
to
e1823c6
Compare
Float8FakeQuantizeConfig(granularity=PerToken()) | ||
|
||
@parametrize("granularity", [PerTensor(), PerRow()]) | ||
def test_float8_fake_quantize(self, granularity: Granularity): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a same test for fp8_int4?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added some sqnr comparison against PTQ fp8_int4
e1823c6
to
bd0bd9d
Compare
**Summary:** This commit adds a QAT path for float8, using the same primitives as `torchao.quantization.Float8Tensor` targeting the following PTQ configs: - `Float8DynamicActivationFloat8WeightConfig` - `Float8DynamicActivationInt4WeightConfig` Usage: ``` from torchao.quantization.granularity import PerRow from torchao.quantization.qat import quantize_, QATConfig base_config = Float8DynamicActivationFloat8WeightConfig( torch.float8_e4m3fn, PerRow(), ) quantize_(model, QATConfig(base_config, step="prepare")) quantize_(model, QATConfig(base_config, step="convert")) ``` OR ``` from torchao.quantization.granularity import PerRow from torchao.quantization.qat import ( Float8FakeQuantizeConfig, QATConfig, quantize_, ) dtype = torch.float8_e4m3fn granularity = PerRow() quantize_(model, QATConfig( activation_config=Float8FakeQuantizeConfig(dtype, granularity), weight_config=Float8FakeQuantizeConfig(dtype, granularity), step="prepare", ) # convert (same as above, not shown) ``` **Test Plan:** ``` python test/quantization/test_qat.py -k test_float8_fake_quantize_config python test/quantization/test_qat.py -k test_float8_fake_quantize python test/quantization/test_qat.py -k test_quantize_api_fp8_fp8 python test/quantization/test_qat.py -k test_quantize_api_fp8_int4 ```
bd0bd9d
to
6b85230
Compare
Summary: This commit adds a QAT path for float8, using the same primitives as
torchao.quantization.Float8Tensor
targeting the following PTQ configs:Float8DynamicActivationFloat8WeightConfig
Float8ActivationInt4WeightConfig
Usage:
OR
Test Plan:
Identical outputs between normal bf16 and QAT fine-tuning for both fp8-fp8 and fp8-int4, reproduced on Llama3.1 using this unsloth notebook. Loss curves also overlap almost exactly (not shown):