Skip to content

Commit af33921

Browse files
authored
BREAK: deprecate UnevaluatedExpression templates (#383)
* BREAK: issue deprecation warnings from `deprecated` module * MAINT: move expression classes to `ampform.sympy.deprecated` * MAINT: remove remaining `UnevaluatedExpresssion` calls and related
1 parent 7cd9a32 commit af33921

File tree

7 files changed

+338
-266
lines changed

7 files changed

+338
-266
lines changed

docs/conf.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,30 +49,28 @@
4949
add_module_names = False
5050
api_github_repo = f"{ORGANIZATION}/{REPO_NAME}"
5151
api_target_substitutions: dict[str, str | tuple[str, str]] = {
52-
"T": "TypeVar",
5352
"BuilderReturnType": ("obj", "ampform.dynamics.builder.BuilderReturnType"),
54-
"DecoratedClass": ("obj", "ampform.sympy.DecoratedClass"),
55-
"DecoratedExpr": ("obj", "ampform.sympy.DecoratedExpr"),
56-
"ExprClass": "ampform.sympy.ExprClass",
53+
"DecoratedClass": ("obj", "ampform.sympy.deprecated.DecoratedClass"),
54+
"DecoratedExpr": ("obj", "ampform.sympy.deprecated.DecoratedExpr"),
5755
"FourMomenta": ("obj", "ampform.kinematics.lorentz.FourMomenta"),
5856
"FourMomentumSymbol": ("obj", "ampform.kinematics.lorentz.FourMomentumSymbol"),
5957
"InteractionProperties": "qrules.quantum_numbers.InteractionProperties",
6058
"LatexPrinter": "sympy.printing.printer.Printer",
6159
"Literal[(-1, 1)]": "typing.Literal",
6260
"Literal[-1, 1]": "typing.Literal",
61+
"NumPyPrintable": ("class", "ampform.sympy.NumPyPrintable"),
6362
"NumPyPrinter": "sympy.printing.printer.Printer",
6463
"ParameterValue": ("obj", "ampform.helicity.ParameterValue"),
6564
"Particle": "qrules.particle.Particle",
6665
"ReactionInfo": "qrules.transition.ReactionInfo",
6766
"Slider": ("obj", "symplot.Slider"),
6867
"State": "qrules.transition.State",
6968
"StateTransition": "qrules.transition.StateTransition",
69+
"T": "TypeVar",
7070
"Topology": "qrules.topology.Topology",
7171
"WignerD": "sympy.physics.quantum.spin.WignerD",
7272
"ampform.helicity._T": "typing.TypeVar",
7373
"ampform.sympy._decorator.ExprClass": ("obj", "ampform.sympy.ExprClass"),
74-
"ampform.sympy._decorator.SymPyAssumptions": "ampform.sympy.SymPyAssumptions",
75-
"an object providing a view on D's values": "typing.ValuesView",
7674
"sp.Basic": "sympy.core.basic.Basic",
7775
"sp.Expr": "sympy.core.expr.Expr",
7876
"sp.Float": "sympy.core.numbers.Float",
@@ -289,7 +287,9 @@
289287
nb_output_stderr = "remove"
290288
nitpick_ignore = [
291289
("py:class", "ArraySum"),
290+
("py:class", "ExprClass"),
292291
("py:class", "MatrixMultiplication"),
292+
("py:class", "ampform.sympy._decorator.SymPyAssumptions"),
293293
]
294294
nitpicky = True
295295
primary_domain = "py"

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ filterwarnings = [
249249
"error",
250250
"ignore:.*invalid value encountered in sqrt.*:RuntimeWarning",
251251
"ignore:.*is deprecated and slated for removal in Python 3.14:DeprecationWarning",
252+
"ignore:.*the @ampform.sympy.unevaluated_expression decorator instead( with commutative=True)?:DeprecationWarning",
252253
"ignore:Passing a schema to Validator.iter_errors is deprecated.*:DeprecationWarning",
253254
"ignore:The .* argument to NotebookFile is deprecated.*:pytest.PytestRemovedIn8Warning",
254255
"ignore:The distutils package is deprecated.*:DeprecationWarning",

src/ampform/sympy/__init__.py

Lines changed: 28 additions & 243 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
1111
"""
1212

13-
# cspell:ignore mhash
1413
from __future__ import annotations
1514

1615
import functools
@@ -23,7 +22,7 @@
2322
from abc import abstractmethod
2423
from os.path import abspath, dirname, expanduser
2524
from textwrap import dedent
26-
from typing import TYPE_CHECKING, Callable, Iterable, Sequence, SupportsFloat, TypeVar
25+
from typing import TYPE_CHECKING, Iterable, Sequence, SupportsFloat
2726

2827
import sympy as sp
2928
from sympy.printing.conventions import split_super_sub
@@ -35,6 +34,13 @@
3534
argument, # noqa: F401 # pyright: ignore[reportUnusedImport]
3635
unevaluated, # noqa: F401 # pyright: ignore[reportUnusedImport]
3736
)
37+
from .deprecated import (
38+
UnevaluatedExpression, # noqa: F401 # pyright: ignore[reportUnusedImport]
39+
create_expression, # noqa: F401 # pyright: ignore[reportUnusedImport]
40+
implement_doit_method, # noqa: F401 # pyright: ignore[reportUnusedImport]
41+
implement_expr, # pyright: ignore[reportUnusedImport] # noqa: F401
42+
make_commutative, # pyright: ignore[reportUnusedImport] # noqa: F401
43+
)
3844

3945
if TYPE_CHECKING:
4046
from sympy.printing.latex import LatexPrinter
@@ -43,133 +49,13 @@
4349
_LOGGER = logging.getLogger(__name__)
4450

4551

46-
class UnevaluatedExpression(sp.Expr):
47-
"""Base class for expression classes with an :meth:`evaluate` method.
48-
49-
Deriving from `~sympy.core.expr.Expr` allows us to keep expression trees condense
50-
before unfolding them with their `~sympy.core.basic.Basic.doit` method. This allows
51-
us to:
52-
53-
1. condense the LaTeX representation of an expression tree by providing a custom
54-
:meth:`_latex` method.
55-
2. overwrite its printer methods (see `NumPyPrintable` and e.g.
56-
:doc:`compwa-org:report/001`).
57-
58-
The `UnevaluatedExpression` base class makes implementations of its derived classes
59-
more secure by enforcing the developer to provide implementations for these methods,
60-
so that SymPy mechanisms work correctly. Decorators like :func:`implement_expr` and
61-
:func:`implement_doit_method` provide convenient means to implement the missing
62-
methods.
63-
64-
.. autolink-preface::
65-
66-
import sympy as sp
67-
from ampform.sympy import UnevaluatedExpression, create_expression
68-
69-
.. automethod:: __new__
70-
.. automethod:: evaluate
71-
.. automethod:: _latex
72-
"""
73-
74-
# https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L74-L77
75-
__slots__: tuple[str] = ("_name",)
76-
_name: str | None
77-
"""Optional instance attribute that can be used in LaTeX representations."""
78-
79-
def __new__(
80-
cls: type[DecoratedClass],
81-
*args,
82-
name: str | None = None,
83-
**hints,
84-
) -> DecoratedClass:
85-
"""Constructor for a class derived from `UnevaluatedExpression`.
86-
87-
This :meth:`~object.__new__` method correctly sets the
88-
`~sympy.core.basic.Basic.args`, assumptions etc. Overwrite it in order to
89-
further specify its signature. The function :func:`create_expression` can be
90-
used in its implementation, like so:
91-
92-
>>> class MyExpression(UnevaluatedExpression):
93-
... def __new__(
94-
... cls, x: sp.Symbol, y: sp.Symbol, n: int, **hints
95-
... ) -> "MyExpression":
96-
... return create_expression(cls, x, y, n, **hints)
97-
...
98-
... def evaluate(self) -> sp.Expr:
99-
... x, y, n = self.args
100-
... return (x + y)**n
101-
...
102-
>>> x, y = sp.symbols("x y")
103-
>>> expr = MyExpression(x, y, n=3)
104-
>>> expr
105-
MyExpression(x, y, 3)
106-
>>> expr.evaluate()
107-
(x + y)**3
108-
"""
109-
# https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L113-L119
110-
obj = object.__new__(cls)
111-
obj._args = args
112-
obj._assumptions = cls.default_assumptions # type: ignore[attr-defined]
113-
obj._mhash = None
114-
obj._name = name
115-
return obj
116-
117-
def __getnewargs_ex__(self) -> tuple[tuple, dict]:
118-
# Pickling support, see
119-
# https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L124-L126
120-
args = tuple(self.args)
121-
kwargs = {"name": self._name}
122-
return args, kwargs
123-
124-
def _hashable_content(self) -> tuple:
125-
# https://github.com/sympy/sympy/blob/1.10/sympy/core/basic.py#L157-L165
126-
# name is converted to string because unstable hash for None
127-
return (*super()._hashable_content(), str(self._name))
128-
129-
@abstractmethod
130-
def evaluate(self) -> sp.Expr:
131-
"""Evaluate and 'unfold' this `UnevaluatedExpression` by one level.
132-
133-
>>> from ampform.dynamics import BreakupMomentumSquared
134-
>>> s, m1, m2 = sp.symbols("s m1 m2")
135-
>>> expr = BreakupMomentumSquared(s, m1, m2)
136-
>>> expr
137-
BreakupMomentumSquared(s, m1, m2)
138-
>>> expr.evaluate()
139-
(s - (m1 - m2)**2)*(s - (m1 + m2)**2)/(4*s)
140-
>>> expr.doit(deep=False)
141-
(s - (m1 - m2)**2)*(s - (m1 + m2)**2)/(4*s)
142-
143-
.. note:: When decorating this class with :func:`implement_doit_method`,
144-
its :meth:`evaluate` method is equivalent to
145-
:meth:`~sympy.core.basic.Basic.doit` with :code:`deep=False`.
146-
"""
147-
148-
def _latex(self, printer: LatexPrinter, *args) -> str:
149-
r"""Provide a mathematical Latex representation for pretty printing.
150-
151-
>>> from ampform.dynamics import BreakupMomentumSquared
152-
>>> s, m1 = sp.symbols("s m1")
153-
>>> expr = BreakupMomentumSquared(s, m1, m1)
154-
>>> print(sp.latex(expr))
155-
q^2\left(s\right)
156-
>>> print(sp.latex(expr.doit()))
157-
- m_{1}^{2} + \frac{s}{4}
158-
"""
159-
args = tuple(map(printer._print, self.args))
160-
name = type(self).__name__
161-
if self._name is not None:
162-
name = self._name
163-
return f"{name}{args}"
164-
165-
16652
class NumPyPrintable(sp.Expr):
16753
r"""`~sympy.core.expr.Expr` class that can lambdify to NumPy code.
16854
169-
This interface for classes that derive from `sympy.Expr <sympy.core.expr.Expr>`
170-
enforce the implementation of a :meth:`_numpycode` method in case the class does not
171-
correctly :func:`~sympy.utilities.lambdify.lambdify` to NumPy code. For more info on
172-
SymPy printers, see :doc:`sympy:modules/printing`.
55+
This interface is for classes that derive from `sympy.Expr <sympy.core.expr.Expr>`
56+
and that require a :meth:`_numpycode` method in case the class does not correctly
57+
:func:`~sympy.utilities.lambdify.lambdify` to NumPy code. For more info on SymPy
58+
printers, see :doc:`sympy:modules/printing`.
17359
17460
Several computational frameworks try to converge their interface to that of NumPy.
17561
See for instance `TensorFlow's NumPy API
@@ -179,9 +65,9 @@ class NumPyPrintable(sp.Expr):
17965
:func:`~sympy.utilities.lambdify.lambdify` SymPy expressions to these different
18066
backends with the same lambdification code.
18167
182-
.. note:: This interface differs from `UnevaluatedExpression` in that it **should
183-
not** implement an :meth:`.evaluate` (and therefore a
184-
:meth:`~sympy.core.basic.Basic.doit`) method.
68+
.. warning:: If you decorate this class with :func:`unevaluated`, you usually want
69+
to do so with :code:`implement_doit=False`, because you do not want the class
70+
to be 'unfolded' with :meth:`~sympy.core.basic.Basic.doit` before lambdification.
18571
18672
18773
.. warning:: The implemented :meth:`_numpycode` method should countain as little
@@ -201,117 +87,6 @@ def _numpycode(self, printer: NumPyPrinter, *args) -> str:
20187
"""Lambdify this `NumPyPrintable` class to NumPy code."""
20288

20389

204-
DecoratedClass = TypeVar("DecoratedClass", bound=UnevaluatedExpression)
205-
"""`~typing.TypeVar` for decorators like :func:`implement_doit_method`."""
206-
207-
208-
def implement_expr(
209-
n_args: int,
210-
) -> Callable[[type[DecoratedClass]], type[DecoratedClass]]:
211-
"""Decorator for classes that derive from `UnevaluatedExpression`.
212-
213-
Implement a :meth:`~object.__new__` and :meth:`~sympy.core.basic.Basic.doit` method
214-
for a class that derives from `~sympy.core.expr.Expr` (via `UnevaluatedExpression`).
215-
"""
216-
217-
def decorator(
218-
decorated_class: type[DecoratedClass],
219-
) -> type[DecoratedClass]:
220-
decorated_class = implement_new_method(n_args)(decorated_class)
221-
return implement_doit_method(decorated_class)
222-
223-
return decorator
224-
225-
226-
def implement_new_method(
227-
n_args: int,
228-
) -> Callable[[type[DecoratedClass]], type[DecoratedClass]]:
229-
"""Implement :meth:`UnevaluatedExpression.__new__` on a derived class.
230-
231-
Implement a :meth:`~object.__new__` method for a class that derives from
232-
`~sympy.core.expr.Expr` (via `UnevaluatedExpression`).
233-
"""
234-
235-
def decorator(
236-
decorated_class: type[DecoratedClass],
237-
) -> type[DecoratedClass]:
238-
def new_method(
239-
cls: type[DecoratedClass],
240-
*args: sp.Symbol,
241-
evaluate: bool = False,
242-
**hints,
243-
) -> DecoratedClass:
244-
if len(args) != n_args:
245-
msg = f"{n_args} parameters expected, got {len(args)}"
246-
raise ValueError(msg)
247-
args = sp.sympify(args)
248-
expr = UnevaluatedExpression.__new__(cls, *args)
249-
if evaluate:
250-
return expr.evaluate() # type: ignore[return-value]
251-
return expr
252-
253-
decorated_class.__new__ = new_method # type: ignore[assignment]
254-
return decorated_class
255-
256-
return decorator
257-
258-
259-
def implement_doit_method(
260-
decorated_class: type[DecoratedClass],
261-
) -> type[DecoratedClass]:
262-
"""Implement ``doit()`` method for an `UnevaluatedExpression` class.
263-
264-
Implement a :meth:`~sympy.core.basic.Basic.doit` method for a class that derives
265-
from `~sympy.core.expr.Expr` (via `UnevaluatedExpression`). A
266-
:meth:`~sympy.core.basic.Basic.doit` method is an extension of an
267-
:meth:`~.UnevaluatedExpression.evaluate` method in the sense that it can work
268-
recursively on deeper expression trees.
269-
"""
270-
271-
@functools.wraps(decorated_class.doit) # type: ignore[attr-defined]
272-
def doit_method(self: UnevaluatedExpression, deep: bool = True) -> sp.Expr:
273-
expr = self.evaluate()
274-
if deep:
275-
return expr.doit()
276-
return expr
277-
278-
decorated_class.doit = doit_method # type: ignore[assignment]
279-
return decorated_class
280-
281-
282-
DecoratedExpr = TypeVar("DecoratedExpr", bound=sp.Expr)
283-
"""`~typing.TypeVar` for decorators like :func:`make_commutative`."""
284-
285-
286-
def make_commutative(
287-
decorated_class: type[DecoratedExpr],
288-
) -> type[DecoratedExpr]:
289-
"""Set commutative and 'extended real' assumptions on expression class.
290-
291-
.. seealso:: :doc:`sympy:guides/assumptions`
292-
"""
293-
decorated_class.is_commutative = True # type: ignore[attr-defined]
294-
decorated_class.is_extended_real = True # type: ignore[attr-defined]
295-
return decorated_class
296-
297-
298-
def create_expression(
299-
cls: type[DecoratedExpr],
300-
*args,
301-
evaluate: bool = False,
302-
name: str | None = None,
303-
**kwargs,
304-
) -> DecoratedExpr:
305-
"""Helper function for implementing `UnevaluatedExpression.__new__`."""
306-
args = sp.sympify(args)
307-
if issubclass(cls, UnevaluatedExpression):
308-
expr = UnevaluatedExpression.__new__(cls, *args, name=name, **kwargs)
309-
if evaluate:
310-
return expr.evaluate() # type: ignore[return-value]
311-
return expr # type: ignore[return-value]
312-
return sp.Expr.__new__(cls, *args, **kwargs) # type: ignore[return-value]
313-
314-
31590
def create_symbol_matrix(name: str, m: int, n: int) -> sp.MutableDenseMatrix:
31691
"""Create a `~sympy.matrices.dense.Matrix` with symbols as elements.
31792
@@ -332,8 +107,7 @@ def create_symbol_matrix(name: str, m: int, n: int) -> sp.MutableDenseMatrix:
332107
return sp.Matrix([[symbol[i, j] for j in range(n)] for i in range(m)])
333108

334109

335-
@implement_doit_method
336-
class PoolSum(UnevaluatedExpression):
110+
class PoolSum(sp.Expr):
337111
r"""Sum over indices where the values are taken from a domain set.
338112
339113
>>> i, j, m, n = sp.symbols("i j m n")
@@ -352,6 +126,7 @@ def __new__(
352126
cls,
353127
expression,
354128
*indices: tuple[sp.Symbol, Iterable[sp.Basic]],
129+
evaluate: bool = False,
355130
**hints,
356131
) -> PoolSum:
357132
converted_indices = []
@@ -361,7 +136,11 @@ def __new__(
361136
msg = f"No values provided for index {idx_symbol}"
362137
raise ValueError(msg)
363138
converted_indices.append((idx_symbol, values))
364-
return create_expression(cls, expression, *converted_indices, **hints)
139+
args = sp.sympify((expression, *converted_indices))
140+
expr: PoolSum = sp.Expr.__new__(cls, *args, **hints)
141+
if evaluate:
142+
return expr.evaluate() # type: ignore[return-value]
143+
return expr
365144

366145
@property
367146
def expression(self) -> sp.Expr:
@@ -375,6 +154,12 @@ def indices(self) -> list[tuple[sp.Symbol, tuple[sp.Float, ...]]]:
375154
def free_symbols(self) -> set[sp.Basic]:
376155
return super().free_symbols - {s for s, _ in self.indices}
377156

157+
def doit(self, deep: bool = True) -> sp.Expr: # type: ignore[override]
158+
expr = self.evaluate()
159+
if deep:
160+
return expr.doit()
161+
return expr
162+
378163
def evaluate(self) -> sp.Expr:
379164
indices = {symbol: tuple(values) for symbol, values in self.indices}
380165
return sp.Add(*[

0 commit comments

Comments
 (0)