Source code for ampform.sympy

# cspell:ignore mhash
# pylint: disable=invalid-getnewargs-ex-returned, protected-access, W0223
# https://stackoverflow.com/a/22224042
"""Tools that facilitate in building :mod:`sympy` expressions."""
from __future__ import annotations

import functools
import itertools
from abc import abstractmethod
from typing import Callable, Iterable, Sequence, SupportsFloat, TypeVar

import sympy as sp
from sympy.printing.latex import LatexPrinter
from sympy.printing.numpy import NumPyPrinter
from sympy.printing.precedence import PRECEDENCE


[docs]class UnevaluatedExpression(sp.Expr): """Base class for expression classes with an :meth:`evaluate` method. Deriving from `~sympy.core.expr.Expr` allows us to keep expression trees condense before unfolding them with their `~sympy.core.basic.Basic.doit` method. This allows us to: 1. condense the LaTeX representation of an expression tree by providing a custom :meth:`_latex` method. 2. overwrite its printer methods (see `NumPyPrintable` and e.g. :doc:`compwa-org:report/001`). The `UnevaluatedExpression` base class makes implementations of its derived classes more secure by enforcing the developer to provide implementations for these methods, so that SymPy mechanisms work correctly. Decorators like :func:`implement_expr` and :func:`implement_doit_method` provide convenient means to implement the missing methods. .. autolink-preface:: import sympy as sp from ampform.sympy import UnevaluatedExpression, create_expression .. automethod:: __new__ .. automethod:: evaluate .. automethod:: _latex """ # https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L74-L77 __slots__: tuple[str] = ("_name",) _name: str | None """Optional instance attribute that can be used in LaTeX representations."""
[docs] def __new__( # pylint: disable=unused-argument cls: type[DecoratedClass], *args, name: str | None = None, **hints, ) -> DecoratedClass: """Constructor for a class derived from `UnevaluatedExpression`. This :meth:`~object.__new__` method correctly sets the `~sympy.core.basic.Basic.args`, assumptions etc. Overwrite it in order to further specify its signature. The function :func:`create_expression` can be used in its implementation, like so: >>> class MyExpression(UnevaluatedExpression): ... def __new__( ... cls, x: sp.Symbol, y: sp.Symbol, n: int, **hints ... ) -> "MyExpression": ... return create_expression(cls, x, y, n, **hints) ... ... def evaluate(self) -> sp.Expr: ... x, y, n = self.args ... return (x + y)**n ... >>> x, y = sp.symbols("x y") >>> expr = MyExpression(x, y, n=3) >>> expr MyExpression(x, y, 3) >>> expr.evaluate() (x + y)**3 """ # https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L113-L119 obj = object.__new__(cls) obj._args = args obj._assumptions = cls.default_assumptions # type: ignore[attr-defined] obj._mhash = None obj._name = name return obj
def __getnewargs_ex__(self) -> tuple[tuple, dict]: # Pickling support, see # https://github.com/sympy/sympy/blob/1.8/sympy/core/basic.py#L124-L126 args = tuple(self.args) kwargs = {"name": self._name} return args, kwargs
[docs] @abstractmethod def evaluate(self) -> sp.Expr: """Evaluate and 'unfold' this `UnevaluatedExpression` by one level. >>> from ampform.dynamics import BreakupMomentumSquared >>> issubclass(BreakupMomentumSquared, UnevaluatedExpression) True >>> s, m1, m2 = sp.symbols("s m1 m2") >>> expr = BreakupMomentumSquared(s, m1, m2) >>> expr BreakupMomentumSquared(s, m1, m2) >>> expr.evaluate() (s - (m1 - m2)**2)*(s - (m1 + m2)**2)/(4*s) >>> expr.doit(deep=False) (s - (m1 - m2)**2)*(s - (m1 + m2)**2)/(4*s) .. note:: When decorating this class with :func:`implement_doit_method`, its :meth:`evaluate` method is equivalent to :meth:`~sympy.core.basic.Basic.doit` with :code:`deep=False`. """
[docs] def _latex(self, printer: LatexPrinter, *args) -> str: r"""Provide a mathematical Latex representation for pretty printing. >>> from ampform.dynamics import BreakupMomentumSquared >>> issubclass(BreakupMomentumSquared, UnevaluatedExpression) True >>> s, m1 = sp.symbols("s m1") >>> expr = BreakupMomentumSquared(s, m1, m1) >>> print(sp.latex(expr)) q^2\left(s\right) >>> print(sp.latex(expr.doit())) - m_{1}^{2} + \frac{s}{4} """ args = tuple(map(printer._print, self.args)) name = type(self).__name__ if self._name is not None: name = self._name return f"{name}{args}"
[docs]class NumPyPrintable(sp.Expr): r"""`~sympy.core.expr.Expr` class that can lambdify to NumPy code. This interface for classes that derive from `sympy.Expr <sympy.core.expr.Expr>` enforce the implementation of a :meth:`_numpycode` method in case the class does not correctly :func:`~sympy.utilities.lambdify.lambdify` to NumPy code. For more info on SymPy printers, see :doc:`sympy:modules/printing`. Several computational frameworks try to converge their interface to that of NumPy. See for instance `TensorFlow's NumPy API <https://www.tensorflow.org/guide/tf_numpy>`_ and `jax.numpy <https://jax.readthedocs.io/en/latest/jax.numpy.html>`_. This fact is used in `TensorWaves <https://tensorwaves.rtfd.io>`_ to :func:`~sympy.utilities.lambdify.lambdify` SymPy expressions to these different backends with the same lambdification code. .. note:: This interface differs from `UnevaluatedExpression` in that it **should not** implement an :meth:`.evaluate` (and therefore a :meth:`~sympy.core.basic.Basic.doit`) method. .. warning:: The implemented :meth:`_numpycode` method should countain as little SymPy computations as possible. Instead, it should get most information from its construction `~sympy.core.basic.Basic.args`, so that SymPy can use printer tricks like :func:`~sympy.simplify.cse_main.cse`, prior expanding with :meth:`~sympy.core.basic.Basic.doit`, and other simplifications that can make the generated code shorter. An example is the `.BoostZMatrix` class, which takes :math:`\beta` as input instead of the `.FourMomentumSymbol` from which :math:`\beta` is computed. .. automethod:: _numpycode """
[docs] @abstractmethod def _numpycode(self, printer: NumPyPrinter, *args) -> str: """Lambdify this `NumPyPrintable` class to NumPy code."""
DecoratedClass = TypeVar("DecoratedClass", bound=UnevaluatedExpression) """`~typing.TypeVar` for decorators like :func:`implement_doit_method`."""
[docs]def implement_expr( n_args: int, ) -> Callable[[type[DecoratedClass]], type[DecoratedClass]]: """Decorator for classes that derive from `UnevaluatedExpression`. Implement a :meth:`~object.__new__` and :meth:`~sympy.core.basic.Basic.doit` method for a class that derives from `~sympy.core.expr.Expr` (via `UnevaluatedExpression`). """ def decorator( decorated_class: type[DecoratedClass], ) -> type[DecoratedClass]: decorated_class = implement_new_method(n_args)(decorated_class) decorated_class = implement_doit_method(decorated_class) return decorated_class return decorator
[docs]def implement_new_method( n_args: int, ) -> Callable[[type[DecoratedClass]], type[DecoratedClass]]: """Implement :meth:`UnevaluatedExpression.__new__` on a derived class. Implement a :meth:`~object.__new__` method for a class that derives from `~sympy.core.expr.Expr` (via `UnevaluatedExpression`). """ def decorator( decorated_class: type[DecoratedClass], ) -> type[DecoratedClass]: def new_method( # pylint: disable=unused-argument cls: type[DecoratedClass], *args: sp.Symbol, evaluate: bool = False, **hints, ) -> DecoratedClass: if len(args) != n_args: raise ValueError( f"{n_args} parameters expected, got {len(args)}" ) args = sp.sympify(args) expr = UnevaluatedExpression.__new__(cls, *args) if evaluate: return expr.evaluate() # type: ignore[return-value] return expr decorated_class.__new__ = new_method # type: ignore[assignment] return decorated_class return decorator
[docs]def implement_doit_method( decorated_class: type[DecoratedClass], ) -> type[DecoratedClass]: """Implement ``doit()`` method for an `UnevaluatedExpression` class. Implement a :meth:`~sympy.core.basic.Basic.doit` method for a class that derives from `~sympy.core.expr.Expr` (via `UnevaluatedExpression`). A :meth:`~sympy.core.basic.Basic.doit` method is an extension of an :meth:`~.UnevaluatedExpression.evaluate` method in the sense that it can work recursively on deeper expression trees. """ @functools.wraps(decorated_class.doit) # type: ignore[attr-defined] def doit_method(self: UnevaluatedExpression, deep: bool = True) -> sp.Expr: expr = self.evaluate() if deep: return expr.doit() return expr decorated_class.doit = doit_method # type: ignore[assignment] return decorated_class
def _implement_latex_subscript( # pyright: reportUnusedFunction=false subscript: str, ) -> Callable[[type[UnevaluatedExpression]], type[UnevaluatedExpression]]: def decorator( decorated_class: type[UnevaluatedExpression], ) -> type[UnevaluatedExpression]: # pylint: disable=protected-access, unused-argument @functools.wraps(decorated_class.doit) def _latex(self: sp.Expr, printer: LatexPrinter, *args) -> str: momentum = printer._print(self._momentum) # type: ignore[attr-defined] if printer._needs_mul_brackets(self._momentum): # type: ignore[attr-defined] momentum = Rf"\left({momentum}\right)" else: momentum = Rf"{{{momentum}}}" return f"{momentum}_{subscript}" decorated_class._latex = _latex # type: ignore[assignment] return decorated_class return decorator DecoratedExpr = TypeVar("DecoratedExpr", bound=sp.Expr) """`~typing.TypeVar` for decorators like :func:`make_commutative`."""
[docs]def make_commutative( decorated_class: type[DecoratedExpr], ) -> type[DecoratedExpr]: """Set commutative and 'extended real' assumptions on expression class. .. seealso:: :doc:`sympy:guides/assumptions` """ decorated_class.is_commutative = True # type: ignore[attr-defined] decorated_class.is_extended_real = True # type: ignore[attr-defined] return decorated_class
[docs]def create_expression( cls: type[DecoratedExpr], *args, evaluate: bool = False, name: str | None = None, **kwargs, ) -> DecoratedExpr: """Helper function for implementing `UnevaluatedExpression.__new__`.""" args = sp.sympify(args) if issubclass(cls, UnevaluatedExpression): expr = UnevaluatedExpression.__new__(cls, *args, name=name, **kwargs) if evaluate: return expr.evaluate() # type: ignore[return-value] return expr # type: ignore[return-value] return sp.Expr.__new__(cls, *args, **kwargs) # type: ignore[return-value]
[docs]def create_symbol_matrix(name: str, m: int, n: int) -> sp.MutableDenseMatrix: """Create a `~sympy.matrices.dense.Matrix` with symbols as elements. The `~sympy.matrices.expressions.MatrixSymbol` has some issues when one is interested in the elements of the matrix. This function instead creates a `~sympy.matrices.dense.Matrix` where the elements are `~sympy.tensor.indexed.Indexed` instances. To convert these `~sympy.tensor.indexed.Indexed` instances to a `~sympy.core.symbol.Symbol`, use :func:`symplot.substitute_indexed_symbols`. >>> create_symbol_matrix("A", m=2, n=3) Matrix([ [A[0, 0], A[0, 1], A[0, 2]], [A[1, 0], A[1, 1], A[1, 2]]]) """ symbol = sp.IndexedBase(name, shape=(m, n)) return sp.Matrix([[symbol[i, j] for j in range(n)] for i in range(m)])
[docs]@implement_doit_method class PoolSum(UnevaluatedExpression): r"""Sum over indices where the values are taken from a domain set. >>> i, j, m, n = sp.symbols("i j m n") >>> expr = PoolSum(i**m + j**n, (i, (-1, 0, +1)), (j, (2, 4, 5))) >>> expr PoolSum(i**m + j**n, (i, (-1, 0, 1)), (j, (2, 4, 5))) >>> print(sp.latex(expr)) \sum_{i=-1}^{1} \sum_{j\in\left\{2,4,5\right\}}{i^{m} + j^{n}} >>> expr.doit() 3*(-1)**m + 3*0**m + 3*2**n + 3*4**n + 3*5**n + 3 """ precedence = PRECEDENCE["Mul"] def __new__( cls, expression, *indices: tuple[sp.Symbol, Iterable[sp.Basic]], **hints, ) -> PoolSum: converted_indices = [] for idx_symbol, values in indices: values = tuple(values) if len(values) == 0: raise ValueError(f"No values provided for index {idx_symbol}") converted_indices.append((idx_symbol, values)) return create_expression(cls, expression, *converted_indices, **hints) @property def expression(self) -> sp.Expr: return self.args[0] # type: ignore[return-value] @property def indices(self) -> list[tuple[sp.Symbol, tuple[sp.Float, ...]]]: return self.args[1:] # type: ignore[return-value] @property def free_symbols(self) -> set[sp.Basic]: return super().free_symbols - {s for s, _ in self.indices} def evaluate(self) -> sp.Expr: indices = {symbol: tuple(values) for symbol, values in self.indices} return sp.Add( *[ self.expression.subs(zip(indices, combi)) for combi in itertools.product(*indices.values()) ] ) def _latex(self, printer: LatexPrinter, *args) -> str: indices = dict(self.indices) sum_symbols: list[str] = [] for idx, values in indices.items(): sum_symbols.append(_render_sum_symbol(printer, idx, values)) expression = printer._print(self.expression) return R" ".join(sum_symbols) + f"{{{expression}}}"
[docs] def cleanup(self) -> sp.Expr | PoolSum: """Remove redundant summations, like indices with one or no value. >>> x, i = sp.symbols("x i") >>> PoolSum(x**i, (i, [0, 1, 2])).cleanup().doit() x**2 + x + 1 >>> PoolSum(x, (i, [0, 1, 2])).cleanup() x >>> PoolSum(x).cleanup() x >>> PoolSum(x**i, (i, [0])).cleanup() 1 """ substitutions = {} new_indices = [] for idx, values in self.indices: if idx not in self.expression.free_symbols: continue if len(values) == 0: continue if len(values) == 1: substitutions[idx] = values[0] else: new_indices.append((idx, values)) new_expression = self.expression.xreplace(substitutions) if len(new_indices) == 0: return new_expression return PoolSum(new_expression, *new_indices)
def _render_sum_symbol( printer: LatexPrinter, idx: sp.Symbol, values: Sequence[SupportsFloat] ) -> str: if len(values) == 0: return "" idx = printer._print(idx) if len(values) == 1: value = values[0] return Rf"\sum_{{{idx}={value}}}" if _is_regular_series(values): sorted_values = sorted(values, key=float) first_value = sorted_values[0] last_value = sorted_values[-1] return Rf"\sum_{{{idx}={first_value}}}^{{{last_value}}}" idx_values = ",".join(map(printer._print, values)) return Rf"\sum_{{{idx}\in\left\{{{idx_values}\right\}}}}" def _is_regular_series(values: Sequence[SupportsFloat]) -> bool: """Check whether a set of values is a series with unit distances. >>> _is_regular_series([0, 1, 2]) True >>> _is_regular_series([-0.5, +0.5]) True >>> _is_regular_series([+0.5, -0.5, 1.5]) True >>> _is_regular_series([-1, +1]) False >>> _is_regular_series([1]) False >>> _is_regular_series([]) False """ if len(values) <= 1: return False sorted_values = sorted(values, key=float) for val, next_val in zip(sorted_values, sorted_values[1:]): difference = float(next_val) - float(val) if difference != 1.0: return False return True