Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

config migration: float8* #1694

Merged
merged 19 commits into from
Feb 14, 2025
39 changes: 32 additions & 7 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
run_tests,
)

from torchao.core.config import AOBaseConfig
from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
from torchao.quantization import (
float8_weight_only,
Expand All @@ -16,6 +17,7 @@
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
quantize_,
)
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
from torchao.utils import (
Expand Down Expand Up @@ -82,7 +84,8 @@ def test_tensor_core_layout_transpose(self):
t = linear.weight
shape = t.shape
apply_int4_weight_only_quant = int4_weight_only(group_size=32)
ql = apply_int4_weight_only_quant(linear)
quantize_(linear, apply_int4_weight_only_quant)
ql = linear
aqt = ql.weight
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)
Expand All @@ -102,7 +105,12 @@ def test_tensor_core_layout_transpose(self):
)
def test_weights_only(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(linear)
if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
ql = linear
else:
# TODO(#1690): delete this once config migration is done
ql = apply_quant(linear)
with tempfile.NamedTemporaryFile() as f:
torch.save(ql.state_dict(), f)
f.seek(0)
Expand All @@ -115,16 +123,24 @@ def test_weights_only(self, apply_quant):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
def test_to_device(self, apply_quant):
def _apply(module, config_or_subclass_inserter):
if isinstance(config_or_subclass_inserter, AOBaseConfig):
quantize_(module, config_or_subclass_inserter)
else:
# TODO(#1690): delete this once config migration is done
module = config_or_subclass_inserter(module)
return module

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(linear)
ql = _apply(linear, apply_quant)
ql.to("cuda")

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(linear)
ql = _apply(linear, apply_quant)
ql.to(device="cuda")

linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(linear)
ql = _apply(linear, apply_quant)
ql.cuda()

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down Expand Up @@ -181,7 +197,12 @@ def apply_uint6_weight_only_quant(linear):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_print_quantized_module(self, apply_quant):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(linear)
if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
ql = linear
else:
# TODO(#1690): delete this once config migration is done
ql = apply_quant(linear)
assert "AffineQuantizedTensor" in str(ql)


Expand All @@ -195,7 +216,11 @@ def test_flatten_unflatten(self, device, dtype):
apply_quant_list = get_quantization_functions(False, True, device)
for apply_quant in apply_quant_list:
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
ql = apply_quant(linear)
if isinstance(apply_quant, AOBaseConfig):
quantize_(linear, apply_quant)
else:
# TODO(#1690): delete this once config migration is done
ql = apply_quant(linear)
lp_tensor = ql.weight
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
tensor_data_dict = {
Expand Down
7 changes: 4 additions & 3 deletions test/hqq/test_hqq_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
MappingType,
ZeroPointDomain,
int4_weight_only,
quantize_,
uintx_weight_only,
)
from torchao.utils import (
Expand Down Expand Up @@ -51,9 +52,9 @@ def _eval_hqq(dtype):
)
dummy_linear.weight.data = W
if dtype == torch.uint4:
q_tensor_hqq = int4_weight_only(group_size=max(block_size), use_hqq=True)(
dummy_linear
).weight
config = int4_weight_only(group_size=max(block_size), use_hqq=True)
quantize_(dummy_linear, config)
q_tensor_hqq = dummy_linear.weight
else:
q_tensor_hqq = uintx_weight_only(
dtype, group_size=max(block_size), use_hqq=True
Expand Down
2 changes: 1 addition & 1 deletion test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,7 +1185,7 @@ def test_qat_prototype_bc(self):
@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch version is 2.4 or lower"
)
def test_quantize_api(self):
def test_quantize_api_standalone(self):
"""
Test that the following:

Expand Down
56 changes: 56 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
Quantizer,
TwoStepQuantizer,
_replace_with_custom_fn_if_matches_filter,
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
Expand All @@ -40,11 +43,13 @@
Int4WeightOnlyQuantizedLinearWeight,
Int8WeightOnlyQuantizedLinearWeight,
)
from torchao.quantization.utils import compute_error
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
is_sm_at_least_89,
unwrap_tensor_subclass,
)

Expand Down Expand Up @@ -783,6 +788,57 @@ def test_int4wo_cpu(self, dtype, x_dim):
assert "_weight_int4pack_mm_for_cpu" in code[0]
assert "aten.mm.default" not in code[0]

# TODO(#1690): move to new config names
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize(
"config",
[
int4_weight_only(),
float8_weight_only(),
float8_dynamic_activation_float8_weight(),
float8_static_activation_float8_weight(scale=torch.tensor([1.0])),
],
)
def test_workflow_e2e_numerics(self, config):
"""
Simple test of e2e int4_weight_only workflow, comparing numerics
to a bfloat16 baseline.
"""
if (
isinstance(
config,
(
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
),
)
and not is_sm_at_least_89()
):
return unittest.skip("requires CUDA capability 8.9 or greater")

# scale has to be moved to cuda here because the parametrization init
# code happens before gating for cuda availability
if isinstance(config, float8_static_activation_float8_weight):
config.scale = config.scale.to("cuda")

# set up inputs
x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
# TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
# is that expected?
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
m_q = copy.deepcopy(m_ref)

# quantize
quantize_(m_q, config)

with torch.no_grad():
y_ref = m_ref(x)
y_q = m_q(x)

sqnr = compute_error(y_ref, y_q)
assert sqnr >= 20, f"SQNR {sqnr} is too low"


class TestMultiTensorFlow(TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
Expand Down
Empty file added torchao/core/__init__.py
Empty file.
29 changes: 29 additions & 0 deletions torchao/core/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import abc


class AOBaseConfig(abc.ABC):
"""
If a workflow config inherits from this then `quantize_` knows
how to a apply it to a model. For example::

# user facing code
class WorkflowFooConfig(AOBaseConfig): ...
# configuration for workflow `Foo` is defined here
bar = 'baz'

# non user facing code
@register_quantize_module_handler(WorkflowFooConfig)
def _transform(
mod: torch.nn.Module,
config: WorkflowFooConfig,
) -> torch.nn.Module:
# the transform is implemented here, usually a tensor sublass
# weight swap or a module swap
...

# then, the user calls `quantize_` with a config, and `_transform` is called
# under the hood by `quantize_.

"""

pass
12 changes: 12 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


from torchao.kernel import (
int_scaled_matmul,
safe_int_mm,
Expand Down Expand Up @@ -45,6 +46,10 @@
AffineQuantizedObserverBase,
)
from .quant_api import (
Float8DynamicActivationFloat8WeightConfig,
Float8StaticActivationFloat8WeightConfig,
Float8WeightOnlyConfig,
Int4WeightOnlyConfig,
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
Expand Down Expand Up @@ -85,6 +90,7 @@
swap_linear_with_smooth_fq_linear,
)
from .subclass import * # noqa: F403
from .transform_module import register_quantize_module_handler
from .unified import Quantizer, TwoStepQuantizer
from .utils import (
compute_error,
Expand Down Expand Up @@ -117,6 +123,10 @@
"fpx_weight_only",
"gemlite_uintx_weight_only",
"swap_conv2d_1x1_to_linear",
"Int4WeightOnlyConfig",
"Float8WeightOnlyConfig",
"Float8DynamicActivationFloat8WeightConfig",
"Float8StaticActivationFloat8WeightConfig",
# smooth quant - subject to change
"get_scale",
"SmoothFakeDynQuantMixin",
Expand Down Expand Up @@ -144,6 +154,8 @@
# operators/kernels
"safe_int_mm",
"int_scaled_matmul",
# registration of module transforms for quantize_
"register_quantize_module_handler",
# dataclasses and types
"MappingType",
"ZeroPointDomain",
Expand Down
4 changes: 4 additions & 0 deletions torchao/quantization/qat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .api import (
ComposableQATQuantizer,
FakeQuantizeConfig,
FromIntXQuantizationAwareTrainingConfig,
IntXQuantizationAwareTrainingConfig,
from_intx_quantization_aware_training,
intx_quantization_aware_training,
)
Expand All @@ -20,4 +22,6 @@
"Int8DynActInt4WeightQATQuantizer",
"intx_quantization_aware_training",
"from_intx_quantization_aware_training",
"FromIntXQuantizationAwareTrainingConfig",
"IntXQuantizationAwareTrainingConfig",
]
Loading
Loading