Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: write cache to user cache directory #407

Merged
merged 5 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 41 additions & 75 deletions src/ampform/sympy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

from __future__ import annotations

import functools
import hashlib
import itertools
import logging
import os
Expand All @@ -22,15 +20,15 @@
import sys
import warnings
from abc import abstractmethod
from os.path import abspath, dirname, expanduser
from textwrap import dedent
from os.path import abspath, dirname
from typing import TYPE_CHECKING, Iterable, Sequence, SupportsFloat

import sympy as sp
from sympy.printing.conventions import split_super_sub
from sympy.printing.precedence import PRECEDENCE
from sympy.printing.pycode import _unpack_integral_limits # noqa: PLC2701

from ._cache import get_readable_hash, get_system_cache_directory
from ._decorator import (
ExprClass, # noqa: F401 # pyright: ignore[reportUnusedImport]
SymPyAssumptions, # noqa: F401 # pyright: ignore[reportUnusedImport]
Expand Down Expand Up @@ -290,77 +288,6 @@ def determine_indices(symbol: sp.Basic) -> list[int]:
return list(indices)


def perform_cached_doit(
unevaluated_expr: sp.Expr, directory: str | None = None
) -> sp.Expr:
"""Perform :meth:`~sympy.core.basic.Basic.doit` cache the result to disk.

The cached result is fetched from disk if the hash of the original expression is the
same as the hash embedded in the filename.

Args:
unevaluated_expr: A `sympy.Expr <sympy.core.expr.Expr>` on which to call
:meth:`~sympy.core.basic.Basic.doit`.
directory: The directory in which to cache the result. If `None`, the cache
directory will be put under the home directory.

.. tip:: For a faster cache, set `PYTHONHASHSEED
<https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED>`_ to a
fixed value.
"""
if directory is None:
home_directory = expanduser("~")
directory = abspath(f"{home_directory}/.sympy-cache")
h = get_readable_hash(unevaluated_expr)
filename = f"{directory}/{h}.pkl"
os.makedirs(dirname(filename), exist_ok=True)
if os.path.exists(filename):
with open(filename, "rb") as f:
return pickle.load(f) # noqa: S301
_LOGGER.warning(
f"Cached expression file {filename} not found, performing doit()..."
)
unfolded_expr = unevaluated_expr.doit()
with open(filename, "wb") as f:
pickle.dump(unfolded_expr, f)
return unfolded_expr


def get_readable_hash(obj) -> str:
python_hash_seed = _get_python_hash_seed()
if python_hash_seed is not None:
return f"pythonhashseed-{python_hash_seed}{hash(obj):+}"
b = _to_bytes(obj)
return hashlib.sha256(b).hexdigest()


def _to_bytes(obj) -> bytes:
if isinstance(obj, sp.Expr):
# Using the str printer is slower and not necessarily unique,
# but pickle.dumps() does not always result in the same bytes stream.
_warn_about_unsafe_hash()
return str(obj).encode()
return pickle.dumps(obj)


def _get_python_hash_seed() -> int | None:
python_hash_seed = os.environ.get("PYTHONHASHSEED", "")
if python_hash_seed is not None and python_hash_seed.isdigit():
return int(python_hash_seed)
return None


@functools.lru_cache(maxsize=None) # warn once
def _warn_about_unsafe_hash():
message = """
PYTHONHASHSEED has not been set. For faster and safer hashing of SymPy expressions,
set the PYTHONHASHSEED environment variable to a fixed value and rerun the program.
See https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED
"""
message = dedent(message).replace("\n", " ").strip()
_LOGGER.warning(message)


class UnevaluatableIntegral(sp.Integral):
abs_tolerance = 1e-5
rel_tolerance = 1e-5
Expand Down Expand Up @@ -405,3 +332,42 @@ def _warn_if_scipy_not_installed() -> None:
" install ampform[scipy]'",
stacklevel=1,
)


def perform_cached_doit(
unevaluated_expr: sp.Expr, cache_directory: str | None = None
) -> sp.Expr:
"""Perform :meth:`~sympy.core.basic.Basic.doit` and cache the result to disk.

The cached result is fetched from disk if the hash of the original expression is the
same as the hash embedded in the filename (see :func:`.get_readable_hash`).

Args:
unevaluated_expr: A `sympy.Expr <sympy.core.expr.Expr>` on which to call
:meth:`~sympy.core.basic.Basic.doit`.
cache_directory: The directory in which to cache the result. Defaults to
:file:`ampform` under the system cache directory (see
:func:`.get_system_cache_directory`).

.. tip:: For a faster cache, set `PYTHONHASHSEED
<https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED>`_ to a
fixed value.

.. automodule:: ampform.sympy._cache
"""
if cache_directory is None:
system_cache_dir = get_system_cache_directory()
cache_directory = abspath(f"{system_cache_dir}/ampform")
h = get_readable_hash(unevaluated_expr)
filename = f"{cache_directory}/{h}.pkl"
os.makedirs(dirname(filename), exist_ok=True)
if os.path.exists(filename):
with open(filename, "rb") as f:
return pickle.load(f) # noqa: S301
_LOGGER.warning(
f"Cached expression file {filename} not found, performing doit()..."
)
unfolded_expr = unevaluated_expr.doit()
with open(filename, "wb") as f:
pickle.dump(unfolded_expr, f)
return unfolded_expr
87 changes: 87 additions & 0 deletions src/ampform/sympy/_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Helper functions for :func:`.perform_cached_doit`."""

from __future__ import annotations

import functools
import hashlib
import logging
import os
import pickle # noqa: S403
import sys
from textwrap import dedent

import sympy as sp

_LOGGER = logging.getLogger(__name__)


def get_system_cache_directory() -> str:
r"""Return the system cache directory for the current platform.

>>> import sys, pytest
>>> if sys.platform.startswith("darwin"):
... assert get_system_cache_directory().endswith("/Library/Caches")
>>> if sys.platform.startswith("linux"):
... assert get_system_cache_directory().endswith("/.cache")
>>> if sys.platform.startswith("win"):
... assert get_system_cache_directory().endswith(R"\AppData\Local")
"""
if sys.platform.startswith("linux"):
cache_directory = os.getenv("XDG_CACHE_HOME")
if cache_directory is not None:
return cache_directory
if sys.platform.startswith("darwin"): # macos
return os.path.expanduser("~/Library/Caches")
if sys.platform.startswith("win"):
cache_directory = os.getenv("LocalAppData") # noqa: SIM112
if cache_directory is not None:
return cache_directory
return os.path.expanduser("~/AppData/Local")
return os.path.expanduser("~/.cache")


def get_readable_hash(obj, ignore_hash_seed: bool = False) -> str:
"""Get a human-readable hash of any hashable Python object.

The algorithm is fastest if `PYTHONHASHSEED
<https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED>`_ is set.
Otherwise, it falls back to computing the hash with :func:`hashlib.sha256()`.

Args:
obj: Any hashable object, mutable or immutable, to be hashed.
ignore_hash_seed: Ignore the :code:`PYTHONHASHSEED` environment variable. If
:code:`True`, the hash seed is ignored and the hash is computed with
:func:`hashlib.sha256`.
"""
python_hash_seed = _get_python_hash_seed()
if ignore_hash_seed or python_hash_seed is None:
b = _to_bytes(obj)
return hashlib.sha256(b).hexdigest()
return f"pythonhashseed-{python_hash_seed}{hash(obj):+}"


def _to_bytes(obj) -> bytes:
if isinstance(obj, sp.Expr):
# Using the str printer is slower and not necessarily unique,
# but pickle.dumps() does not always result in the same bytes stream.
_warn_about_unsafe_hash()
return str(obj).encode()
return pickle.dumps(obj)


def _get_python_hash_seed() -> int | None:
python_hash_seed = os.environ.get("PYTHONHASHSEED", "")
if python_hash_seed is not None and python_hash_seed.isdigit():
return int(python_hash_seed)
return None


@functools.lru_cache(maxsize=None) # warn once
def _warn_about_unsafe_hash():
message = """
PYTHONHASHSEED has not been set. For faster and safer hashing of SymPy expressions,
set the PYTHONHASHSEED environment variable to a fixed value and rerun the program.
See https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED
"""
message = dedent(message).replace("\n", " ").strip()
_LOGGER.warning(message)
2 changes: 1 addition & 1 deletion tests/sympy/test_caching.py → tests/sympy/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import sympy as sp

from ampform.dynamics import EnergyDependentWidth
from ampform.sympy import _warn_about_unsafe_hash, get_readable_hash
from ampform.sympy._cache import _warn_about_unsafe_hash, get_readable_hash

if TYPE_CHECKING:
from _pytest.logging import LogCaptureFixture
Expand Down
Loading