Skip to content

Commit

Permalink
DOC: explain usage of argument() function (#396)
Browse files Browse the repository at this point in the history
* ENH: emit warning if `_latex_repr_` is mistyped
* FIX: use correct docstring syntax
  • Loading branch information
redeboer authored Feb 12, 2024
1 parent bb17900 commit 58f5614
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 5 deletions.
66 changes: 66 additions & 0 deletions docs/usage/sympy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,72 @@
"Math(aslatex({e: e.doit() for e in exprs}))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"By default, instance attributes are converted ['sympified'](https://docs.sympy.org/latest/modules/core.html#module-sympy.core.sympify). To avoid this behavior, use the {func}`.argument` function."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import Callable\n",
"\n",
"from ampform.sympy import argument\n",
"\n",
"\n",
"class Transformation:\n",
" def __init__(self, power: int) -> None:\n",
" self.power = power\n",
"\n",
" def __call__(self, x: sp.Basic, y: sp.Basic) -> sp.Expr:\n",
" return x + y**self.power\n",
"\n",
"\n",
"@unevaluated\n",
"class MyExpr(sp.Expr):\n",
" x: Any\n",
" y: Any\n",
" functor: Callable = argument(sympify=False)\n",
"\n",
" def evaluate(self) -> sp.Expr:\n",
" return self.functor(self.x, self.y)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice how the `functor` attribute has not been sympified (there is no SymPy equivalent for a callable object), but the `functor` can be called in the `evaluate()`/`doit()` method."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"a, b, k = sp.symbols(\"a b k\")\n",
"expr = MyExpr(a, y=b, functor=Transformation(power=k))\n",
"assert expr.x is a\n",
"assert expr.y is b\n",
"assert not isinstance(expr.functor, sp.Basic)\n",
"Math(aslatex({expr: expr.doit()}))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
":::{tip}\n",
"An example where this is used, is in the {class}`.EnergyDependentWidth` class, where we do not want to sympify the {attr}`~.EnergyDependentWidth.phsp_factor` protocol.\n",
":::"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
12 changes: 7 additions & 5 deletions src/ampform/sympy/_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools
import inspect
import sys
import warnings
from collections import abc
from dataclasses import MISSING, Field
from dataclasses import astuple as _get_arguments
Expand Down Expand Up @@ -152,7 +153,7 @@ def unevaluated(
y
Attributes to the class are fed to the `~object.__new__` constructor of the
:class`~sympy.core.expr.Expr` class and are therefore also called "arguments". Just
:class:`~sympy.core.expr.Expr` class and are therefore also called "arguments". Just
like in the :class:`~sympy.core.expr.Expr` class, these arguments are automatically
`sympified
<https://docs.sympy.org/latest/modules/core.html#module-sympy.core.sympify>`_.
Expand Down Expand Up @@ -187,6 +188,11 @@ def decorator(cls: type[ExprClass]) -> type[ExprClass]:
cls = _implement_new_method(cls)
if implement_doit:
cls = _implement_doit(cls)
typos = ["_latex_repr"]
for typo in typos:
if hasattr(cls, typo):
msg = f"Class defines a {typo} attribute, but it should be _latex_repr_"
warnings.warn(msg, category=UserWarning, stacklevel=1)
if hasattr(cls, "_latex_repr_"):
cls = _implement_latex_repr(cls)
_set_assumptions(**assumptions)(cls)
Expand Down Expand Up @@ -345,10 +351,6 @@ def __call__(self, printer: LatexPrinter, *args) -> str: ...
@dataclass_transform(field_specifiers=(argument, _create_field))
def _implement_latex_repr(cls: type[T]) -> type[T]:
repr_name = "_latex_repr_"
repr_mistyped = "_latex_repr"
if hasattr(cls, repr_mistyped):
msg = f"Class defines a {repr_mistyped} attribute, but it should be {repr_name}"
raise AttributeError(msg)
_latex_repr_: LatexMethod | str | None = getattr(cls, repr_name, None)
if _latex_repr_ is None:
msg = (
Expand Down
16 changes: 16 additions & 0 deletions tests/sympy/decorator/test_unevaluated.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import inspect
from typing import Any, ClassVar

import pytest
import sympy as sp

from ampform.sympy._decorator import argument, unevaluated
Expand Down Expand Up @@ -124,6 +125,21 @@ class MyExpr(sp.Expr):
)


def test_latex_repr_typo_warning():
with pytest.warns(
UserWarning,
match=r"Class defines a _latex_repr attribute, but it should be _latex_repr_",
):

@unevaluated(real=False)
class MyExpr(sp.Expr): # pyright: ignore[reportUnusedClass]
x: sp.Symbol
_latex_repr = "<The attribute name is wrong>"

def evaluate(self) -> sp.Expr:
return self.x


def test_no_implement_doit():
@unevaluated(implement_doit=False)
class Squared(sp.Expr):
Expand Down

0 comments on commit 58f5614

Please sign in to comment.