# cspell:ignore mhash
# pylint: disable=invalid-getnewargs-ex-returned, protected-access, W0223
# https://stackoverflow.com/a/22224042
"""Tools that facilitate in building :mod:`sympy` expressions."""
from __future__ import annotations
import functools
import itertools
from abc import abstractmethod
from typing import Callable, Iterable, Sequence, SupportsFloat, TypeVar
import sympy as sp
from sympy.printing.latex import LatexPrinter
from sympy.printing.numpy import NumPyPrinter
from sympy.printing.precedence import PRECEDENCE
[docs]class UnevaluatedExpression(sp.Expr):
"""Base class for expression classes with an :meth:`evaluate` method.
Deriving from `~sympy.core.expr.Expr` allows us to keep expression trees
condense before unfolding them with their `~sympy.core.basic.Basic.doit`
method. This allows us to:
1. condense the LaTeX representation of an expression tree by providing a
custom :meth:`_latex` method.
2. overwrite its printer methods (see `NumPyPrintable` and e.g.
:doc:`compwa-org:report/001`).
The `UnevaluatedExpression` base class makes implementations of its derived
classes more secure by enforcing the developer to provide implementations
for these methods, so that SymPy mechanisms work correctly. Decorators like
:func:`implement_expr` and :func:`implement_doit_method` provide convenient
means to implement the missing methods.
.. autolink-preface::
import sympy as sp
from ampform.sympy import UnevaluatedExpression, create_expression
.. automethod:: __new__
.. automethod:: evaluate
.. automethod:: _latex
"""
# https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L74-L77
__slots__: tuple[str] = ("_name",)
_name: str | None
"""Optional instance attribute that can be used in LaTeX representations."""
[docs] def __new__( # pylint: disable=unused-argument
cls: type[DecoratedClass],
*args,
name: str | None = None,
**hints,
) -> DecoratedClass:
"""Constructor for a class derived from `UnevaluatedExpression`.
This :meth:`~object.__new__` method correctly sets the
`~sympy.core.basic.Basic.args`, assumptions etc. Overwrite it in order
to further specify its signature. The function
:func:`create_expression` can be used in its implementation, like so:
>>> class MyExpression(UnevaluatedExpression):
... def __new__(
... cls, x: sp.Symbol, y: sp.Symbol, n: int, **hints
... ) -> "MyExpression":
... return create_expression(cls, x, y, n, **hints)
...
... def evaluate(self) -> sp.Expr:
... x, y, n = self.args
... return (x + y)**n
...
>>> x, y = sp.symbols("x y")
>>> expr = MyExpression(x, y, n=3)
>>> expr
MyExpression(x, y, 3)
>>> expr.evaluate()
(x + y)**3
"""
# https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L113-L119
obj = object.__new__(cls)
obj._args = args
obj._assumptions = cls.default_assumptions # type: ignore[attr-defined]
obj._mhash = None
obj._name = name
return obj
def __getnewargs_ex__(self) -> tuple[tuple, dict]:
# Pickling support, see
# https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L124-L126
args = tuple(self.args)
kwargs = {"name": self._name}
return args, kwargs
[docs] @abstractmethod
def evaluate(self) -> sp.Expr:
"""Evaluate and 'unfold' this `UnevaluatedExpression` by one level.
>>> from ampform.dynamics import BreakupMomentumSquared
>>> issubclass(BreakupMomentumSquared, UnevaluatedExpression)
True
>>> s, m1, m2 = sp.symbols("s m1 m2")
>>> expr = BreakupMomentumSquared(s, m1, m2)
>>> expr
BreakupMomentumSquared(s, m1, m2)
>>> expr.evaluate()
(s - (m1 - m2)**2)*(s - (m1 + m2)**2)/(4*s)
>>> expr.doit(deep=False)
(s - (m1 - m2)**2)*(s - (m1 + m2)**2)/(4*s)
.. note:: When decorating this class with :func:`implement_doit_method`,
its :meth:`evaluate` method is equivalent to
:meth:`~sympy.core.basic.Basic.doit` with :code:`deep=False`.
"""
[docs] def _latex(self, printer: LatexPrinter, *args) -> str:
r"""Provide a mathematical Latex representation for pretty printing.
>>> from ampform.dynamics import BreakupMomentumSquared
>>> issubclass(BreakupMomentumSquared, UnevaluatedExpression)
True
>>> s, m1 = sp.symbols("s m1")
>>> expr = BreakupMomentumSquared(s, m1, m1)
>>> print(sp.latex(expr))
q^2\left(s\right)
>>> print(sp.latex(expr.doit()))
- m_{1}^{2} + \frac{s}{4}
"""
args = tuple(map(printer._print, self.args))
name = type(self).__name__
if self._name is not None:
name = self._name
return f"{name}{args}"
[docs]class NumPyPrintable(sp.Expr):
r"""`~sympy.core.expr.Expr` class that can lambdify to NumPy code.
This interface for classes that derive from `sympy.Expr
<sympy.core.expr.Expr>` enforce the implementation of a :meth:`_numpycode`
method in case the class does not correctly
:func:`~sympy.utilities.lambdify.lambdify` to NumPy code. For more info on
SymPy printers, see :doc:`sympy:modules/printing`.
Several computational frameworks try to converge their interface to that of
NumPy. See for instance `TensorFlow's NumPy API
<https://www.tensorflow.org/guide/tf_numpy>`_ and `jax.numpy
<https://jax.readthedocs.io/en/latest/jax.numpy.html>`_. This fact is used
in `TensorWaves <https://tensorwaves.rtfd.io>`_ to
:func:`~sympy.utilities.lambdify.lambdify` SymPy expressions to these
different backends with the same lambdification code.
.. note:: This interface differs from `UnevaluatedExpression` in that it
**should not** implement an :meth:`.evaluate` (and therefore a
:meth:`~sympy.core.basic.Basic.doit`) method.
.. warning:: The implemented :meth:`_numpycode` method should countain as
little SymPy computations as possible. Instead, it should get most
information from its construction `~sympy.core.basic.Basic.args`, so
that SymPy can use printer tricks like
:func:`~sympy.simplify.cse_main.cse`, prior expanding with
:meth:`~sympy.core.basic.Basic.doit`, and other simplifications that
can make the generated code shorter. An example is the `.BoostZMatrix`
class, which takes :math:`\beta` as input instead of the
`.FourMomentumSymbol` from which :math:`\beta` is computed.
.. automethod:: _numpycode
"""
[docs] @abstractmethod
def _numpycode(self, printer: NumPyPrinter, *args) -> str:
"""Lambdify this `NumPyPrintable` class to NumPy code."""
DecoratedClass = TypeVar("DecoratedClass", bound=UnevaluatedExpression)
"""`~typing.TypeVar` for decorators like :func:`implement_doit_method`."""
[docs]def implement_expr(
n_args: int,
) -> Callable[[type[DecoratedClass]], type[DecoratedClass]]:
"""Decorator for classes that derive from `UnevaluatedExpression`.
Implement a :meth:`~object.__new__` and
:meth:`~sympy.core.basic.Basic.doit` method for a class that derives from
`~sympy.core.expr.Expr` (via `UnevaluatedExpression`).
"""
def decorator(
decorated_class: type[DecoratedClass],
) -> type[DecoratedClass]:
decorated_class = implement_new_method(n_args)(decorated_class)
decorated_class = implement_doit_method(decorated_class)
return decorated_class
return decorator
[docs]def implement_new_method(
n_args: int,
) -> Callable[[type[DecoratedClass]], type[DecoratedClass]]:
"""Implement :meth:`UnevaluatedExpression.__new__` on a derived class.
Implement a :meth:`~object.__new__` method for a class that derives from
`~sympy.core.expr.Expr` (via `UnevaluatedExpression`).
"""
def decorator(
decorated_class: type[DecoratedClass],
) -> type[DecoratedClass]:
def new_method( # pylint: disable=unused-argument
cls: type[DecoratedClass],
*args: sp.Symbol,
evaluate: bool = False,
**hints,
) -> DecoratedClass:
if len(args) != n_args:
raise ValueError(
f"{n_args} parameters expected, got {len(args)}"
)
args = sp.sympify(args)
expr = UnevaluatedExpression.__new__(cls, *args)
if evaluate:
return expr.evaluate() # type: ignore[return-value]
return expr
decorated_class.__new__ = new_method # type: ignore[assignment]
return decorated_class
return decorator
[docs]def implement_doit_method(
decorated_class: type[DecoratedClass],
) -> type[DecoratedClass]:
"""Implement ``doit()`` method for an `UnevaluatedExpression` class.
Implement a :meth:`~sympy.core.basic.Basic.doit` method for a class that
derives from `~sympy.core.expr.Expr` (via `UnevaluatedExpression`). A
:meth:`~sympy.core.basic.Basic.doit` method is an extension of an
:meth:`~.UnevaluatedExpression.evaluate` method in the sense that it can
work recursively on deeper expression trees.
"""
@functools.wraps(decorated_class.doit) # type: ignore[attr-defined]
def doit_method(self: UnevaluatedExpression, deep: bool = True) -> sp.Expr:
expr = self.evaluate()
if deep:
return expr.doit()
return expr
decorated_class.doit = doit_method # type: ignore[assignment]
return decorated_class
def _implement_latex_subscript( # pyright: reportUnusedFunction=false
subscript: str,
) -> Callable[[type[UnevaluatedExpression]], type[UnevaluatedExpression]]:
def decorator(
decorated_class: type[UnevaluatedExpression],
) -> type[UnevaluatedExpression]:
# pylint: disable=protected-access, unused-argument
@functools.wraps(decorated_class.doit)
def _latex(self: sp.Expr, printer: LatexPrinter, *args) -> str:
momentum = printer._print(self._momentum) # type: ignore[attr-defined]
if printer._needs_mul_brackets(self._momentum): # type: ignore[attr-defined]
momentum = Rf"\left({momentum}\right)"
else:
momentum = Rf"{{{momentum}}}"
return f"{momentum}_{subscript}"
decorated_class._latex = _latex # type: ignore[assignment]
return decorated_class
return decorator
DecoratedExpr = TypeVar("DecoratedExpr", bound=sp.Expr)
"""`~typing.TypeVar` for decorators like :func:`make_commutative`."""
[docs]def make_commutative(
decorated_class: type[DecoratedExpr],
) -> type[DecoratedExpr]:
"""Set commutative and 'extended real' assumptions on expression class.
.. seealso:: :doc:`sympy:guides/assumptions`
"""
decorated_class.is_commutative = True # type: ignore[attr-defined]
decorated_class.is_extended_real = True # type: ignore[attr-defined]
return decorated_class
[docs]def create_expression(
cls: type[DecoratedExpr],
*args,
evaluate: bool = False,
name: str | None = None,
**kwargs,
) -> DecoratedExpr:
"""Helper function for implementing `UnevaluatedExpression.__new__`."""
args = sp.sympify(args)
if issubclass(cls, UnevaluatedExpression):
expr = UnevaluatedExpression.__new__(cls, *args, name=name, **kwargs)
if evaluate:
return expr.evaluate() # type: ignore[return-value]
return expr # type: ignore[return-value]
return sp.Expr.__new__(cls, *args, **kwargs) # type: ignore[return-value]
[docs]def create_symbol_matrix(name: str, m: int, n: int) -> sp.MutableDenseMatrix:
"""Create a `~sympy.matrices.dense.Matrix` with symbols as elements.
The `~sympy.matrices.expressions.MatrixSymbol` has some issues when one is
interested in the elements of the matrix. This function instead creates a
`~sympy.matrices.dense.Matrix` where the elements are
`~sympy.tensor.indexed.Indexed` instances.
To convert these `~sympy.tensor.indexed.Indexed` instances to a
`~sympy.core.symbol.Symbol`, use
:func:`symplot.substitute_indexed_symbols`.
>>> create_symbol_matrix("A", m=2, n=3)
Matrix([
[A[0, 0], A[0, 1], A[0, 2]],
[A[1, 0], A[1, 1], A[1, 2]]])
"""
symbol = sp.IndexedBase(name, shape=(m, n))
return sp.Matrix([[symbol[i, j] for j in range(n)] for i in range(m)])
[docs]@implement_doit_method
class PoolSum(UnevaluatedExpression):
r"""Sum over indices where the values are taken from a domain set.
>>> i, j, m, n = sp.symbols("i j m n")
>>> expr = PoolSum(i**m + j**n, (i, (-1, 0, +1)), (j, (2, 4, 5)))
>>> expr
PoolSum(i**m + j**n, (i, (-1, 0, 1)), (j, (2, 4, 5)))
>>> print(sp.latex(expr))
\sum_{i=-1}^{1} \sum_{j\in\left\{2,4,5\right\}}{i^{m} + j^{n}}
>>> expr.doit()
3*(-1)**m + 3*0**m + 3*2**n + 3*4**n + 3*5**n + 3
"""
precedence = PRECEDENCE["Mul"]
def __new__(
cls,
expression,
*indices: tuple[sp.Symbol, Iterable[sp.Basic]],
**hints,
) -> PoolSum:
converted_indices = []
for idx_symbol, values in indices:
values = tuple(values)
if len(values) == 0:
raise ValueError(f"No values provided for index {idx_symbol}")
converted_indices.append((idx_symbol, values))
return create_expression(cls, expression, *converted_indices, **hints)
@property
def expression(self) -> sp.Expr:
return self.args[0] # type: ignore[return-value]
@property
def indices(self) -> list[tuple[sp.Symbol, tuple[sp.Float, ...]]]:
return self.args[1:] # type: ignore[return-value]
@property
def free_symbols(self) -> set[sp.Basic]:
return super().free_symbols - {s for s, _ in self.indices}
def evaluate(self) -> sp.Expr:
indices = {symbol: tuple(values) for symbol, values in self.indices}
return sp.Add(
*[
self.expression.subs(zip(indices, combi))
for combi in itertools.product(*indices.values())
]
)
def _latex(self, printer: LatexPrinter, *args) -> str:
indices = dict(self.indices)
sum_symbols: list[str] = []
for idx, values in indices.items():
sum_symbols.append(_render_sum_symbol(printer, idx, values))
expression = printer._print(self.expression)
return R" ".join(sum_symbols) + f"{{{expression}}}"
[docs] def cleanup(self) -> sp.Expr | PoolSum:
"""Remove redundant summations, like indices with one or no value.
>>> x, i = sp.symbols("x i")
>>> PoolSum(x**i, (i, [0, 1, 2])).cleanup().doit()
x**2 + x + 1
>>> PoolSum(x, (i, [0, 1, 2])).cleanup()
x
>>> PoolSum(x).cleanup()
x
>>> PoolSum(x**i, (i, [0])).cleanup()
1
"""
substitutions = {}
new_indices = []
for idx, values in self.indices:
if idx not in self.expression.free_symbols:
continue
if len(values) == 0:
continue
if len(values) == 1:
substitutions[idx] = values[0]
else:
new_indices.append((idx, values))
new_expression = self.expression.xreplace(substitutions)
if len(new_indices) == 0:
return new_expression
return PoolSum(new_expression, *new_indices)
def _render_sum_symbol(
printer: LatexPrinter, idx: sp.Symbol, values: Sequence[SupportsFloat]
) -> str:
if len(values) == 0:
return ""
idx = printer._print(idx)
if len(values) == 1:
value = values[0]
return Rf"\sum_{{{idx}={value}}}"
if _is_regular_series(values):
sorted_values = sorted(values, key=float)
first_value = sorted_values[0]
last_value = sorted_values[-1]
return Rf"\sum_{{{idx}={first_value}}}^{{{last_value}}}"
idx_values = ",".join(map(printer._print, values))
return Rf"\sum_{{{idx}\in\left\{{{idx_values}\right\}}}}"
def _is_regular_series(values: Sequence[SupportsFloat]) -> bool:
"""Check whether a set of values is a series with unit distances.
>>> _is_regular_series([0, 1, 2])
True
>>> _is_regular_series([-0.5, +0.5])
True
>>> _is_regular_series([+0.5, -0.5, 1.5])
True
>>> _is_regular_series([-1, +1])
False
>>> _is_regular_series([1])
False
>>> _is_regular_series([])
False
"""
if len(values) <= 1:
return False
sorted_values = sorted(values, key=float)
for val, next_val in zip(sorted_values, sorted_values[1:]):
difference = float(next_val) - float(val)
if difference != 1.0:
return False
return True