From 8da81c839fb9aec414d7215563570116d2b48f19 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 8 Mar 2024 14:13:49 +0100 Subject: [PATCH 1/5] ENH: write cache to system cache directory --- src/ampform/sympy/__init__.py | 46 ++++++++++++++++++++++++++++------- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/src/ampform/sympy/__init__.py b/src/ampform/sympy/__init__.py index a2f983dec..799e9210c 100644 --- a/src/ampform/sympy/__init__.py +++ b/src/ampform/sympy/__init__.py @@ -22,7 +22,7 @@ import sys import warnings from abc import abstractmethod -from os.path import abspath, dirname, expanduser +from os.path import abspath, dirname from textwrap import dedent from typing import TYPE_CHECKING, Iterable, Sequence, SupportsFloat @@ -291,9 +291,9 @@ def determine_indices(symbol: sp.Basic) -> list[int]: def perform_cached_doit( - unevaluated_expr: sp.Expr, directory: str | None = None + unevaluated_expr: sp.Expr, cache_directory: str | None = None ) -> sp.Expr: - """Perform :meth:`~sympy.core.basic.Basic.doit` cache the result to disk. + """Perform :meth:`~sympy.core.basic.Basic.doit` and cache the result to disk. The cached result is fetched from disk if the hash of the original expression is the same as the hash embedded in the filename. @@ -301,18 +301,21 @@ def perform_cached_doit( Args: unevaluated_expr: A `sympy.Expr ` on which to call :meth:`~sympy.core.basic.Basic.doit`. - directory: The directory in which to cache the result. If `None`, the cache - directory will be put under the home directory. + cache_directory: The directory in which to cache the result. Defaults to + :file:`ampform` under the system cache directory (see + :func:`_get_system_cache_directory`). .. tip:: For a faster cache, set `PYTHONHASHSEED `_ to a fixed value. + + .. autofunction:: _get_system_cache_directory """ - if directory is None: - home_directory = expanduser("~") - directory = abspath(f"{home_directory}/.sympy-cache") + if cache_directory is None: + system_cache_dir = _get_system_cache_directory() + cache_directory = abspath(f"{system_cache_dir}/ampform") h = get_readable_hash(unevaluated_expr) - filename = f"{directory}/{h}.pkl" + filename = f"{cache_directory}/{h}.pkl" os.makedirs(dirname(filename), exist_ok=True) if os.path.exists(filename): with open(filename, "rb") as f: @@ -326,6 +329,31 @@ def perform_cached_doit( return unfolded_expr +def _get_system_cache_directory() -> str: + r"""Return the system cache directory for the current platform. + + >>> import sys, pytest + >>> if sys.platform.startswith("darwin"): + ... assert _get_system_cache_directory().endswith("/Library/Caches") + >>> if sys.platform.startswith("linux"): + ... assert _get_system_cache_directory().endswith("/.cache") + >>> if sys.platform.startswith("win"): + ... assert _get_system_cache_directory().endswith(R"\AppData\Local") + """ + if sys.platform.startswith("linux"): + cache_directory = os.getenv("XDG_CACHE_HOME") + if cache_directory is not None: + return cache_directory + if sys.platform.startswith("darwin"): # macos + return os.path.expanduser("~/Library/Caches") + if sys.platform.startswith("win"): + cache_directory = os.getenv("LocalAppData") # noqa: SIM112 + if cache_directory is not None: + return cache_directory + return os.path.expanduser("~/AppData/Local") + return os.path.expanduser("~/.cache") + + def get_readable_hash(obj) -> str: python_hash_seed = _get_python_hash_seed() if python_hash_seed is not None: From f3f2034cf7a57d135861d93e9fe86ef527922050 Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 8 Mar 2024 14:13:50 +0100 Subject: [PATCH 2/5] BREAK: make `_get_readable_hash()` private and document --- src/ampform/sympy/__init__.py | 11 +++++++++-- tests/sympy/test_caching.py | 8 ++++---- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/ampform/sympy/__init__.py b/src/ampform/sympy/__init__.py index 799e9210c..8d17a5bf0 100644 --- a/src/ampform/sympy/__init__.py +++ b/src/ampform/sympy/__init__.py @@ -310,11 +310,12 @@ def perform_cached_doit( fixed value. .. autofunction:: _get_system_cache_directory + .. autofunction:: _get_readable_hash """ if cache_directory is None: system_cache_dir = _get_system_cache_directory() cache_directory = abspath(f"{system_cache_dir}/ampform") - h = get_readable_hash(unevaluated_expr) + h = _get_readable_hash(unevaluated_expr) filename = f"{cache_directory}/{h}.pkl" os.makedirs(dirname(filename), exist_ok=True) if os.path.exists(filename): @@ -354,7 +355,13 @@ def _get_system_cache_directory() -> str: return os.path.expanduser("~/.cache") -def get_readable_hash(obj) -> str: +def _get_readable_hash(obj) -> str: + """Get a human-readable hash of any Python object. + + The algorithm is fastest if `PYTHONHASHSEED + `_ is set. + Otherwise, it falls back to computing the hash with :func:`hashlib.sha256()`. + """ python_hash_seed = _get_python_hash_seed() if python_hash_seed is not None: return f"pythonhashseed-{python_hash_seed}{hash(obj):+}" diff --git a/tests/sympy/test_caching.py b/tests/sympy/test_caching.py index 292996bfd..036744ab1 100644 --- a/tests/sympy/test_caching.py +++ b/tests/sympy/test_caching.py @@ -9,7 +9,7 @@ import sympy as sp from ampform.dynamics import EnergyDependentWidth -from ampform.sympy import _warn_about_unsafe_hash, get_readable_hash +from ampform.sympy import _get_readable_hash, _warn_about_unsafe_hash if TYPE_CHECKING: from _pytest.logging import LogCaptureFixture @@ -61,7 +61,7 @@ def test_get_readable_hash(assumptions, expected_hashes, caplog: LogCaptureFixtu caplog.set_level(logging.WARNING) x, y = sp.symbols("x y", **assumptions) expr = x**2 + y - h_str = get_readable_hash(expr) + h_str = _get_readable_hash(expr) python_hash_seed = os.environ.get("PYTHONHASHSEED") if python_hash_seed is None: assert h_str[:7] == "bbc9833" @@ -88,7 +88,7 @@ def test_get_readable_hash_energy_dependent_width(): angular_momentum=angular_momentum, meson_radius=d, ) - h = get_readable_hash(expr) + h = _get_readable_hash(expr) python_hash_seed = os.environ.get("PYTHONHASHSEED") if python_hash_seed is None: pytest.skip("PYTHONHASHSEED has not been set") @@ -124,4 +124,4 @@ def test_get_readable_hash_large(amplitude_model: tuple[str, HelicityModel]): "canonical-helicity": "pythonhashseed-0-8505502895987205495", "helicity": "pythonhashseed-0-1430245260241162669", }[formalism] - assert get_readable_hash(model.expression) == expected_hash + assert _get_readable_hash(model.expression) == expected_hash From 4af3e31767a6c87d71ddcb98f91b823b9bc4a38f Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 8 Mar 2024 14:13:51 +0100 Subject: [PATCH 3/5] ENH: add `ignore_hash_seed` flag --- src/ampform/sympy/__init__.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/ampform/sympy/__init__.py b/src/ampform/sympy/__init__.py index 8d17a5bf0..cc4371103 100644 --- a/src/ampform/sympy/__init__.py +++ b/src/ampform/sympy/__init__.py @@ -355,18 +355,24 @@ def _get_system_cache_directory() -> str: return os.path.expanduser("~/.cache") -def _get_readable_hash(obj) -> str: - """Get a human-readable hash of any Python object. +def _get_readable_hash(obj, ignore_hash_seed: bool = False) -> str: + """Get a human-readable hash of any hashable Python object. The algorithm is fastest if `PYTHONHASHSEED `_ is set. Otherwise, it falls back to computing the hash with :func:`hashlib.sha256()`. + + Args: + obj: Any hashable object, mutable or immutable, to be hashed. + ignore_hash_seed: Ignore the :code:`PYTHONHASHSEED` environment variable. If + :code:`True`, the hash seed is ignored and the hash is computed with + :func:`hashlib.sha256`. """ python_hash_seed = _get_python_hash_seed() - if python_hash_seed is not None: - return f"pythonhashseed-{python_hash_seed}{hash(obj):+}" - b = _to_bytes(obj) - return hashlib.sha256(b).hexdigest() + if ignore_hash_seed or python_hash_seed is None: + b = _to_bytes(obj) + return hashlib.sha256(b).hexdigest() + return f"pythonhashseed-{python_hash_seed}{hash(obj):+}" def _to_bytes(obj) -> bytes: From fde39299778179b877894bac1a7fb95911aec30a Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 8 Mar 2024 14:13:52 +0100 Subject: [PATCH 4/5] MAINT: move cache functions to end of module --- src/ampform/sympy/__init__.py | 92 +++++++++++++++++------------------ 1 file changed, 46 insertions(+), 46 deletions(-) diff --git a/src/ampform/sympy/__init__.py b/src/ampform/sympy/__init__.py index cc4371103..76bf00a2e 100644 --- a/src/ampform/sympy/__init__.py +++ b/src/ampform/sympy/__init__.py @@ -290,6 +290,52 @@ def determine_indices(symbol: sp.Basic) -> list[int]: return list(indices) +class UnevaluatableIntegral(sp.Integral): + abs_tolerance = 1e-5 + rel_tolerance = 1e-5 + limit = 50 + dummify = True + + @override + def doit(self, **hints): + args = [arg.doit(**hints) for arg in self.args] + return self.func(*args) + + @override + def _numpycode(self, printer, *args): + _warn_if_scipy_not_installed() + integration_vars, limits = _unpack_integral_limits(self) + if len(limits) != 1 or len(integration_vars) != 1: + msg = f"Cannot handle {len(limits)}-dimensional integrals" + raise ValueError(msg) + x = integration_vars[0] + a, b = limits[0] + expr = self.args[0] + if self.dummify: + dummy = sp.Dummy() + expr = expr.xreplace({x: dummy}) + x = dummy + integrate_func = "quad_vec" + printer.module_imports["scipy.integrate"].add(integrate_func) + return ( + f"{integrate_func}(lambda {printer._print(x)}: {printer._print(expr)}," + f" {printer._print(a)}, {printer._print(b)}," + f" epsabs={self.abs_tolerance}, epsrel={self.abs_tolerance}," + f" limit={self.limit})[0]" + ) + + +def _warn_if_scipy_not_installed() -> None: + try: + import scipy # noqa: F401, PLC0415 # pyright: ignore[reportUnusedImport, reportMissingImports] + except ImportError: + warnings.warn( + "Scipy is not installed. Install with 'pip install scipy' or with 'pip" + " install ampform[scipy]'", + stacklevel=1, + ) + + def perform_cached_doit( unevaluated_expr: sp.Expr, cache_directory: str | None = None ) -> sp.Expr: @@ -400,49 +446,3 @@ def _warn_about_unsafe_hash(): """ message = dedent(message).replace("\n", " ").strip() _LOGGER.warning(message) - - -class UnevaluatableIntegral(sp.Integral): - abs_tolerance = 1e-5 - rel_tolerance = 1e-5 - limit = 50 - dummify = True - - @override - def doit(self, **hints): - args = [arg.doit(**hints) for arg in self.args] - return self.func(*args) - - @override - def _numpycode(self, printer, *args): - _warn_if_scipy_not_installed() - integration_vars, limits = _unpack_integral_limits(self) - if len(limits) != 1 or len(integration_vars) != 1: - msg = f"Cannot handle {len(limits)}-dimensional integrals" - raise ValueError(msg) - x = integration_vars[0] - a, b = limits[0] - expr = self.args[0] - if self.dummify: - dummy = sp.Dummy() - expr = expr.xreplace({x: dummy}) - x = dummy - integrate_func = "quad_vec" - printer.module_imports["scipy.integrate"].add(integrate_func) - return ( - f"{integrate_func}(lambda {printer._print(x)}: {printer._print(expr)}," - f" {printer._print(a)}, {printer._print(b)}," - f" epsabs={self.abs_tolerance}, epsrel={self.abs_tolerance}," - f" limit={self.limit})[0]" - ) - - -def _warn_if_scipy_not_installed() -> None: - try: - import scipy # noqa: F401, PLC0415 # pyright: ignore[reportUnusedImport, reportMissingImports] - except ImportError: - warnings.warn( - "Scipy is not installed. Install with 'pip install scipy' or with 'pip" - " install ampform[scipy]'", - stacklevel=1, - ) From 55a15eb472ea67a715ed6be6ecb13ec3c51ba63c Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Fri, 8 Mar 2024 14:13:53 +0100 Subject: [PATCH 5/5] MAINT: move cache helper functions to `sympy._cache` --- src/ampform/sympy/__init__.py | 87 ++----------------- src/ampform/sympy/_cache.py | 87 +++++++++++++++++++ .../sympy/{test_caching.py => test_cache.py} | 8 +- 3 files changed, 97 insertions(+), 85 deletions(-) create mode 100644 src/ampform/sympy/_cache.py rename tests/sympy/{test_caching.py => test_cache.py} (95%) diff --git a/src/ampform/sympy/__init__.py b/src/ampform/sympy/__init__.py index 76bf00a2e..babfe0c8a 100644 --- a/src/ampform/sympy/__init__.py +++ b/src/ampform/sympy/__init__.py @@ -12,8 +12,6 @@ from __future__ import annotations -import functools -import hashlib import itertools import logging import os @@ -23,7 +21,6 @@ import warnings from abc import abstractmethod from os.path import abspath, dirname -from textwrap import dedent from typing import TYPE_CHECKING, Iterable, Sequence, SupportsFloat import sympy as sp @@ -31,6 +28,7 @@ from sympy.printing.precedence import PRECEDENCE from sympy.printing.pycode import _unpack_integral_limits # noqa: PLC2701 +from ._cache import get_readable_hash, get_system_cache_directory from ._decorator import ( ExprClass, # noqa: F401 # pyright: ignore[reportUnusedImport] SymPyAssumptions, # noqa: F401 # pyright: ignore[reportUnusedImport] @@ -342,26 +340,25 @@ def perform_cached_doit( """Perform :meth:`~sympy.core.basic.Basic.doit` and cache the result to disk. The cached result is fetched from disk if the hash of the original expression is the - same as the hash embedded in the filename. + same as the hash embedded in the filename (see :func:`.get_readable_hash`). Args: unevaluated_expr: A `sympy.Expr ` on which to call :meth:`~sympy.core.basic.Basic.doit`. cache_directory: The directory in which to cache the result. Defaults to :file:`ampform` under the system cache directory (see - :func:`_get_system_cache_directory`). + :func:`.get_system_cache_directory`). .. tip:: For a faster cache, set `PYTHONHASHSEED `_ to a fixed value. - .. autofunction:: _get_system_cache_directory - .. autofunction:: _get_readable_hash + .. automodule:: ampform.sympy._cache """ if cache_directory is None: - system_cache_dir = _get_system_cache_directory() + system_cache_dir = get_system_cache_directory() cache_directory = abspath(f"{system_cache_dir}/ampform") - h = _get_readable_hash(unevaluated_expr) + h = get_readable_hash(unevaluated_expr) filename = f"{cache_directory}/{h}.pkl" os.makedirs(dirname(filename), exist_ok=True) if os.path.exists(filename): @@ -374,75 +371,3 @@ def perform_cached_doit( with open(filename, "wb") as f: pickle.dump(unfolded_expr, f) return unfolded_expr - - -def _get_system_cache_directory() -> str: - r"""Return the system cache directory for the current platform. - - >>> import sys, pytest - >>> if sys.platform.startswith("darwin"): - ... assert _get_system_cache_directory().endswith("/Library/Caches") - >>> if sys.platform.startswith("linux"): - ... assert _get_system_cache_directory().endswith("/.cache") - >>> if sys.platform.startswith("win"): - ... assert _get_system_cache_directory().endswith(R"\AppData\Local") - """ - if sys.platform.startswith("linux"): - cache_directory = os.getenv("XDG_CACHE_HOME") - if cache_directory is not None: - return cache_directory - if sys.platform.startswith("darwin"): # macos - return os.path.expanduser("~/Library/Caches") - if sys.platform.startswith("win"): - cache_directory = os.getenv("LocalAppData") # noqa: SIM112 - if cache_directory is not None: - return cache_directory - return os.path.expanduser("~/AppData/Local") - return os.path.expanduser("~/.cache") - - -def _get_readable_hash(obj, ignore_hash_seed: bool = False) -> str: - """Get a human-readable hash of any hashable Python object. - - The algorithm is fastest if `PYTHONHASHSEED - `_ is set. - Otherwise, it falls back to computing the hash with :func:`hashlib.sha256()`. - - Args: - obj: Any hashable object, mutable or immutable, to be hashed. - ignore_hash_seed: Ignore the :code:`PYTHONHASHSEED` environment variable. If - :code:`True`, the hash seed is ignored and the hash is computed with - :func:`hashlib.sha256`. - """ - python_hash_seed = _get_python_hash_seed() - if ignore_hash_seed or python_hash_seed is None: - b = _to_bytes(obj) - return hashlib.sha256(b).hexdigest() - return f"pythonhashseed-{python_hash_seed}{hash(obj):+}" - - -def _to_bytes(obj) -> bytes: - if isinstance(obj, sp.Expr): - # Using the str printer is slower and not necessarily unique, - # but pickle.dumps() does not always result in the same bytes stream. - _warn_about_unsafe_hash() - return str(obj).encode() - return pickle.dumps(obj) - - -def _get_python_hash_seed() -> int | None: - python_hash_seed = os.environ.get("PYTHONHASHSEED", "") - if python_hash_seed is not None and python_hash_seed.isdigit(): - return int(python_hash_seed) - return None - - -@functools.lru_cache(maxsize=None) # warn once -def _warn_about_unsafe_hash(): - message = """ - PYTHONHASHSEED has not been set. For faster and safer hashing of SymPy expressions, - set the PYTHONHASHSEED environment variable to a fixed value and rerun the program. - See https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED - """ - message = dedent(message).replace("\n", " ").strip() - _LOGGER.warning(message) diff --git a/src/ampform/sympy/_cache.py b/src/ampform/sympy/_cache.py new file mode 100644 index 000000000..421f4d89c --- /dev/null +++ b/src/ampform/sympy/_cache.py @@ -0,0 +1,87 @@ +"""Helper functions for :func:`.perform_cached_doit`.""" + +from __future__ import annotations + +import functools +import hashlib +import logging +import os +import pickle # noqa: S403 +import sys +from textwrap import dedent + +import sympy as sp + +_LOGGER = logging.getLogger(__name__) + + +def get_system_cache_directory() -> str: + r"""Return the system cache directory for the current platform. + + >>> import sys, pytest + >>> if sys.platform.startswith("darwin"): + ... assert get_system_cache_directory().endswith("/Library/Caches") + >>> if sys.platform.startswith("linux"): + ... assert get_system_cache_directory().endswith("/.cache") + >>> if sys.platform.startswith("win"): + ... assert get_system_cache_directory().endswith(R"\AppData\Local") + """ + if sys.platform.startswith("linux"): + cache_directory = os.getenv("XDG_CACHE_HOME") + if cache_directory is not None: + return cache_directory + if sys.platform.startswith("darwin"): # macos + return os.path.expanduser("~/Library/Caches") + if sys.platform.startswith("win"): + cache_directory = os.getenv("LocalAppData") # noqa: SIM112 + if cache_directory is not None: + return cache_directory + return os.path.expanduser("~/AppData/Local") + return os.path.expanduser("~/.cache") + + +def get_readable_hash(obj, ignore_hash_seed: bool = False) -> str: + """Get a human-readable hash of any hashable Python object. + + The algorithm is fastest if `PYTHONHASHSEED + `_ is set. + Otherwise, it falls back to computing the hash with :func:`hashlib.sha256()`. + + Args: + obj: Any hashable object, mutable or immutable, to be hashed. + ignore_hash_seed: Ignore the :code:`PYTHONHASHSEED` environment variable. If + :code:`True`, the hash seed is ignored and the hash is computed with + :func:`hashlib.sha256`. + """ + python_hash_seed = _get_python_hash_seed() + if ignore_hash_seed or python_hash_seed is None: + b = _to_bytes(obj) + return hashlib.sha256(b).hexdigest() + return f"pythonhashseed-{python_hash_seed}{hash(obj):+}" + + +def _to_bytes(obj) -> bytes: + if isinstance(obj, sp.Expr): + # Using the str printer is slower and not necessarily unique, + # but pickle.dumps() does not always result in the same bytes stream. + _warn_about_unsafe_hash() + return str(obj).encode() + return pickle.dumps(obj) + + +def _get_python_hash_seed() -> int | None: + python_hash_seed = os.environ.get("PYTHONHASHSEED", "") + if python_hash_seed is not None and python_hash_seed.isdigit(): + return int(python_hash_seed) + return None + + +@functools.lru_cache(maxsize=None) # warn once +def _warn_about_unsafe_hash(): + message = """ + PYTHONHASHSEED has not been set. For faster and safer hashing of SymPy expressions, + set the PYTHONHASHSEED environment variable to a fixed value and rerun the program. + See https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED + """ + message = dedent(message).replace("\n", " ").strip() + _LOGGER.warning(message) diff --git a/tests/sympy/test_caching.py b/tests/sympy/test_cache.py similarity index 95% rename from tests/sympy/test_caching.py rename to tests/sympy/test_cache.py index 036744ab1..41fd50906 100644 --- a/tests/sympy/test_caching.py +++ b/tests/sympy/test_cache.py @@ -9,7 +9,7 @@ import sympy as sp from ampform.dynamics import EnergyDependentWidth -from ampform.sympy import _get_readable_hash, _warn_about_unsafe_hash +from ampform.sympy._cache import _warn_about_unsafe_hash, get_readable_hash if TYPE_CHECKING: from _pytest.logging import LogCaptureFixture @@ -61,7 +61,7 @@ def test_get_readable_hash(assumptions, expected_hashes, caplog: LogCaptureFixtu caplog.set_level(logging.WARNING) x, y = sp.symbols("x y", **assumptions) expr = x**2 + y - h_str = _get_readable_hash(expr) + h_str = get_readable_hash(expr) python_hash_seed = os.environ.get("PYTHONHASHSEED") if python_hash_seed is None: assert h_str[:7] == "bbc9833" @@ -88,7 +88,7 @@ def test_get_readable_hash_energy_dependent_width(): angular_momentum=angular_momentum, meson_radius=d, ) - h = _get_readable_hash(expr) + h = get_readable_hash(expr) python_hash_seed = os.environ.get("PYTHONHASHSEED") if python_hash_seed is None: pytest.skip("PYTHONHASHSEED has not been set") @@ -124,4 +124,4 @@ def test_get_readable_hash_large(amplitude_model: tuple[str, HelicityModel]): "canonical-helicity": "pythonhashseed-0-8505502895987205495", "helicity": "pythonhashseed-0-1430245260241162669", }[formalism] - assert _get_readable_hash(model.expression) == expected_hash + assert get_readable_hash(model.expression) == expected_hash