Skip to content

Commit 092da80

Browse files
[mypy] deprecation.py (#3188)
### Changes - Enable mypy for `nncf/common/deprecation.py` - Rework deprecation decorator
1 parent f355847 commit 092da80

File tree

3 files changed

+89
-42
lines changed

3 files changed

+89
-42
lines changed

nncf/common/deprecation.py

+68-41
Original file line numberDiff line numberDiff line change
@@ -9,58 +9,85 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
import functools
13-
import inspect
12+
import types
1413
import warnings
15-
from typing import Callable, Type, TypeVar
14+
from functools import wraps
15+
from typing import Any, Callable, Optional, TypeVar, cast
1616

17-
from packaging import version
17+
TObj = TypeVar("TObj")
1818

1919

2020
def warning_deprecated(msg: str) -> None:
21+
"""
22+
Display a warning message indicating that a certain functionality is deprecated.
23+
24+
:param msg: The warning message to display.
25+
"""
2126
# Note: must use FutureWarning in order not to get suppressed by default
2227
warnings.warn(msg, FutureWarning, stacklevel=2)
2328

2429

25-
ClassOrFn = TypeVar("ClassOrFn", Callable, Type)
30+
def deprecated(
31+
msg: Optional[str] = None, start_version: Optional[str] = None, end_version: Optional[str] = None
32+
) -> Callable[[TObj], TObj]:
33+
"""
34+
Decorator to mark a function or class as deprecated.
2635
36+
:param msg: Message to provide additional information about the deprecation.
37+
:param start_version: Start version from which the function or class is deprecated.
38+
:param end_version: End version until which the function or class is deprecated.
2739
28-
class deprecated:
40+
:return: The decorator function.
2941
"""
30-
A decorator for marking function calls or class instantiations as deprecated. A call to the marked function or an
31-
instantiation of an object of the marked class will trigger a `FutureWarning`. If a class is marked as
32-
@deprecated, only the instantiations will trigger a warning, but static attribute accesses or method calls will not.
42+
43+
def decorator(obj: TObj) -> TObj:
44+
45+
if isinstance(obj, types.FunctionType):
46+
47+
@wraps(obj)
48+
def wrapper(*args: Any, **kwargs: Any) -> Any:
49+
name = f"function '{obj.__module__}.{obj.__name__}'"
50+
text = _generate_deprecation_message(name, msg, start_version, end_version)
51+
warning_deprecated(text)
52+
return obj(*args, **kwargs)
53+
54+
return cast(TObj, wrapper)
55+
56+
if isinstance(obj, type):
57+
original_init = obj.__init__ # type: ignore[misc]
58+
59+
@wraps(original_init)
60+
def wrapped_init(*args: Any, **kwargs: Any) -> Any:
61+
name = f"class '{obj.__module__}.{obj.__name__}'"
62+
text = _generate_deprecation_message(name, msg, start_version, end_version)
63+
warning_deprecated(text)
64+
return original_init(*args, **kwargs)
65+
66+
obj.__init__ = wrapped_init # type: ignore[misc]
67+
68+
return cast(TObj, obj)
69+
70+
raise TypeError("The @deprecated decorator can only be used on functions or classes.")
71+
72+
return decorator
73+
74+
75+
def _generate_deprecation_message(
76+
name: str, text: Optional[str], start_version: Optional[str], end_version: Optional[str]
77+
) -> str:
3378
"""
79+
Generate a deprecation message for a given name, with optional start and end versions.
3480
35-
def __init__(self, msg: str = None, start_version: str = None, end_version: str = None):
36-
"""
37-
:param msg: Custom message to be added after the boilerplate deprecation text.
38-
"""
39-
self.msg = msg
40-
self.start_version = version.parse(start_version) if start_version is not None else None
41-
self.end_version = version.parse(end_version) if end_version is not None else None
42-
43-
def __call__(self, fn_or_class: ClassOrFn) -> ClassOrFn:
44-
name = fn_or_class.__module__ + "." + fn_or_class.__name__
45-
if inspect.isclass(fn_or_class):
46-
fn_or_class.__init__ = self._get_wrapper(fn_or_class.__init__, name)
47-
return fn_or_class
48-
return self._get_wrapper(fn_or_class, name)
49-
50-
def _get_wrapper(self, fn_to_wrap: Callable, name: str) -> Callable:
51-
@functools.wraps(fn_to_wrap)
52-
def wrapped(*args, **kwargs):
53-
msg = f"Usage of {name} is deprecated "
54-
if self.start_version is not None:
55-
msg += f"starting from NNCF v{str(self.start_version)} "
56-
msg += "and will be removed in "
57-
if self.end_version is not None:
58-
msg += f"NNCF v{str(self.end_version)}."
59-
else:
60-
msg += "a future NNCF version."
61-
if self.msg is not None:
62-
msg += "\n" + self.msg
63-
warning_deprecated(msg)
64-
return fn_to_wrap(*args, **kwargs)
65-
66-
return wrapped
81+
:param name: The name of the deprecated feature.
82+
:param text: Additional text to include in the deprecation message.
83+
:param start_version: The version from which the feature is deprecated.
84+
:param end_version: The version in which the feature will be removed.
85+
:return: The deprecation message.
86+
"""
87+
msg = (
88+
f"Usage of {name} is deprecated {f'starting from NNCF v{start_version} ' if start_version else ''}"
89+
f"and will be removed in {f'NNCF v{end_version}.' if end_version else 'a future NNCF version.'}"
90+
)
91+
if text:
92+
return "\n".join([msg, text])
93+
return msg

pyproject.toml

-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ files = [
108108
exclude = [
109109
"nncf/common/composite_compression.py",
110110
"nncf/common/compression.py",
111-
"nncf/common/deprecation.py",
112111
"nncf/common/logging/progress_bar.py",
113112
"nncf/common/logging/track_progress.py",
114113
"nncf/common/pruning/clusterization.py",

tests/common/test_deprecation_warnings.py

+21
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# limitations under the License.
1111
import pytest
1212

13+
from nncf.common.deprecation import _generate_deprecation_message
1314
from nncf.common.deprecation import deprecated
1415
from nncf.common.logging.logger import NNCFDeprecationWarning
1516

@@ -50,3 +51,23 @@ def test_warnings_are_shown_for_deprecated_function_call_with_versions():
5051
def test_warnings_are_shown_for_deprecated_class_instantiation():
5152
with pytest.warns(NNCFDeprecationWarning, match=EXAMPLE_MSG):
5253
DeprecatedClass()
54+
55+
56+
def test_generate_deprecation_message():
57+
ret = _generate_deprecation_message("foo", "text", "1.2.3", "4.5.6")
58+
assert ret == "Usage of foo is deprecated starting from NNCF v1.2.3 and will be removed in NNCF v4.5.6.\ntext"
59+
60+
ret = _generate_deprecation_message("foo", "text", "1.2.3", None)
61+
assert (
62+
ret
63+
== "Usage of foo is deprecated starting from NNCF v1.2.3 and will be removed in a future NNCF version.\ntext"
64+
)
65+
66+
ret = _generate_deprecation_message("foo", "text", None, None)
67+
assert ret == "Usage of foo is deprecated and will be removed in a future NNCF version.\ntext"
68+
69+
ret = _generate_deprecation_message("foo", "text", None, "4.5.6")
70+
assert ret == "Usage of foo is deprecated and will be removed in NNCF v4.5.6.\ntext"
71+
72+
ret = _generate_deprecation_message("foo", None, None, None)
73+
assert ret == "Usage of foo is deprecated and will be removed in a future NNCF version."

0 commit comments

Comments
 (0)