diff --git a/test/test_utils.py b/test/test_utils.py index b46d600053..0e77388f13 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. import unittest -import warnings from unittest.mock import patch import torch @@ -37,55 +36,6 @@ def test_torch_version_at_least(self): f"Failed for torch.__version__={torch_version}, comparing with {compare_version}", ) - def test_torch_version_deprecation(self): - """ - Test that TORCH_VERSION_AT_LEAST* and TORCH_VERSION_AFTER* - trigger deprecation warnings on use, not on import. - """ - # Reset deprecation warning state, otherwise we won't log warnings here - warnings.resetwarnings() - - # Importing and referencing should not trigger deprecation warning - with warnings.catch_warnings(record=True) as _warnings: - from torchao.utils import ( - TORCH_VERSION_AFTER_2_2, - TORCH_VERSION_AFTER_2_3, - TORCH_VERSION_AFTER_2_4, - TORCH_VERSION_AFTER_2_5, - TORCH_VERSION_AT_LEAST_2_2, - TORCH_VERSION_AT_LEAST_2_3, - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, - TORCH_VERSION_AT_LEAST_2_6, - TORCH_VERSION_AT_LEAST_2_7, - TORCH_VERSION_AT_LEAST_2_8, - ) - - deprecated_api_to_name = [ - (TORCH_VERSION_AT_LEAST_2_8, "TORCH_VERSION_AT_LEAST_2_8"), - (TORCH_VERSION_AT_LEAST_2_7, "TORCH_VERSION_AT_LEAST_2_7"), - (TORCH_VERSION_AT_LEAST_2_6, "TORCH_VERSION_AT_LEAST_2_6"), - (TORCH_VERSION_AT_LEAST_2_5, "TORCH_VERSION_AT_LEAST_2_5"), - (TORCH_VERSION_AT_LEAST_2_4, "TORCH_VERSION_AT_LEAST_2_4"), - (TORCH_VERSION_AT_LEAST_2_3, "TORCH_VERSION_AT_LEAST_2_3"), - (TORCH_VERSION_AT_LEAST_2_2, "TORCH_VERSION_AT_LEAST_2_2"), - (TORCH_VERSION_AFTER_2_5, "TORCH_VERSION_AFTER_2_5"), - (TORCH_VERSION_AFTER_2_4, "TORCH_VERSION_AFTER_2_4"), - (TORCH_VERSION_AFTER_2_3, "TORCH_VERSION_AFTER_2_3"), - (TORCH_VERSION_AFTER_2_2, "TORCH_VERSION_AFTER_2_2"), - ] - self.assertEqual(len(_warnings), 0) - - # Accessing the boolean value should trigger deprecation warning - with warnings.catch_warnings(record=True) as _warnings: - for api, name in deprecated_api_to_name: - num_warnings_before = len(_warnings) - if api: - pass - regex = f"{name} is deprecated and will be removed" - self.assertEqual(len(_warnings), num_warnings_before + 1) - self.assertIn(regex, str(_warnings[-1].message)) - class TestTorchAOBaseTensor(unittest.TestCase): def test_print_arg_types(self): diff --git a/torchao/utils.py b/torchao/utils.py index 9dfebfb6fb..b010c4f9b8 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -8,7 +8,6 @@ import itertools import re import time -import warnings from functools import reduce from importlib.metadata import version from math import gcd @@ -34,17 +33,6 @@ "is_sm_at_least_90", "is_package_at_least", "DummyModule", - # Deprecated - "TORCH_VERSION_AT_LEAST_2_2", - "TORCH_VERSION_AT_LEAST_2_3", - "TORCH_VERSION_AT_LEAST_2_4", - "TORCH_VERSION_AT_LEAST_2_5", - "TORCH_VERSION_AT_LEAST_2_6", - "TORCH_VERSION_AT_LEAST_2_7", - "TORCH_VERSION_AFTER_2_2", - "TORCH_VERSION_AFTER_2_3", - "TORCH_VERSION_AFTER_2_4", - "TORCH_VERSION_AFTER_2_5", ] @@ -378,61 +366,6 @@ def torch_version_at_least(min_version): return parse_version(torch.__version__) >= parse_version(min_version) -def _deprecated_torch_version_at_least(version_str: str) -> str: - """ - Wrapper for existing TORCH_VERSION_AT_LEAST* variables that will log - a deprecation warning if the variable is used. - """ - version_str_var_name = "_".join(version_str.split(".")[:2]) - deprecation_msg = f"TORCH_VERSION_AT_LEAST_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0" - return _BoolDeprecationWrapper( - torch_version_at_least(version_str), - deprecation_msg, - ) - - -def _deprecated_torch_version_after(version_str: str) -> str: - """ - Wrapper for existing TORCH_VERSION_AFTER* variables that will log - a deprecation warning if the variable is used. - """ - bool_value = is_fbcode() or version("torch") >= version_str - version_str_var_name = "_".join(version_str.split(".")[:2]) - deprecation_msg = f"TORCH_VERSION_AFTER_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0" - return _BoolDeprecationWrapper(bool_value, deprecation_msg) - - -class _BoolDeprecationWrapper: - """ - A deprecation wrapper that logs a warning when the given bool value is accessed. - """ - - def __init__(self, bool_value: bool, msg: str): - self.bool_value = bool_value - self.msg = msg - - def __bool__(self): - warnings.warn(self.msg) - return self.bool_value - - def __eq__(self, other): - return bool(self) == bool(other) - - -# Deprecated, use `torch_version_at_least` directly instead -TORCH_VERSION_AT_LEAST_2_8 = _deprecated_torch_version_at_least("2.8.0") -TORCH_VERSION_AT_LEAST_2_7 = _deprecated_torch_version_at_least("2.7.0") -TORCH_VERSION_AT_LEAST_2_6 = _deprecated_torch_version_at_least("2.6.0") -TORCH_VERSION_AT_LEAST_2_5 = _deprecated_torch_version_at_least("2.5.0") -TORCH_VERSION_AT_LEAST_2_4 = _deprecated_torch_version_at_least("2.4.0") -TORCH_VERSION_AT_LEAST_2_3 = _deprecated_torch_version_at_least("2.3.0") -TORCH_VERSION_AT_LEAST_2_2 = _deprecated_torch_version_at_least("2.2.0") -TORCH_VERSION_AFTER_2_5 = _deprecated_torch_version_after("2.5.0.dev") -TORCH_VERSION_AFTER_2_4 = _deprecated_torch_version_after("2.4.0.dev") -TORCH_VERSION_AFTER_2_3 = _deprecated_torch_version_after("2.3.0.dev") -TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev") - - """ Helper function for implementing aten op or torch function dispatch and dispatching to these implementations.