Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ Our framework makes it straightforward to add tensor parallel support to your cu

We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()`. We have a few examples you can follow

1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))`
1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))`
2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference

Expand Down
51 changes: 0 additions & 51 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import gc
import tempfile
import unittest
import warnings
from pathlib import Path

import torch
Expand Down Expand Up @@ -847,56 +846,6 @@ def test_int4wo_cuda_serialization(self):
# load state_dict in cuda
model.load_state_dict(sd, assign=True)

def test_config_deprecation(self):
"""
Test that old config functions like `int4_weight_only` trigger deprecation warnings.
"""
from torchao.quantization import (
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
fpx_weight_only,
gemlite_uintx_weight_only,
int4_dynamic_activation_int4_weight,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
uintx_weight_only,
)

# Reset deprecation warning state, otherwise we won't log warnings here
warnings.resetwarnings()

# Map from deprecated API to the args needed to instantiate it
deprecated_apis_to_args = {
float8_dynamic_activation_float8_weight: (),
float8_static_activation_float8_weight: (torch.randn(3)),
float8_weight_only: (),
fpx_weight_only: (3, 2),
gemlite_uintx_weight_only: (),
int4_dynamic_activation_int4_weight: (),
int4_weight_only: (),
int8_dynamic_activation_int4_weight: (),
int8_dynamic_activation_int8_weight: (),
int8_weight_only: (),
uintx_weight_only: (torch.uint4,),
}

with warnings.catch_warnings(record=True) as _warnings:
# Call each deprecated API twice
for cls, args in deprecated_apis_to_args.items():
cls(*args)
cls(*args)

# Each call should trigger the warning only once
self.assertEqual(len(_warnings), len(deprecated_apis_to_args))
for w in _warnings:
self.assertIn(
"is deprecated and will be removed in a future release",
str(w.message),
)


common_utils.instantiate_parametrized_tests(TestQuantFlow)

Expand Down
50 changes: 0 additions & 50 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import warnings
from unittest.mock import patch

import torch
Expand Down Expand Up @@ -37,55 +36,6 @@ def test_torch_version_at_least(self):
f"Failed for torch.__version__={torch_version}, comparing with {compare_version}",
)

def test_torch_version_deprecation(self):
"""
Test that TORCH_VERSION_AT_LEAST* and TORCH_VERSION_AFTER*
trigger deprecation warnings on use, not on import.
"""
# Reset deprecation warning state, otherwise we won't log warnings here
warnings.resetwarnings()

# Importing and referencing should not trigger deprecation warning
with warnings.catch_warnings(record=True) as _warnings:
from torchao.utils import (
TORCH_VERSION_AFTER_2_2,
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
TORCH_VERSION_AFTER_2_5,
TORCH_VERSION_AT_LEAST_2_2,
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
TORCH_VERSION_AT_LEAST_2_7,
TORCH_VERSION_AT_LEAST_2_8,
)

deprecated_api_to_name = [
(TORCH_VERSION_AT_LEAST_2_8, "TORCH_VERSION_AT_LEAST_2_8"),
(TORCH_VERSION_AT_LEAST_2_7, "TORCH_VERSION_AT_LEAST_2_7"),
(TORCH_VERSION_AT_LEAST_2_6, "TORCH_VERSION_AT_LEAST_2_6"),
(TORCH_VERSION_AT_LEAST_2_5, "TORCH_VERSION_AT_LEAST_2_5"),
(TORCH_VERSION_AT_LEAST_2_4, "TORCH_VERSION_AT_LEAST_2_4"),
(TORCH_VERSION_AT_LEAST_2_3, "TORCH_VERSION_AT_LEAST_2_3"),
(TORCH_VERSION_AT_LEAST_2_2, "TORCH_VERSION_AT_LEAST_2_2"),
(TORCH_VERSION_AFTER_2_5, "TORCH_VERSION_AFTER_2_5"),
(TORCH_VERSION_AFTER_2_4, "TORCH_VERSION_AFTER_2_4"),
(TORCH_VERSION_AFTER_2_3, "TORCH_VERSION_AFTER_2_3"),
(TORCH_VERSION_AFTER_2_2, "TORCH_VERSION_AFTER_2_2"),
]
self.assertEqual(len(_warnings), 0)

# Accessing the boolean value should trigger deprecation warning
with warnings.catch_warnings(record=True) as _warnings:
for api, name in deprecated_api_to_name:
num_warnings_before = len(_warnings)
if api:
pass
regex = f"{name} is deprecated and will be removed"
self.assertEqual(len(_warnings), num_warnings_before + 1)
self.assertIn(regex, str(_warnings[-1].message))


class TestTorchAOBaseTensor(unittest.TestCase):
def test_print_arg_types(self):
Expand Down
24 changes: 0 additions & 24 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,9 @@
PlainLayout,
TensorCoreTiledLayout,
UIntXWeightOnlyConfig,
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
fpx_weight_only,
gemlite_uintx_weight_only,
int4_dynamic_activation_int4_weight,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_semi_sparse_weight,
int8_dynamic_activation_int8_weight,
int8_weight_only,
intx_quantization_aware_training,
quantize_,
swap_conv2d_1x1_to_linear,
uintx_weight_only,
)
from .quant_primitives import (
MappingType,
Expand Down Expand Up @@ -131,19 +119,7 @@
"ALL_AUTOQUANT_CLASS_LIST",
# top level API - manual
"quantize_",
"int4_dynamic_activation_int4_weight",
"int8_dynamic_activation_int4_weight",
"int8_dynamic_activation_int8_weight",
"int8_dynamic_activation_int8_semi_sparse_weight",
"int4_weight_only",
"int8_weight_only",
"intx_quantization_aware_training",
"float8_weight_only",
"float8_dynamic_activation_float8_weight",
"float8_static_activation_float8_weight",
"uintx_weight_only",
"fpx_weight_only",
"gemlite_uintx_weight_only",
"swap_conv2d_1x1_to_linear",
"Int4DynamicActivationInt4WeightConfig",
"Int8DynamicActivationInt4WeightConfig",
Expand Down
75 changes: 1 addition & 74 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@
to_weight_tensor_with_linear_activation_quantization_metadata,
)
from torchao.utils import (
_ConfigDeprecationWrapper,
is_MI300,
is_sm_at_least_89,
is_sm_at_least_90,
Expand Down Expand Up @@ -148,18 +147,7 @@
"autoquant",
"_get_subclass_inserter",
"quantize_",
"int8_dynamic_activation_int4_weight",
"int8_dynamic_activation_int8_weight",
"int8_dynamic_activation_int8_semi_sparse_weight",
"int4_weight_only",
"int8_weight_only",
"intx_quantization_aware_training",
"float8_weight_only",
"uintx_weight_only",
"fpx_weight_only",
"gemlite_uintx_weight_only",
"float8_dynamic_activation_float8_weight",
"float8_static_activation_float8_weight",
"Int8DynActInt4WeightQuantizer",
"Float8DynamicActivationFloat8SemiSparseWeightConfig",
"ModuleFqnToConfig",
Expand Down Expand Up @@ -507,7 +495,7 @@ def quantize_(
# Int8DynamicActivationInt8WeightConfig (optimized with int8 mm op and torch.compile)
# Int4WeightOnlyConfig (optimized with int4 tinygemm kernel and torch.compile)
# Int8WeightOnlyConfig (optimized with int8 mm op and torch.compile
from torchao.quantization.quant_api import int4_weight_only
from torchao.quantization.quant_api import Int4WeightOnlyConfig

m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
quantize_(m, Int4WeightOnlyConfig(group_size=32, version=1))
Expand Down Expand Up @@ -629,12 +617,6 @@ def __post_init__(self):
)


# for BC
int8_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
"int8_dynamic_activation_int4_weight", Int8DynamicActivationInt4WeightConfig
)


@register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig)
def _int8_dynamic_activation_int4_weight_transform(
module: torch.nn.Module,
Expand Down Expand Up @@ -1000,12 +982,6 @@ def __post_init__(self):
)


# for bc
int4_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
"int4_dynamic_activation_int4_weight", Int4DynamicActivationInt4WeightConfig
)


@register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig)
def _int4_dynamic_activation_int4_weight_transform(
module: torch.nn.Module, config: Int4DynamicActivationInt4WeightConfig
Expand Down Expand Up @@ -1063,12 +1039,6 @@ def __post_init__(self):
)


# for BC
gemlite_uintx_weight_only = _ConfigDeprecationWrapper(
"gemlite_uintx_weight_only", GemliteUIntXWeightOnlyConfig
)


@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig)
def _gemlite_uintx_weight_only_transform(
module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig
Expand Down Expand Up @@ -1146,11 +1116,6 @@ def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig")


# for BC
# TODO maybe change other callsites
int4_weight_only = _ConfigDeprecationWrapper("int4_weight_only", Int4WeightOnlyConfig)


def _int4_weight_only_quantize_tensor(weight, config):
# TODO(future PR): perhaps move this logic to a different file, to keep the API
# file clean of implementation details
Expand Down Expand Up @@ -1362,10 +1327,6 @@ def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig")


# for BC
int8_weight_only = _ConfigDeprecationWrapper("int8_weight_only", Int8WeightOnlyConfig)


def _int8_weight_only_quantize_tensor(weight, config):
mapping_type = MappingType.SYMMETRIC
target_dtype = torch.int8
Expand Down Expand Up @@ -1523,12 +1484,6 @@ def __post_init__(self):
)


# for BC
int8_dynamic_activation_int8_weight = _ConfigDeprecationWrapper(
"int8_dynamic_activation_int8_weight", Int8DynamicActivationInt8WeightConfig
)


def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
layout = config.layout
act_mapping_type = config.act_mapping_type
Expand Down Expand Up @@ -1634,12 +1589,6 @@ def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig")


# for BC
float8_weight_only = _ConfigDeprecationWrapper(
"float8_weight_only", Float8WeightOnlyConfig
)


def _float8_weight_only_quant_tensor(weight, config):
if config.version == 1:
warnings.warn(
Expand Down Expand Up @@ -1798,12 +1747,6 @@ def __post_init__(self):
self.granularity = [activation_granularity, weight_granularity]


# for bc
float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper(
"float8_dynamic_activation_float8_weight", Float8DynamicActivationFloat8WeightConfig
)


def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
activation_dtype = config.activation_dtype
weight_dtype = config.weight_dtype
Expand Down Expand Up @@ -1979,12 +1922,6 @@ def __post_init__(self):
)


# for bc
float8_static_activation_float8_weight = _ConfigDeprecationWrapper(
"float8_static_activation_float8_weight", Float8StaticActivationFloat8WeightConfig
)


@register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig)
def _float8_static_activation_float8_weight_transform(
module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig
Expand Down Expand Up @@ -2067,12 +2004,6 @@ def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.UIntXWeightOnlyConfig")


# for BC
uintx_weight_only = _ConfigDeprecationWrapper(
"uintx_weight_only", UIntXWeightOnlyConfig
)


@register_quantize_module_handler(UIntXWeightOnlyConfig)
def _uintx_weight_only_transform(
module: torch.nn.Module, config: UIntXWeightOnlyConfig
Expand Down Expand Up @@ -2351,10 +2282,6 @@ def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.FPXWeightOnlyConfig")


# for BC
fpx_weight_only = _ConfigDeprecationWrapper("fpx_weight_only", FPXWeightOnlyConfig)


@register_quantize_module_handler(FPXWeightOnlyConfig)
def _fpx_weight_only_transform(
module: torch.nn.Module, config: FPXWeightOnlyConfig
Expand Down
Loading
Loading