Source code for ampform.sympy.math

# cspell:ignore cmath compwa csqrt lambdifier
# pylint: disable=no-member, protected-access, unused-argument, W0223
# https://stackoverflow.com/a/22224042
"""A collection of basic math operations, used in `ampform.dynamics`."""
from __future__ import annotations

from typing import overload

import sympy as sp
from sympy.plotting.experimental_lambdify import Lambdifier
from sympy.printing.numpy import NumPyPrinter
from sympy.printing.printer import Printer
from sympy.printing.pycode import PythonCodePrinter

from . import NumPyPrintable, create_expression, make_commutative


[docs]@make_commutative class ComplexSqrt(NumPyPrintable): """Square root that returns positive imaginary values for negative input. A special version :func:`~sympy.functions.elementary.miscellaneous.sqrt` that renders nicely as LaTeX and and can be used as a handle for lambdify printers. See :doc:`compwa-org:report/000`, :doc:`compwa-org:report/001`, and :doc:`sympy:modules/printing` for how to implement a custom :func:`~sympy.utilities.lambdify.lambdify` printer. """ @overload def __new__(cls, x: sp.Number, *args, **kwargs) -> sp.Expr: # type: ignore[misc] ... @overload def __new__(cls, x: sp.Expr, *args, **kwargs) -> ComplexSqrt: ... def __new__(cls, x, *args, **kwargs): x = sp.sympify(x) expr = create_expression(cls, x, *args, **kwargs) if isinstance(x, sp.Number): return expr.get_definition() return expr def _latex(self, printer: Printer, *args) -> str: x = printer._print(self.args[0]) return Rf"\sqrt[\mathrm{{c}}]{{{x}}}" def _numpycode(self, printer: NumPyPrinter, *args) -> str: return self.__print_complex(printer) def _pythoncode(self, printer: PythonCodePrinter, *args) -> str: printer.module_imports["cmath"].add("sqrt as csqrt") x = printer._print(self.args[0]) return ( f"(((1j*sqrt(-{x}))" f" if isinstance({x}, (float, int)) and ({x} < 0)" f" else (csqrt({x}))))" ) def __print_complex(self, printer: Printer) -> str: expr = self.get_definition() return printer._print(expr)
[docs] def get_definition(self) -> sp.Piecewise: """Get a symbolic definition for this expression class. .. note:: This class is `.NumPyPrintable`, so should not have an :meth:`~.UnevaluatedExpression.evaluate` method (in order to block :meth:`~sympy.core.basic.Basic.doit`). This method serves as an equivalent to that. """ x: sp.Expr = self.args[0] # type: ignore[assignment] return sp.Piecewise( (sp.I * sp.sqrt(-x), x < 0), (sp.sqrt(x), True), )
Lambdifier.builtin_functions_different["ComplexSqrt"] = "sqrt"