1010
1111""" 
1212
13- # cspell:ignore mhash 
1413from  __future__ import  annotations 
1514
1615import  functools 
2322from  abc  import  abstractmethod 
2423from  os .path  import  abspath , dirname , expanduser 
2524from  textwrap  import  dedent 
26- from  typing  import  TYPE_CHECKING , Callable ,  Iterable , Sequence , SupportsFloat ,  TypeVar 
25+ from  typing  import  TYPE_CHECKING , Iterable , Sequence , SupportsFloat 
2726
2827import  sympy  as  sp 
2928from  sympy .printing .conventions  import  split_super_sub 
3534    argument ,  # noqa: F401  # pyright: ignore[reportUnusedImport] 
3635    unevaluated ,  # noqa: F401  # pyright: ignore[reportUnusedImport] 
3736)
37+ from  .deprecated  import  (
38+     UnevaluatedExpression ,  # noqa: F401  # pyright: ignore[reportUnusedImport] 
39+     create_expression ,  # noqa: F401  # pyright: ignore[reportUnusedImport] 
40+     implement_doit_method ,  # noqa: F401  # pyright: ignore[reportUnusedImport] 
41+     implement_expr ,  # pyright: ignore[reportUnusedImport]  # noqa: F401 
42+     make_commutative ,  # pyright: ignore[reportUnusedImport]  # noqa: F401 
43+ )
3844
3945if  TYPE_CHECKING :
4046    from  sympy .printing .latex  import  LatexPrinter 
4349_LOGGER  =  logging .getLogger (__name__ )
4450
4551
46- class  UnevaluatedExpression (sp .Expr ):
47-     """Base class for expression classes with an :meth:`evaluate` method. 
48- 
49-     Deriving from `~sympy.core.expr.Expr` allows us to keep expression trees condense 
50-     before unfolding them with their `~sympy.core.basic.Basic.doit` method. This allows 
51-     us to: 
52- 
53-     1. condense the LaTeX representation of an expression tree by providing a custom 
54-        :meth:`_latex` method. 
55-     2. overwrite its printer methods (see `NumPyPrintable` and e.g. 
56-        :doc:`compwa-org:report/001`). 
57- 
58-     The `UnevaluatedExpression` base class makes implementations of its derived classes 
59-     more secure by enforcing the developer to provide implementations for these methods, 
60-     so that SymPy mechanisms work correctly. Decorators like :func:`implement_expr` and 
61-     :func:`implement_doit_method` provide convenient means to implement the missing 
62-     methods. 
63- 
64-     .. autolink-preface:: 
65- 
66-         import sympy as sp 
67-         from ampform.sympy import UnevaluatedExpression, create_expression 
68- 
69-     .. automethod:: __new__ 
70-     .. automethod:: evaluate 
71-     .. automethod:: _latex 
72-     """ 
73- 
74-     # https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L74-L77 
75-     __slots__ : tuple [str ] =  ("_name" ,)
76-     _name : str  |  None 
77-     """Optional instance attribute that can be used in LaTeX representations.""" 
78- 
79-     def  __new__ (
80-         cls : type [DecoratedClass ],
81-         * args ,
82-         name : str  |  None  =  None ,
83-         ** hints ,
84-     ) ->  DecoratedClass :
85-         """Constructor for a class derived from `UnevaluatedExpression`. 
86- 
87-         This :meth:`~object.__new__` method correctly sets the 
88-         `~sympy.core.basic.Basic.args`, assumptions etc. Overwrite it in order to 
89-         further specify its signature. The function :func:`create_expression` can be 
90-         used in its implementation, like so: 
91- 
92-         >>> class MyExpression(UnevaluatedExpression): 
93-         ...    def __new__( 
94-         ...        cls, x: sp.Symbol, y: sp.Symbol, n: int, **hints 
95-         ...    ) -> "MyExpression": 
96-         ...        return create_expression(cls, x, y, n, **hints) 
97-         ... 
98-         ...    def evaluate(self) -> sp.Expr: 
99-         ...        x, y, n = self.args 
100-         ...        return (x + y)**n 
101-         ... 
102-         >>> x, y = sp.symbols("x y") 
103-         >>> expr = MyExpression(x, y, n=3) 
104-         >>> expr 
105-         MyExpression(x, y, 3) 
106-         >>> expr.evaluate() 
107-         (x + y)**3 
108-         """ 
109-         # https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L113-L119 
110-         obj  =  object .__new__ (cls )
111-         obj ._args  =  args 
112-         obj ._assumptions  =  cls .default_assumptions   # type: ignore[attr-defined] 
113-         obj ._mhash  =  None 
114-         obj ._name  =  name 
115-         return  obj 
116- 
117-     def  __getnewargs_ex__ (self ) ->  tuple [tuple , dict ]:
118-         # Pickling support, see 
119-         # https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L124-L126 
120-         args  =  tuple (self .args )
121-         kwargs  =  {"name" : self ._name }
122-         return  args , kwargs 
123- 
124-     def  _hashable_content (self ) ->  tuple :
125-         # https://github.com/sympy/sympy/blob/1.10/sympy/core/basic.py#L157-L165 
126-         # name is converted to string because unstable hash for None 
127-         return  (* super ()._hashable_content (), str (self ._name ))
128- 
129-     @abstractmethod  
130-     def  evaluate (self ) ->  sp .Expr :
131-         """Evaluate and 'unfold' this `UnevaluatedExpression` by one level. 
132- 
133-         >>> from ampform.dynamics import BreakupMomentumSquared 
134-         >>> s, m1, m2 = sp.symbols("s m1 m2") 
135-         >>> expr = BreakupMomentumSquared(s, m1, m2) 
136-         >>> expr 
137-         BreakupMomentumSquared(s, m1, m2) 
138-         >>> expr.evaluate() 
139-         (s - (m1 - m2)**2)*(s - (m1 + m2)**2)/(4*s) 
140-         >>> expr.doit(deep=False) 
141-         (s - (m1 - m2)**2)*(s - (m1 + m2)**2)/(4*s) 
142- 
143-         .. note:: When decorating this class with :func:`implement_doit_method`, 
144-             its :meth:`evaluate` method is equivalent to 
145-             :meth:`~sympy.core.basic.Basic.doit` with :code:`deep=False`. 
146-         """ 
147- 
148-     def  _latex (self , printer : LatexPrinter , * args ) ->  str :
149-         r"""Provide a mathematical Latex representation for pretty printing. 
150- 
151-         >>> from ampform.dynamics import BreakupMomentumSquared 
152-         >>> s, m1 = sp.symbols("s m1") 
153-         >>> expr = BreakupMomentumSquared(s, m1, m1) 
154-         >>> print(sp.latex(expr)) 
155-         q^2\left(s\right) 
156-         >>> print(sp.latex(expr.doit())) 
157-         - m_{1}^{2} + \frac{s}{4} 
158-         """ 
159-         args  =  tuple (map (printer ._print , self .args ))
160-         name  =  type (self ).__name__ 
161-         if  self ._name  is  not None :
162-             name  =  self ._name 
163-         return  f"{ name } { args }  
164- 
165- 
16652class  NumPyPrintable (sp .Expr ):
16753    r"""`~sympy.core.expr.Expr` class that can lambdify to NumPy code. 
16854
169-     This interface for classes that derive from `sympy.Expr <sympy.core.expr.Expr>` 
170-     enforce the implementation of  a :meth:`_numpycode` method in case the class does not 
171-     correctly  :func:`~sympy.utilities.lambdify.lambdify` to NumPy code. For more info on 
172-     SymPy  printers, see :doc:`sympy:modules/printing`. 
55+     This interface is  for classes that derive from `sympy.Expr <sympy.core.expr.Expr>` 
56+     and that require  a :meth:`_numpycode` method in case the class does not correctly  
57+     :func:`~sympy.utilities.lambdify.lambdify` to NumPy code. For more info on SymPy  
58+     printers, see :doc:`sympy:modules/printing`. 
17359
17460    Several computational frameworks try to converge their interface to that of NumPy. 
17561    See for instance `TensorFlow's NumPy API 
@@ -179,9 +65,9 @@ class NumPyPrintable(sp.Expr):
17965    :func:`~sympy.utilities.lambdify.lambdify` SymPy expressions to these different 
18066    backends with the same lambdification code. 
18167
182-     .. note :: This interface differs from `UnevaluatedExpression` in that it **should  
183-         not** implement an :meth:`.evaluate` (and therefore a  
184-         :meth:`~sympy.core.basic.Basic.doit`) method . 
68+     .. warning :: If you decorate this class with :func:`unevaluated`, you usually want  
69+         to do so with :code:`implement_doit=False`, because you do not want the class  
70+         to be 'unfolded' with  :meth:`~sympy.core.basic.Basic.doit` before lambdification . 
18571
18672
18773    .. warning:: The implemented :meth:`_numpycode` method should countain as little 
@@ -201,117 +87,6 @@ def _numpycode(self, printer: NumPyPrinter, *args) -> str:
20187        """Lambdify this `NumPyPrintable` class to NumPy code.""" 
20288
20389
204- DecoratedClass  =  TypeVar ("DecoratedClass" , bound = UnevaluatedExpression )
205- """`~typing.TypeVar` for decorators like :func:`implement_doit_method`.""" 
206- 
207- 
208- def  implement_expr (
209-     n_args : int ,
210- ) ->  Callable [[type [DecoratedClass ]], type [DecoratedClass ]]:
211-     """Decorator for classes that derive from `UnevaluatedExpression`. 
212- 
213-     Implement a :meth:`~object.__new__` and :meth:`~sympy.core.basic.Basic.doit` method 
214-     for a class that derives from `~sympy.core.expr.Expr` (via `UnevaluatedExpression`). 
215-     """ 
216- 
217-     def  decorator (
218-         decorated_class : type [DecoratedClass ],
219-     ) ->  type [DecoratedClass ]:
220-         decorated_class  =  implement_new_method (n_args )(decorated_class )
221-         return  implement_doit_method (decorated_class )
222- 
223-     return  decorator 
224- 
225- 
226- def  implement_new_method (
227-     n_args : int ,
228- ) ->  Callable [[type [DecoratedClass ]], type [DecoratedClass ]]:
229-     """Implement :meth:`UnevaluatedExpression.__new__` on a derived class. 
230- 
231-     Implement a :meth:`~object.__new__` method for a class that derives from 
232-     `~sympy.core.expr.Expr` (via `UnevaluatedExpression`). 
233-     """ 
234- 
235-     def  decorator (
236-         decorated_class : type [DecoratedClass ],
237-     ) ->  type [DecoratedClass ]:
238-         def  new_method (
239-             cls : type [DecoratedClass ],
240-             * args : sp .Symbol ,
241-             evaluate : bool  =  False ,
242-             ** hints ,
243-         ) ->  DecoratedClass :
244-             if  len (args ) !=  n_args :
245-                 msg  =  f"{ n_args } { len (args )}  
246-                 raise  ValueError (msg )
247-             args  =  sp .sympify (args )
248-             expr  =  UnevaluatedExpression .__new__ (cls , * args )
249-             if  evaluate :
250-                 return  expr .evaluate ()  # type: ignore[return-value] 
251-             return  expr 
252- 
253-         decorated_class .__new__  =  new_method   # type: ignore[assignment] 
254-         return  decorated_class 
255- 
256-     return  decorator 
257- 
258- 
259- def  implement_doit_method (
260-     decorated_class : type [DecoratedClass ],
261- ) ->  type [DecoratedClass ]:
262-     """Implement ``doit()`` method for an `UnevaluatedExpression` class. 
263- 
264-     Implement a :meth:`~sympy.core.basic.Basic.doit` method for a class that derives 
265-     from `~sympy.core.expr.Expr` (via `UnevaluatedExpression`). A 
266-     :meth:`~sympy.core.basic.Basic.doit` method is an extension of an 
267-     :meth:`~.UnevaluatedExpression.evaluate` method in the sense that it can work 
268-     recursively on deeper expression trees. 
269-     """ 
270- 
271-     @functools .wraps (decorated_class .doit )  # type: ignore[attr-defined]  
272-     def  doit_method (self : UnevaluatedExpression , deep : bool  =  True ) ->  sp .Expr :
273-         expr  =  self .evaluate ()
274-         if  deep :
275-             return  expr .doit ()
276-         return  expr 
277- 
278-     decorated_class .doit  =  doit_method   # type: ignore[assignment] 
279-     return  decorated_class 
280- 
281- 
282- DecoratedExpr  =  TypeVar ("DecoratedExpr" , bound = sp .Expr )
283- """`~typing.TypeVar` for decorators like :func:`make_commutative`.""" 
284- 
285- 
286- def  make_commutative (
287-     decorated_class : type [DecoratedExpr ],
288- ) ->  type [DecoratedExpr ]:
289-     """Set commutative and 'extended real' assumptions on expression class. 
290- 
291-     .. seealso:: :doc:`sympy:guides/assumptions` 
292-     """ 
293-     decorated_class .is_commutative  =  True   # type: ignore[attr-defined] 
294-     decorated_class .is_extended_real  =  True   # type: ignore[attr-defined] 
295-     return  decorated_class 
296- 
297- 
298- def  create_expression (
299-     cls : type [DecoratedExpr ],
300-     * args ,
301-     evaluate : bool  =  False ,
302-     name : str  |  None  =  None ,
303-     ** kwargs ,
304- ) ->  DecoratedExpr :
305-     """Helper function for implementing `UnevaluatedExpression.__new__`.""" 
306-     args  =  sp .sympify (args )
307-     if  issubclass (cls , UnevaluatedExpression ):
308-         expr  =  UnevaluatedExpression .__new__ (cls , * args , name = name , ** kwargs )
309-         if  evaluate :
310-             return  expr .evaluate ()  # type: ignore[return-value] 
311-         return  expr   # type: ignore[return-value] 
312-     return  sp .Expr .__new__ (cls , * args , ** kwargs )  # type: ignore[return-value] 
313- 
314- 
31590def  create_symbol_matrix (name : str , m : int , n : int ) ->  sp .MutableDenseMatrix :
31691    """Create a `~sympy.matrices.dense.Matrix` with symbols as elements. 
31792
@@ -332,8 +107,7 @@ def create_symbol_matrix(name: str, m: int, n: int) -> sp.MutableDenseMatrix:
332107    return  sp .Matrix ([[symbol [i , j ] for  j  in  range (n )] for  i  in  range (m )])
333108
334109
335- @implement_doit_method  
336- class  PoolSum (UnevaluatedExpression ):
110+ class  PoolSum (sp .Expr ):
337111    r"""Sum over indices where the values are taken from a domain set. 
338112
339113    >>> i, j, m, n = sp.symbols("i j m n") 
@@ -352,6 +126,7 @@ def __new__(
352126        cls ,
353127        expression ,
354128        * indices : tuple [sp .Symbol , Iterable [sp .Basic ]],
129+         evaluate : bool  =  False ,
355130        ** hints ,
356131    ) ->  PoolSum :
357132        converted_indices  =  []
@@ -361,7 +136,11 @@ def __new__(
361136                msg  =  f"No values provided for index { idx_symbol }  
362137                raise  ValueError (msg )
363138            converted_indices .append ((idx_symbol , values ))
364-         return  create_expression (cls , expression , * converted_indices , ** hints )
139+         args  =  sp .sympify ((expression , * converted_indices ))
140+         expr : PoolSum  =  sp .Expr .__new__ (cls , * args , ** hints )
141+         if  evaluate :
142+             return  expr .evaluate ()  # type: ignore[return-value] 
143+         return  expr 
365144
366145    @property  
367146    def  expression (self ) ->  sp .Expr :
@@ -375,6 +154,12 @@ def indices(self) -> list[tuple[sp.Symbol, tuple[sp.Float, ...]]]:
375154    def  free_symbols (self ) ->  set [sp .Basic ]:
376155        return  super ().free_symbols  -  {s  for  s , _  in  self .indices }
377156
157+     def  doit (self , deep : bool  =  True ) ->  sp .Expr :  # type: ignore[override] 
158+         expr  =  self .evaluate ()
159+         if  deep :
160+             return  expr .doit ()
161+         return  expr 
162+ 
378163    def  evaluate (self ) ->  sp .Expr :
379164        indices  =  {symbol : tuple (values ) for  symbol , values  in  self .indices }
380165        return  sp .Add (* [
0 commit comments