{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hideCode": true,
    "hideOutput": true,
    "hidePrompt": true,
    "jupyter": {
     "source_hidden": true
    },
    "slideshow": {
     "slide_type": "skip"
    },
    "tags": [
     "remove-cell",
     "skip-execution"
    ]
   },
   "outputs": [],
   "source": [
    "# WARNING: advised to install a specific version, e.g. ampform==0.1.2\n",
    "%pip install -q ampform[doc,viz] IPython"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hideCode": true,
    "hideOutput": true,
    "hidePrompt": true,
    "jupyter": {
     "source_hidden": true
    },
    "slideshow": {
     "slide_type": "skip"
    },
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [],
   "source": [
    "%config InlineBackend.figure_formats = ['svg']\n",
    "import os\n",
    "\n",
    "from IPython.display import display  # noqa: F401\n",
    "\n",
    "STATIC_WEB_PAGE = {\"EXECUTE_NB\", \"READTHEDOCS\"}.intersection(os.environ)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{autolink-concat}\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "skip"
    },
    "tags": []
   },
   "source": [
    "# Inspect model interactively"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we illustrate how to interactively inspect a {class}`.HelicityModel`. The procedure should in fact work for any {class}`sympy.Expr <sympy.core.expr.Expr>`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create amplitude model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we create some {class}`.HelicityModel`. We could also have used {mod}`pickle` to {func}`~pickle.load` the {class}`.HelicityModel` that we created in {doc}`/usage/amplitude`, but the cell below allows running this notebook independently."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import qrules\n",
    "\n",
    "from ampform import get_builder\n",
    "from ampform.dynamics.builder import (\n",
    "    create_non_dynamic_with_ff,\n",
    "    create_relativistic_breit_wigner_with_ff,\n",
    ")\n",
    "\n",
    "reaction = qrules.generate_transitions(\n",
    "    initial_state=(\"J/psi(1S)\", [-1, +1]),\n",
    "    final_state=[\"gamma\", \"pi0\", \"pi0\"],\n",
    "    allowed_intermediate_particles=[\"f(0)(980)\", \"f(0)(1500)\"],\n",
    "    allowed_interaction_types=[\"strong\", \"EM\"],\n",
    "    formalism=\"canonical-helicity\",\n",
    ")\n",
    "model_builder = get_builder(reaction)\n",
    "initial_state_particle = reaction.initial_state[-1]\n",
    "model_builder.set_dynamics(initial_state_particle, create_non_dynamic_with_ff)\n",
    "for name in reaction.get_intermediate_particles().names:\n",
    "    model_builder.set_dynamics(name, create_relativistic_breit_wigner_with_ff)\n",
    "model = model_builder.formulate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this case, {ref}`as we saw <usage/amplitude:Mathematical formula>`, the overall model contains just one intensity term $I = |\\sum_i A_i|^2$, with $\\sum_i A_i$ some coherent sum of amplitudes. We can extract $\\sum_i A_i$ as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sympy as sp\n",
    "\n",
    "amplitude = model.expression.args[0].args[0].args[0]\n",
    "assert isinstance(amplitude, sp.Add)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Substitute some of the boring parameters with the provided {attr}`~.HelicityModel.parameter_defaults`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "substitutions = {\n",
    "    symbol: value\n",
    "    for symbol, value in model.parameter_defaults.items()\n",
    "    if not symbol.name.startswith(\"Gamma_\")\n",
    "    and not symbol.name.startswith(\"m_\")\n",
    "}\n",
    "amplitude = amplitude.doit().subs(substitutions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Lambdify"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now need to identify the {class}`~sympy.core.symbol.Symbol` over which the amplitude is to be plotted. The remaining symbols will be turned into slider parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_symbol = sp.Symbol(\"m_12\", real=True)\n",
    "slider_symbols = sorted(amplitude.free_symbols, key=lambda s: s.name)\n",
    "slider_symbols.remove(plot_symbol)\n",
    "slider_symbols"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, {func}`~sympy.utilities.lambdify.lambdify` the expression:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np_amplitude = sp.lambdify(\n",
    "    (plot_symbol, *slider_symbols),\n",
    "    amplitude,\n",
    "    \"numpy\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We also have to define some functions that formulate what we want to plot. A pure amplitude won't do, because we can only plot real values:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ":::{margin}\n",
    "\n",
    "See {doc}`mpl_interactions:examples/plot` and {doc}`mpl_interactions:examples/scatter` for why these functions are constructed this way.\n",
    "\n",
    ":::"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "def intensity(plot_variable, **kwargs):\n",
    "    values = np_amplitude(plot_variable, **kwargs)\n",
    "    return np.abs(values) ** 2\n",
    "\n",
    "\n",
    "def argand(**kwargs):\n",
    "    values = np_amplitude(plot_domain, **kwargs)\n",
    "    argand = np.array([values.real, values.imag])\n",
    "    return argand.T"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare sliders"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ":::{tip}\n",
    "\n",
    "This procedure has been extracted to the façade function {func}`symplot.prepare_sliders`.\n",
    "\n",
    ":::"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We also need to define the domain over which to plot, as well sliders for each of the remaining symbols. The function {func}`.create_slider` helps creating an [ipywidgets slider](https://ipywidgets.readthedocs.io/en/stable/examples/Widget%20List.html) for each {class}`~sympy.core.symbol.Symbol`:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    ":::{margin}\n",
    "\n",
    "{doc}`/usage/symplot` is not a published module, just a set of helper functions and classes provided for this documentation. Since the procedure sketched on this page is quite general, this module could be published as a separate interactive plotting package for {mod}`sympy` later on.\n",
    "\n",
    ":::"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from symplot import create_slider\n",
    "\n",
    "plot_domain = np.linspace(0.2, 2.5, num=400)\n",
    "sliders_mapping = {\n",
    "    symbol.name: create_slider(symbol) for symbol in slider_symbols\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If the name of a {class}`~sympy.core.symbol.Symbol` is not a valid name for a Python variable (see {meth}`str.isidentifier`), {func}`~sympy.utilities.lambdify.lambdify` 'dummifies' it, so it can be used as argument for the lambdified function. Since {func}`~mpl_interactions.pyplot.interactive_plot` works with these (dummified) arguments as identifiers for the sliders, we need some mapping between the {class}`~sympy.core.symbol.Symbol` and their dummified name. This can be constructed with the help of {func}`inspect.signature`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import inspect\n",
    "\n",
    "symbols_names = list(map(lambda s: s.name, (plot_symbol, *slider_symbols)))\n",
    "arg_names = list(inspect.signature(np_amplitude).parameters)\n",
    "arg_to_symbol = dict(zip(arg_names, symbols_names))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The class {class}`.SliderKwargs` comes in as a handy manager for this {obj}`dict` of sliders:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "full-width"
    ]
   },
   "outputs": [],
   "source": [
    "from symplot import SliderKwargs\n",
    "\n",
    "sliders = SliderKwargs(sliders_mapping, arg_to_symbol)\n",
    "sliders"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "{class}`.SliderKwargs` also provides convenient methods for setting ranges and initial values for the sliders."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_steps = 100\n",
    "sliders.set_ranges(\n",
    "    {\n",
    "        \"m_f(0)(980)\": (0.3, 1.8, n_steps),\n",
    "        \"m_f(0)(1500)\": (0.3, 1.8, n_steps),\n",
    "        \"Gamma_f(0)(980)\": (0.01, 1, n_steps),\n",
    "        \"Gamma_f(0)(1500)\": (0.01, 1, n_steps),\n",
    "        \"m_1\": (0.01, 1, n_steps),\n",
    "        \"m_2\": (0.01, 1, n_steps),\n",
    "        \"phi_1+2\": (0, 2 * np.pi, 40),\n",
    "        \"theta_1+2\": (-np.pi, np.pi, 40),\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The method {meth}`~.SliderKwargs.set_values` is designed in particular for {attr}`~.HelicityModel.parameter_defaults`, but also works well with both argument names (that may have been dummified) and the original symbol names:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import qrules\n",
    "\n",
    "pdg = qrules.load_pdg()\n",
    "sliders.set_values(model.parameter_defaults)\n",
    "sliders.set_values(\n",
    "    {  # symbol names\n",
    "        \"phi_1+2\": 0,\n",
    "        \"theta_1+2\": np.pi / 2,\n",
    "    },\n",
    "    # argument names\n",
    "    m_1=pdg[\"pi0\"].mass,\n",
    "    m_2=pdg[\"pi0\"].mass,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    },
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [],
   "source": [
    "if STATIC_WEB_PAGE:\n",
    "    # Concatenate flipped domain for reverse animation\n",
    "    domain = np.linspace(0.8, 2.2, 100)\n",
    "    domain = np.concatenate((domain, np.flip(domain[1:])))\n",
    "    sliders._sliders[\"m_f(0)(980)\"] = domain"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Interactive Argand plot"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we can use {doc}`mpl-interactions <mpl_interactions:index>` to plot the functions defined above with regard to the parameter values:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    ":::{margin}\n",
    "\n",
    "Interactive {mod}`~matplotlib.widgets` do not render well on web pages, so run this notebook in on Binder or locally in Jupyter Lab to see the full potential of {doc}`mpl-interactions <mpl_interactions:index>`!\n",
    "\n",
    ":::"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    },
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "%matplotlib widget\n",
    "import matplotlib.pyplot as plt\n",
    "import mpl_interactions.ipyplot as iplt\n",
    "\n",
    "# Create figure\n",
    "fig, axes = plt.subplots(\n",
    "    1, 2, figsize=0.9 * np.array((8, 3.8)), tight_layout=True\n",
    ")\n",
    "fig.canvas.toolbar_visible = False\n",
    "fig.canvas.header_visible = False\n",
    "fig.canvas.footer_visible = False\n",
    "fig.suptitle(R\"$J/\\psi \\to \\gamma f_0, f_0 \\to \\pi^0\\pi^0$\")\n",
    "ax_intensity, ax_argand = axes\n",
    "m_label = R\"$m_{\\pi^0\\pi^0}$ (GeV)\"\n",
    "ax_intensity.set_xlabel(m_label)\n",
    "ax_intensity.set_ylabel(\"$I = |A|^2$\")\n",
    "ax_argand.set_xlabel(\"Re($A$)\")\n",
    "ax_argand.set_ylabel(\"Im($A$)\")\n",
    "\n",
    "# Fill plots\n",
    "controls = iplt.plot(\n",
    "    plot_domain,\n",
    "    intensity,\n",
    "    **sliders,\n",
    "    slider_formats={slider: \"{:.3f}\" for slider in arg_names},\n",
    "    ylim=\"auto\",\n",
    "    ax=ax_intensity,\n",
    ")\n",
    "iplt.scatter(\n",
    "    argand,\n",
    "    controls=controls,\n",
    "    xlim=\"auto\",\n",
    "    ylim=\"auto\",\n",
    "    parametric=True,\n",
    "    c=plot_domain,\n",
    "    s=1,\n",
    "    ax=ax_argand,\n",
    ")\n",
    "plt.colorbar(label=m_label, ax=ax_argand, aspect=30, pad=0.01)\n",
    "plt.winter()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    ":::{margin}\n",
    "\n",
    "This figure is an animation over **just one of the parameters**. Run the notebook itself to play around with all parameters!\n",
    "\n",
    ":::"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    },
    "tags": [
     "remove-input"
    ]
   },
   "outputs": [],
   "source": [
    "# Export for Read the Docs\n",
    "if STATIC_WEB_PAGE:\n",
    "    from IPython.display import Image\n",
    "\n",
    "    output_path = \"animation.gif\"\n",
    "    symbol_to_arg = dict(zip(symbols_names, arg_names))\n",
    "    arg_name = symbol_to_arg[\"m_f(0)(980)\"]\n",
    "    controls.save_animation(output_path, fig, arg_name, fps=25)\n",
    "    with open(output_path, \"rb\") as f:\n",
    "        display(Image(data=f.read(), format=\"png\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ":::{tip}\n",
    "\n",
    "See {doc}`/usage/dynamics/k-matrix` for why $\\boldsymbol{K}$-matrix dynamics are better than simple Breit-Wigners when resonances are close to each other.\n",
    "\n",
    ":::"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{toctree}\n",
    ":hidden:\n",
    "symplot\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
