Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 0 additions & 50 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import warnings
from unittest.mock import patch

import torch
Expand Down Expand Up @@ -37,55 +36,6 @@ def test_torch_version_at_least(self):
f"Failed for torch.__version__={torch_version}, comparing with {compare_version}",
)

def test_torch_version_deprecation(self):
"""
Test that TORCH_VERSION_AT_LEAST* and TORCH_VERSION_AFTER*
trigger deprecation warnings on use, not on import.
"""
# Reset deprecation warning state, otherwise we won't log warnings here
warnings.resetwarnings()

# Importing and referencing should not trigger deprecation warning
with warnings.catch_warnings(record=True) as _warnings:
from torchao.utils import (
TORCH_VERSION_AFTER_2_2,
TORCH_VERSION_AFTER_2_3,
TORCH_VERSION_AFTER_2_4,
TORCH_VERSION_AFTER_2_5,
TORCH_VERSION_AT_LEAST_2_2,
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
TORCH_VERSION_AT_LEAST_2_7,
TORCH_VERSION_AT_LEAST_2_8,
)

deprecated_api_to_name = [
(TORCH_VERSION_AT_LEAST_2_8, "TORCH_VERSION_AT_LEAST_2_8"),
(TORCH_VERSION_AT_LEAST_2_7, "TORCH_VERSION_AT_LEAST_2_7"),
(TORCH_VERSION_AT_LEAST_2_6, "TORCH_VERSION_AT_LEAST_2_6"),
(TORCH_VERSION_AT_LEAST_2_5, "TORCH_VERSION_AT_LEAST_2_5"),
(TORCH_VERSION_AT_LEAST_2_4, "TORCH_VERSION_AT_LEAST_2_4"),
(TORCH_VERSION_AT_LEAST_2_3, "TORCH_VERSION_AT_LEAST_2_3"),
(TORCH_VERSION_AT_LEAST_2_2, "TORCH_VERSION_AT_LEAST_2_2"),
(TORCH_VERSION_AFTER_2_5, "TORCH_VERSION_AFTER_2_5"),
(TORCH_VERSION_AFTER_2_4, "TORCH_VERSION_AFTER_2_4"),
(TORCH_VERSION_AFTER_2_3, "TORCH_VERSION_AFTER_2_3"),
(TORCH_VERSION_AFTER_2_2, "TORCH_VERSION_AFTER_2_2"),
]
self.assertEqual(len(_warnings), 0)

# Accessing the boolean value should trigger deprecation warning
with warnings.catch_warnings(record=True) as _warnings:
for api, name in deprecated_api_to_name:
num_warnings_before = len(_warnings)
if api:
pass
regex = f"{name} is deprecated and will be removed"
self.assertEqual(len(_warnings), num_warnings_before + 1)
self.assertIn(regex, str(_warnings[-1].message))


class TestTorchAOBaseTensor(unittest.TestCase):
def test_print_arg_types(self):
Expand Down
67 changes: 0 additions & 67 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import itertools
import re
import time
import warnings
from functools import reduce
from importlib.metadata import version
from math import gcd
Expand All @@ -34,17 +33,6 @@
"is_sm_at_least_90",
"is_package_at_least",
"DummyModule",
# Deprecated
"TORCH_VERSION_AT_LEAST_2_2",
"TORCH_VERSION_AT_LEAST_2_3",
"TORCH_VERSION_AT_LEAST_2_4",
"TORCH_VERSION_AT_LEAST_2_5",
"TORCH_VERSION_AT_LEAST_2_6",
"TORCH_VERSION_AT_LEAST_2_7",
"TORCH_VERSION_AFTER_2_2",
"TORCH_VERSION_AFTER_2_3",
"TORCH_VERSION_AFTER_2_4",
"TORCH_VERSION_AFTER_2_5",
]


Expand Down Expand Up @@ -378,61 +366,6 @@ def torch_version_at_least(min_version):
return parse_version(torch.__version__) >= parse_version(min_version)


def _deprecated_torch_version_at_least(version_str: str) -> str:
"""
Wrapper for existing TORCH_VERSION_AT_LEAST* variables that will log
a deprecation warning if the variable is used.
"""
version_str_var_name = "_".join(version_str.split(".")[:2])
deprecation_msg = f"TORCH_VERSION_AT_LEAST_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0"
return _BoolDeprecationWrapper(
torch_version_at_least(version_str),
deprecation_msg,
)


def _deprecated_torch_version_after(version_str: str) -> str:
"""
Wrapper for existing TORCH_VERSION_AFTER* variables that will log
a deprecation warning if the variable is used.
"""
bool_value = is_fbcode() or version("torch") >= version_str
version_str_var_name = "_".join(version_str.split(".")[:2])
deprecation_msg = f"TORCH_VERSION_AFTER_{version_str_var_name} is deprecated and will be removed in torchao 0.14.0"
return _BoolDeprecationWrapper(bool_value, deprecation_msg)


class _BoolDeprecationWrapper:
"""
A deprecation wrapper that logs a warning when the given bool value is accessed.
"""

def __init__(self, bool_value: bool, msg: str):
self.bool_value = bool_value
self.msg = msg

def __bool__(self):
warnings.warn(self.msg)
return self.bool_value

def __eq__(self, other):
return bool(self) == bool(other)


# Deprecated, use `torch_version_at_least` directly instead
TORCH_VERSION_AT_LEAST_2_8 = _deprecated_torch_version_at_least("2.8.0")
TORCH_VERSION_AT_LEAST_2_7 = _deprecated_torch_version_at_least("2.7.0")
TORCH_VERSION_AT_LEAST_2_6 = _deprecated_torch_version_at_least("2.6.0")
TORCH_VERSION_AT_LEAST_2_5 = _deprecated_torch_version_at_least("2.5.0")
TORCH_VERSION_AT_LEAST_2_4 = _deprecated_torch_version_at_least("2.4.0")
TORCH_VERSION_AT_LEAST_2_3 = _deprecated_torch_version_at_least("2.3.0")
TORCH_VERSION_AT_LEAST_2_2 = _deprecated_torch_version_at_least("2.2.0")
TORCH_VERSION_AFTER_2_5 = _deprecated_torch_version_after("2.5.0.dev")
TORCH_VERSION_AFTER_2_4 = _deprecated_torch_version_after("2.4.0.dev")
TORCH_VERSION_AFTER_2_3 = _deprecated_torch_version_after("2.3.0.dev")
TORCH_VERSION_AFTER_2_2 = _deprecated_torch_version_after("2.2.0.dev")


"""
Helper function for implementing aten op or torch function dispatch
and dispatching to these implementations.
Expand Down
Loading