-
Notifications
You must be signed in to change notification settings - Fork 260
[bc-breaking] enable direct configuration in quantize_ #1595
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
24114ce
5b9d876
1cea42f
138883b
ba045ea
94d9426
b589ce7
aaba2d8
26850da
7caecb1
0542402
fac3263
d63e657
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,13 +8,15 @@ | |
run_tests, | ||
) | ||
|
||
from torchao.core.config import AOBaseWorkflowConfig | ||
from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout | ||
from torchao.quantization import ( | ||
float8_weight_only, | ||
int4_weight_only, | ||
int8_dynamic_activation_int4_weight, | ||
int8_dynamic_activation_int8_weight, | ||
int8_weight_only, | ||
quantize_, | ||
) | ||
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain | ||
from torchao.utils import ( | ||
|
@@ -76,7 +78,8 @@ def test_tensor_core_layout_transpose(self): | |
t = linear.weight | ||
shape = t.shape | ||
apply_int4_weight_only_quant = int4_weight_only(group_size=32) | ||
ql = apply_int4_weight_only_quant(linear) | ||
quantize_(linear, apply_int4_weight_only_quant) | ||
ql = linear | ||
aqt = ql.weight | ||
aqt_shape = aqt.shape | ||
self.assertEqual(aqt_shape, shape) | ||
|
@@ -95,7 +98,11 @@ def test_tensor_core_layout_transpose(self): | |
) | ||
def test_weights_only(self, apply_quant): | ||
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") | ||
ql = apply_quant(linear) | ||
if isinstance(apply_quant, AOBaseWorkflowConfig): | ||
quantize_(linear, apply_quant) | ||
ql = linear | ||
else: | ||
ql = apply_quant(linear) | ||
with tempfile.NamedTemporaryFile() as f: | ||
torch.save(ql.state_dict(), f) | ||
f.seek(0) | ||
|
@@ -171,8 +178,13 @@ def apply_uint6_weight_only_quant(linear): | |
@common_utils.parametrize("apply_quant", get_quantization_functions(True, True)) | ||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
def test_print_quantized_module(self, apply_quant): | ||
print(apply_quant) | ||
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda") | ||
ql = apply_quant(linear) | ||
if isinstance(apply_quant, AOBaseWorkflowConfig): | ||
quantize_(linear, apply_quant) | ||
ql = linear | ||
else: | ||
ql = apply_quant(linear) | ||
vkuzo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert "AffineQuantizedTensor" in str(ql) | ||
|
||
|
||
|
@@ -186,7 +198,10 @@ def test_flatten_unflatten(self, device, dtype): | |
apply_quant_list = get_quantization_functions(False, True, device) | ||
for apply_quant in apply_quant_list: | ||
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device) | ||
ql = apply_quant(linear) | ||
if isinstance(apply_quant, AOBaseWorkflowConfig): | ||
quantize_(linear, apply_quant) | ||
else: | ||
ql = apply_quant(linear) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. have a few partners where we need to forward fix BC issues including HuggingFace transformers, Optimimum, SGLang and Diffusers There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @msaroufim do you have a link? I don't expect any BC breakages of people using the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SGLANG callsite: https://github.com/sgl-project/sglang/blob/2f47d710ae9cb1bdbbe0fe2392a0634827d257b3/python/sglang/srt/layers/torchao_utils.py#L39 Diffusers callsite: https://github.com/huggingface/diffusers/blob/7fb481f840b5d73982cafd1affe89f21a5c0b20b/src/diffusers/quantizers/torchao/torchao_quantizer.py#L234 we should definitely test these, but they look like they will be unaffected to me |
||
lp_tensor = ql.weight | ||
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() | ||
tensor_data_dict = { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1185,7 +1185,7 @@ def test_qat_prototype_bc(self): | |
@unittest.skipIf( | ||
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower" | ||
) | ||
def test_quantize_api(self): | ||
def test_quantize_api_standalone(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need this change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's convenient from being able to filter for only this test from the commandline. I can remove it if you'd like. |
||
""" | ||
Test that the following: | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import abc | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel we can just add this to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. slightly stronger preference is I feel "core" shouldn't appear in the import, so users should be able to do this:
but we can do that by adding this to |
||
|
||
|
||
# directory location for this might need more polish | ||
class AOBaseWorkflowConfig(abc.ABC): | ||
vkuzo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
If a workflow config inherits from this then `quantize_` knows | ||
what to do with it. | ||
|
||
TODO write a better docblock. | ||
""" | ||
vkuzo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
pass |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from typing import Callable, Dict | ||
|
||
import torch | ||
|
||
from torchao.core.config import AOBaseWorkflowConfig | ||
|
||
_QUANTIZE_CONFIG_HANDLER: Dict[ | ||
drisspg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
AOBaseWorkflowConfig, | ||
Callable[[torch.nn.Module, AOBaseWorkflowConfig], torch.nn.Module], | ||
] = {} | ||
|
||
|
||
def register_quantize_module_handler(config_type): | ||
def decorator(func): | ||
vkuzo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_QUANTIZE_CONFIG_HANDLER[config_type] = func | ||
|
||
return decorator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove?