Source code for ampform.sympy

# cspell:ignore mhash
# pylint: disable=invalid-getnewargs-ex-returned, protected-access
"""Tools that facilitate in building :mod:`sympy` expressions."""

import functools
import itertools
from abc import abstractmethod
from typing import (
    Any,
    Callable,
    Iterable,
    List,
    Optional,
    Sequence,
    Set,
    Tuple,
    Type,
    TypeVar,
    Union,
)

import sympy as sp
from sympy.printing.latex import LatexPrinter
from sympy.printing.numpy import NumPyPrinter
from sympy.printing.precedence import PRECEDENCE


[docs]class UnevaluatedExpression(sp.Expr): """Base class for expression classes with an :meth:`evaluate` method. Deriving from `~sympy.core.expr.Expr` allows us to keep expression trees condense before unfolding them with their `~sympy.core.basic.Basic.doit` method. This allows us to: 1. condense the LaTeX representation of an expression tree by providing a custom :meth:`_latex` method. 2. overwrite its printer methods (see `NumPyPrintable` and e.g. :doc:`compwa-org:report/001`). The `UnevaluatedExpression` base class makes implementations of its derived classes more secure by enforcing the developer to provide implementations for these methods, so that SymPy mechanisms work correctly. Decorators like :func:`implement_expr` and :func:`implement_doit_method` provide convenient means to implement the missing methods. .. autolink-preface:: import sympy as sp from ampform.sympy import UnevaluatedExpression, create_expression .. automethod:: __new__ .. automethod:: evaluate .. automethod:: _latex """ # https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L74-L77 __slots__: Tuple[str] = ("_name",) _name: Optional[str] """Optional instance attribute that can be used in LaTeX representations."""
[docs] def __new__( # pylint: disable=unused-argument cls: Type["DecoratedClass"], *args: Any, name: Optional[str] = None, **hints: Any, ) -> "DecoratedClass": """Constructor for a class derived from `UnevaluatedExpression`. This :meth:`~object.__new__` method correctly sets the `~sympy.core.basic.Basic.args`, assumptions etc. Overwrite it in order to further specify its signature. The function :func:`create_expression` can be used in its implementation, like so: >>> class MyExpression(UnevaluatedExpression): ... def __new__( ... cls, x: sp.Symbol, y: sp.Symbol, n: int, **hints ... ) -> "MyExpression": ... return create_expression(cls, x, y, n, **hints) ... ... def evaluate(self) -> sp.Expr: ... x, y, n = self.args ... return (x + y)**n ... >>> x, y = sp.symbols("x y") >>> expr = MyExpression(x, y, n=3) >>> expr MyExpression(x, y, 3) >>> expr.evaluate() (x + y)**3 """ # https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L113-L119 obj = object.__new__(cls) obj._args = args obj._assumptions = cls.default_assumptions obj._mhash = None obj._name = name return obj
def __getnewargs_ex__(self) -> Tuple[tuple, dict]: # Pickling support, see # https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L124-L126 args = tuple(self.args) kwargs = {"name": self._name} return args, kwargs
[docs] @abstractmethod def evaluate(self) -> sp.Expr: """Evaluate and 'unfold' this `UnevaluatedExpression` by one level. >>> from ampform.dynamics import BreakupMomentumSquared >>> issubclass(BreakupMomentumSquared, UnevaluatedExpression) True >>> s, m1, m2 = sp.symbols("s m1 m2") >>> expr = BreakupMomentumSquared(s, m1, m2) >>> expr BreakupMomentumSquared(s, m1, m2) >>> expr.evaluate() (s - (m1 - m2)**2)*(s - (m1 + m2)**2)/(4*s) >>> expr.doit(deep=False) (s - (m1 - m2)**2)*(s - (m1 + m2)**2)/(4*s) .. note:: When decorating this class with :func:`implement_doit_method`, its :meth:`evaluate` method is equivalent to :meth:`~sympy.core.basic.Basic.doit` with :code:`deep=False`. """
[docs] def _latex(self, printer: LatexPrinter, *args: Any) -> str: r"""Provide a mathematical Latex representation for pretty printing. >>> from ampform.dynamics import BreakupMomentumSquared >>> issubclass(BreakupMomentumSquared, UnevaluatedExpression) True >>> s, m1 = sp.symbols("s m1") >>> expr = BreakupMomentumSquared(s, m1, m1) >>> print(sp.latex(expr)) q^2\left(s\right) >>> print(sp.latex(expr.doit())) - m_{1}^{2} + \frac{s}{4} """ args = tuple(map(printer._print, self.args)) name = type(self).__name__ if self._name is not None: name = self._name return f"{name}{args}"
[docs]class NumPyPrintable(sp.Expr): r"""`~sympy.core.expr.Expr` class that can lambdify to NumPy code. This interface for classes that derive from `sympy.Expr <sympy.core.expr.Expr>` enforce the implementation of a :meth:`_numpycode` method in case the class does not correctly :func:`~sympy.utilities.lambdify.lambdify` to NumPy code. For more info on SymPy printers, see :doc:`sympy:modules/printing`. Several computational frameworks try to converge their interface to that of NumPy. See for instance `TensorFlow's NumPy API <https://www.tensorflow.org/guide/tf_numpy>`_ and `jax.numpy <https://jax.readthedocs.io/en/latest/jax.numpy.html>`_. This fact is used in `TensorWaves <https://tensorwaves.rtfd.io>`_ to :func:`~sympy.utilities.lambdify.lambdify` SymPy expressions to these different backends with the same lambdification code. .. note:: This interface differs from `UnevaluatedExpression` in that it **should not** implement an :meth:`.evaluate` (and therefore a :meth:`~sympy.core.basic.Basic.doit`) method. .. warning:: The implemented :meth:`_numpycode` method should countain as little SymPy computations as possible. Instead, it should get most information from its construction `~sympy.core.basic.Basic.args`, so that SymPy can use printer tricks like :func:`~sympy.simplify.cse_main.cse`, prior expanding with :meth:`~sympy.core.basic.Basic.doit`, and other simplifications that can make the generated code shorter. An example is the `.BoostZMatrix` class, which takes :math:`\beta` as input instead of the `.FourMomentumSymbol` from which :math:`\beta` is computed. .. automethod:: _numpycode """
[docs] @abstractmethod def _numpycode(self, printer: NumPyPrinter, *args: Any) -> str: """Lambdify this `NumPyPrintable` class to NumPy code."""
DecoratedClass = TypeVar("DecoratedClass", bound=UnevaluatedExpression) """`~typing.TypeVar` for decorators like :func:`make_commutative`."""
[docs]def implement_expr( n_args: int, ) -> Callable[[Type[DecoratedClass]], Type[DecoratedClass]]: """Decorator for classes that derive from `UnevaluatedExpression`. Implement a :meth:`~object.__new__` and :meth:`~sympy.core.basic.Basic.doit` method for a class that derives from `~sympy.core.expr.Expr` (via `UnevaluatedExpression`). """ def decorator( decorated_class: Type[DecoratedClass], ) -> Type[DecoratedClass]: decorated_class = implement_new_method(n_args)(decorated_class) decorated_class = implement_doit_method(decorated_class) return decorated_class return decorator
[docs]def implement_new_method( n_args: int, ) -> Callable[[Type[DecoratedClass]], Type[DecoratedClass]]: """Implement :meth:`UnevaluatedExpression.__new__` on a derived class. Implement a :meth:`~object.__new__` method for a class that derives from `~sympy.core.expr.Expr` (via `UnevaluatedExpression`). """ def decorator( decorated_class: Type[DecoratedClass], ) -> Type[DecoratedClass]: def new_method( # pylint: disable=unused-argument cls: Type, *args: sp.Symbol, evaluate: bool = False, **hints: Any, ) -> bool: if len(args) != n_args: raise ValueError( f"{n_args} parameters expected, got {len(args)}" ) args = sp.sympify(args) expr = UnevaluatedExpression.__new__(cls, *args) if evaluate: return expr.evaluate() return expr decorated_class.__new__ = new_method # type: ignore[assignment] return decorated_class return decorator
[docs]def implement_doit_method( decorated_class: Type[DecoratedClass], ) -> Type[DecoratedClass]: """Implement ``doit()`` method for an `UnevaluatedExpression` class. Implement a :meth:`~sympy.core.basic.Basic.doit` method for a class that derives from `~sympy.core.expr.Expr` (via `UnevaluatedExpression`). A :meth:`~sympy.core.basic.Basic.doit` method is an extension of an :meth:`~.UnevaluatedExpression.evaluate` method in the sense that it can work recursively on deeper expression trees. """ @functools.wraps(decorated_class.doit) # type: ignore[attr-defined] def doit_method(self: UnevaluatedExpression, deep: bool = True) -> sp.Expr: expr = self.evaluate() if deep: return expr.doit() return expr decorated_class.doit = doit_method # type: ignore[attr-defined] return decorated_class
def _implement_latex_subscript( # pyright: reportUnusedFunction=false subscript: str, ) -> Callable[[Type[UnevaluatedExpression]], Type[UnevaluatedExpression]]: def decorator( decorated_class: Type[UnevaluatedExpression], ) -> Type[UnevaluatedExpression]: # pylint: disable=protected-access, unused-argument @functools.wraps(decorated_class.doit) def _latex(self: sp.Expr, printer: LatexPrinter, *args: Any) -> str: momentum = printer._print(self._momentum) if printer._needs_mul_brackets(self._momentum): momentum = Rf"\left({momentum}\right)" else: momentum = Rf"{{{momentum}}}" return f"{momentum}_{subscript}" decorated_class._latex = _latex # type: ignore[assignment] return decorated_class return decorator
[docs]def make_commutative( decorated_class: Type[DecoratedClass], ) -> Type[DecoratedClass]: """Set commutative and 'extended real' assumptions on expression class. .. seealso:: :doc:`sympy:guides/assumptions` """ decorated_class.is_commutative = True # type: ignore[attr-defined] decorated_class.is_extended_real = True # type: ignore[attr-defined] return decorated_class
[docs]def create_expression( cls: Type[UnevaluatedExpression], *args: Any, evaluate: bool = False, name: Optional[str] = None, **kwargs: Any, ) -> sp.Expr: """Helper function for implementing `UnevaluatedExpression.__new__`.""" args = sp.sympify(args) expr = UnevaluatedExpression.__new__(cls, *args, name=name, **kwargs) if evaluate: return expr.evaluate() return expr
[docs]def create_symbol_matrix(name: str, m: int, n: int) -> sp.Matrix: """Create a `~sympy.matrices.dense.Matrix` with symbols as elements. The `~sympy.matrices.expressions.MatrixSymbol` has some issues when one is interested in the elements of the matrix. This function instead creates a `~sympy.matrices.dense.Matrix` where the elements are `~sympy.tensor.indexed.Indexed` instances. To convert these `~sympy.tensor.indexed.Indexed` instances to a `~sympy.core.symbol.Symbol`, use :func:`symplot.substitute_indexed_symbols`. >>> create_symbol_matrix("A", m=2, n=3) Matrix([ [A[0, 0], A[0, 1], A[0, 2]], [A[1, 0], A[1, 1], A[1, 2]]]) """ symbol = sp.IndexedBase(name, shape=(m, n)) return sp.Matrix([[symbol[i, j] for j in range(n)] for i in range(m)])
[docs]@implement_doit_method class PoolSum(UnevaluatedExpression): # pylint: disable=line-too-long r"""Sum over indices where the values are taken from a domain set. >>> i, j, m, n = sp.symbols("i j m n") >>> expr = PoolSum(i**m + j**n, (i, (-1, 0, +1)), (j, (2, 4, 5))) >>> expr PoolSum(i**m + j**n, (i, (-1, 0, 1)), (j, (2, 4, 5))) >>> print(sp.latex(expr)) \sum_{i=-1}^{1} \sum_{j\in\left\{2,4,5\right\}}{i^{m} + j^{n}} >>> expr.doit() 3*(-1)**m + 3*0**m + 3*2**n + 3*4**n + 3*5**n + 3 """ precedence = PRECEDENCE["Mul"] def __new__( cls, expression: sp.Expr, *indices: Tuple[sp.Symbol, Iterable[sp.Float]], **hints: Any, ) -> "PoolSum": converted_indices = [] for idx_symbol, values in indices: values = tuple(values) if len(values) == 0: raise ValueError(f"No values provided for index {idx_symbol}") converted_indices.append((idx_symbol, values)) return create_expression(cls, expression, *converted_indices, **hints) @property def expression(self) -> sp.Expr: return self.args[0] @property def indices(self) -> List[Tuple[sp.Symbol, Tuple[sp.Float, ...]]]: return self.args[1:] @property def free_symbols(self) -> Set[sp.Symbol]: return super().free_symbols - {s for s, _ in self.indices} def evaluate(self) -> sp.Expr: indices = {symbol: tuple(values) for symbol, values in self.indices} return sp.Add( *[ self.expression.subs(zip(indices, combi)) for combi in itertools.product(*indices.values()) ] ) def _latex(self, printer: LatexPrinter, *args: Any) -> str: indices = dict(self.indices) sum_symbols: List[str] = [] for idx, values in indices.items(): sum_symbols.append(_render_sum_symbol(printer, idx, values)) expression = printer._print(self.expression) return R" ".join(sum_symbols) + f"{{{expression}}}"
[docs] def cleanup(self) -> Union[sp.Expr, "PoolSum"]: """Remove redundant summations, like indices with one or no value. >>> x, i = sp.symbols("x i") >>> PoolSum(x**i, (i, [0, 1, 2])).cleanup().doit() x**2 + x + 1 >>> PoolSum(x, (i, [0, 1, 2])).cleanup() x >>> PoolSum(x).cleanup() x >>> PoolSum(x**i, (i, [0])).cleanup() 1 """ substitutions = {} new_indices = [] for idx, values in self.indices: if idx not in self.expression.free_symbols: continue if len(values) == 0: continue if len(values) == 1: substitutions[idx] = values[0] else: new_indices.append((idx, values)) new_expression = self.expression.xreplace(substitutions) if len(new_indices) == 0: return new_expression return PoolSum(new_expression, *new_indices)
def _render_sum_symbol( printer: LatexPrinter, idx: sp.Symbol, values: Sequence[float] ) -> str: if len(values) == 0: return "" idx = printer._print(idx) if len(values) == 1: value = values[0] return Rf"\sum_{{{idx}={value}}}" if _is_regular_series(values): sorted_values = sorted(values) first_value = sorted_values[0] last_value = sorted_values[-1] return Rf"\sum_{{{idx}={first_value}}}^{{{last_value}}}" idx_values = ",".join(map(printer._print, values)) return Rf"\sum_{{{idx}\in\left\{{{idx_values}\right\}}}}" def _is_regular_series(values: Sequence[float]) -> bool: """Check whether a set of values is a series with unit distances. >>> _is_regular_series([0, 1, 2]) True >>> _is_regular_series([-0.5, +0.5]) True >>> _is_regular_series([+0.5, -0.5, 1.5]) True >>> _is_regular_series([-1, +1]) False >>> _is_regular_series([1]) False >>> _is_regular_series([]) False """ if len(values) <= 1: return False sorted_values = sorted(values) for val, next_val in zip(sorted_values, sorted_values[1:]): difference = float(next_val - val) if difference != 1.0: return False return True