diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml index 117141b..c2e78c0 100644 --- a/.github/workflows/pythonpackage.yml +++ b/.github/workflows/pythonpackage.yml @@ -44,16 +44,15 @@ jobs: - name: Install dependencies shell: bash - env: - SPECIAL: ${{ matrix.special }} run: | - if [[ $SPECIAL == '; no-optional' ]]; then - pip install -e .[test_no_optional] - elif [[ $SPECIAL == '; pre-release' ]]; then - pip install --pre -e .[test] --upgrade --force-reinstall - else - pip install -e .[test] - fi + case "${{ matrix.special }}" in + "; no-optional") + pip install -e .[test_no_optional] ;; + "; pre-release") + pip install --pre -e .[test] --upgrade --force-reinstall ;; + *) + pip install -e .[test] ;; + esac - name: Python info run: | @@ -65,14 +64,13 @@ jobs: - name: Test with pytest shell: bash - env: - SPECIAL: ${{ matrix.special }} run: | - if [[ $SPECIAL == '; no-optional' ]]; then - pytest --mypy - else - pytest --mypy --doctest-modules - fi + case "${{ matrix.special }}" in + "; no-optional") + pytest --mypy ;; + *) + pytest --mypy --doctest-modules ;; + esac - name: Run codecov uses: codecov/codecov-action@v2 diff --git a/CHANGELOG.rst b/CHANGELOG.rst index cd62be9..6cb3858 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,6 +6,14 @@ All notable changes to this project will be documented in this file. This project adheres to `Semantic Versioning `_. +2.3.0 +***** +* Added ``UserMapping`` entry points for the IPython key completioner + and pretty printer. +* Added a decorator for applying the effect of ``warnings.filterwarnings`` + to the decorated function. + + 2.2.0 ***** * Added a decorator for constructing positional-only signatures. diff --git a/README.rst b/README.rst index 1a8bdfc..f34daa6 100644 --- a/README.rst +++ b/README.rst @@ -22,7 +22,7 @@ ################ -Nano-Utils 2.2.0 +Nano-Utils 2.3.0 ################ Utility functions used throughout the various nlesc-nano repositories. diff --git a/nanoutils/__version__.py b/nanoutils/__version__.py index b66927e..fff80ed 100644 --- a/nanoutils/__version__.py +++ b/nanoutils/__version__.py @@ -1,3 +1,3 @@ """The **Nano-Utils** version.""" -__version__ = '2.2.0' +__version__ = '2.3.0' diff --git a/nanoutils/_user_dict.py b/nanoutils/_user_dict.py index 4f8b35a..85141dc 100644 --- a/nanoutils/_user_dict.py +++ b/nanoutils/_user_dict.py @@ -26,6 +26,12 @@ from .utils import positional_only from .typing_utils import Protocol, runtime_checkable +if TYPE_CHECKING: + from IPython.lib.pretty import RepresentationPrinter + + class _ReprFunc(Protocol[_KT, _VT]): + def __call__(self, __dct: dict[_KT, _VT], *, width: int) -> str: ... + __all__ = ["UserMapping", "MutableUserMapping", "_DictLike", "_SupportsKeysAndGetItem"] _SENTINEL = object() @@ -62,6 +68,17 @@ def __getitem__(self, __key: _KT) -> _VT_co: ] +def _repr_func(self: UserMapping[_KT, _VT], func: _ReprFunc[_KT, _VT]) -> str: + """Helper function for :meth:`UserMapping.__repr__`.""" + cls = type(self) + dict_repr = func(self._dict, width=76) + if len(dict_repr) <= 76: + return f"{cls.__name__}({dict_repr})" + else: + dict_repr2 = textwrap.indent(dict_repr[1:-1], 3 * " ") + return f"{cls.__name__}({{\n {dict_repr2},\n}})" + + class UserMapping(Mapping[_KT, _VT_co]): """Base class for user-defined immutable mappings.""" @@ -104,14 +121,21 @@ def copy(self: _ST1) -> _ST1: @reprlib.recursive_repr(fillvalue='...') def __repr__(self) -> str: """Implement :func:`repr(self) `.""" - cls = type(self) - width = 80 - 2 - len(cls.__name__) - dct_repr = pformat(self._dict, width=width) - if len(dct_repr) <= width: - return f"{cls.__name__}({dct_repr})" - else: - dct_repr2 = textwrap.indent(dct_repr[1:-1], 3 * " ") - return f"{cls.__name__}({{\n {dct_repr2},\n}})" + return _repr_func(self, func=pformat) + + def _repr_pretty_(self, p: RepresentationPrinter, cycle: bool) -> None: + """Entry point for the :mod:`IPython ` pretty printer.""" + if cycle: + p.text(f"{type(self).__name__}(...)") + return None + + from IPython.lib.pretty import pretty + string = _repr_func(self, func=lambda dct, width: pretty(dct, max_width=width)) + p.text(string) + + def _ipython_key_completions_(self) -> KeysView[_KT]: + """Entry point for the IPython key completioner.""" + return self.keys() def __hash__(self) -> int: """Implement :func:`hash(self) `. diff --git a/nanoutils/utils.py b/nanoutils/utils.py index e035463..7e412e5 100644 --- a/nanoutils/utils.py +++ b/nanoutils/utils.py @@ -18,6 +18,7 @@ import warnings import importlib import inspect +import functools from types import ModuleType from functools import wraps from typing import ( @@ -34,9 +35,10 @@ MutableMapping, Collection, cast, - overload + overload, ) +from .typing_utils import Literal from .empty import EMPTY_CONTAINER __all__ = [ @@ -58,6 +60,7 @@ 'positional_only', 'UserMapping', 'MutableUserMapping', + 'warning_filter', ] _T = TypeVar('_T') @@ -737,6 +740,75 @@ def positional_only(func: _FT) -> _FT: return func +def warning_filter( + action: Literal["default", "error", "ignore", "always", "module", "once"], + message: str = "", + category: type[Warning] = Warning, + module: str = "", + lineno: int = 0, + append: bool = False, +) -> Callable[[_FT], _FT]: + """A decorator for wrapping function calls with :func:`warnings.filterwarnings`. + + Examples + -------- + .. code-block:: python + + >>> from nanoutils import warning_filter + >>> import warnings + + >>> @warning_filter("error", category=UserWarning) + ... def func(): + ... warnings.warn("test", UserWarning) + + >>> func() + Traceback (most recent call last): + ... + UserWarning: test + + Parameters + ---------- + action : :class:`str` + One of the following strings: + + * ``"default"``: Print the first occurrence of matching warnings for each location (module + line number) where the warning is issued + * ``"error"``: Turn matching warnings into exceptions + * ``"ignore"``: Never print matching warnings + * ``"always"``: Always print matching warnings + * ``"module"``: Print the first occurrence of matching warnings for each module where the warning is issued (regardless of line number) + * ``"once"``: Print only the first occurrence of matching warnings, regardless of location + + message : :class:`str`, optional + A string containing a regular expression that the start of the warning message must match. + The expression is compiled to always be case-insensitive. + category : :class:`type[Warning] ` + The to-be affected :class:`Warning` (sub-)class. + module : :class:`str`, optional + A string containing a regular expression that the module name must match. + The expression is compiled to be case-sensitive. + lineno : :class:`int` + An integer that the line number where the warning occurred must match, + or 0 to match all line numbers. + append : :class:`bool` + Whether the warning entry is inserted at the end. + + See Also + -------- + :func:`warnings.filterwarnings` : + Insert a simple entry into the list of warnings filters (at the front). + + """ + def decorator(func: _FT) -> _FT: + @functools.wraps(func) + def wrapper(*args, **kwargs): + with warnings.catch_warnings(): + warnings.filterwarnings(action, message, category, module, lineno, append) + ret = func(*args, **kwargs) + return ret + return cast(_FT, wrapper) + return decorator + + # Move to the end to reduce the risk of circular imports from ._partial import PartialPrepend from ._set_attr import SetAttr @@ -747,5 +819,5 @@ def positional_only(func: _FT) -> _FT: __doc__ = construct_api_doc( globals(), - decorators={'set_docstring', 'raise_if', 'ignore_if', 'positional_only'}, + decorators={'set_docstring', 'raise_if', 'ignore_if', 'positional_only', 'warning_filter'}, ) diff --git a/setup.py b/setup.py index ebaa4c2..11c26f2 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ 'pyyaml', 'h5py', 'numpy', + 'ipython', ] tests_require += docs_require tests_require += build_requires diff --git a/tests/test_dtype_mapping.py b/tests/test_dtype_mapping.py index 7b1f39b..91f1978 100644 --- a/tests/test_dtype_mapping.py +++ b/tests/test_dtype_mapping.py @@ -3,7 +3,7 @@ import sys import textwrap from typing import TYPE_CHECKING, no_type_check -from collections.abc import Iterator +from collections.abc import Iterator, Callable import pytest from assertionlib import assertion @@ -18,6 +18,14 @@ import numpy.typing as npt import _pytest +try: + from IPython.lib.pretty import pretty +except ModuleNotFoundError: + IPYTHON: bool = False + pretty = NotImplemented +else: + IPYTHON = True + class BasicMapping: def __init__(self, dct: dict[str, npt.DTypeLike]) -> None: @@ -128,16 +136,21 @@ def test_repr(self, obj: DTypeMapping) -> None: )""").strip() assertion.str_eq(obj, string1, str_converter=repr) - string2 = textwrap.dedent(f""" + string2 = f"{type(obj).__name__}()" + assertion.str_eq(type(obj)(), string2, str_converter=repr) + + @pytest.mark.parametrize("str_func", [ + str, + pytest.param(pretty, marks=pytest.mark.skipif(not IPYTHON, reason="Requires IPython")), + ], ids=["str", "pretty"]) + def test_str(self, obj: DTypeMapping, str_func: Callable[[object], str]) -> None: + string = textwrap.dedent(f""" {type(obj).__name__}( a = int64, b = float64, c = None: @@ -75,9 +83,14 @@ def test_eq(self, obj: UserMapping[str, int]) -> None: def test_getitem(self, obj: UserMapping[str, int], key: str, value: int) -> None: assertion.eq(obj[key], value) - def test_repr(self, obj: UserMapping[str, int]) -> None: + @pytest.mark.parametrize("str_func", [ + str, + repr, + pytest.param(pretty, marks=pytest.mark.skipif(not IPYTHON, reason="Requires IPython")), + ], ids=["str", "repr", "pretty"]) + def test_repr(self, obj: UserMapping[str, int], str_func: Callable[[object], str]) -> None: string1 = f"{type(obj).__name__}({{'a': 0, 'b': 1, 'c': 2}})" - assertion.str_eq(obj, string1) + assertion.str_eq(obj, string1, str_converter=str_func) cls = type(obj) ref2 = cls(zip(string.ascii_lowercase[:12], range(12))) @@ -97,7 +110,12 @@ def test_repr(self, obj: UserMapping[str, int]) -> None: 'l': 11, }}) """).strip() - assertion.str_eq(ref2, string2) + assertion.str_eq(ref2, string2, str_converter=str_func) + + @pytest.mark.skipif(not IPYTHON, reason="Rquires IPython") + def test_pretty_repr(self, obj: UserMapping[str, int]) -> None: + string1 = f"{type(obj).__name__}({{'a': 0, 'b': 1, 'c': 2}})" + assertion.str_eq(obj, string1, str_converter=pretty) def test_hash(self, obj: UserMapping[str, int]) -> None: if isinstance(obj, MutableUserMapping): @@ -134,6 +152,10 @@ def test_fromkeys(self, obj: UserMapping[str, int]) -> None: assertion.isinstance(dct, cls) assertion.eq(dct.keys(), obj.keys()) + def test_key_completions(self, obj: UserMapping[str, int]) -> None: + assertion.isinstance(obj._ipython_key_completions_(), KeysView) + assertion.eq(obj._ipython_key_completions_(), obj.keys()) + def test_get(self, obj: UserMapping[str, int]) -> None: assertion.eq(obj.get("a"), 0) assertion.is_(obj.get("d"), None)