{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hideCode": true,
    "hideOutput": true,
    "hidePrompt": true,
    "jupyter": {
     "source_hidden": true
    },
    "slideshow": {
     "slide_type": "skip"
    },
    "tags": [
     "remove-cell",
     "skip-execution"
    ]
   },
   "outputs": [],
   "source": [
    "# WARNING: advised to install a specific version, e.g. ampform==0.1.2\n",
    "%pip install -q ampform[doc,viz] IPython"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "hideCode": true,
    "hideOutput": true,
    "hidePrompt": true,
    "jupyter": {
     "source_hidden": true
    },
    "slideshow": {
     "slide_type": "skip"
    },
    "tags": [
     "remove-cell"
    ]
   },
   "outputs": [],
   "source": [
    "%config InlineBackend.figure_formats = ['svg']\n",
    "import os\n",
    "\n",
    "from IPython.display import display  # noqa: F401\n",
    "\n",
    "STATIC_WEB_PAGE = {\"EXECUTE_NB\", \"READTHEDOCS\"}.intersection(os.environ)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{autolink-concat}\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Modify amplitude model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    },
    "tags": [
     "hide-cell"
    ]
   },
   "outputs": [],
   "source": [
    "import attr\n",
    "import graphviz\n",
    "import qrules\n",
    "import sympy as sp\n",
    "\n",
    "from ampform import get_builder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since a {attr}`.HelicityModel.expression` is simply a {class}`sympy.Expr <sympy.core.expr.Expr>`, it's relatively easy to modify it. The {class}`.HelicityModel` however also contains other attributes that need to be modified accordingly. In this notebook, we show how to do that for specific use cases using the following example decay:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "remove-output"
    ]
   },
   "outputs": [],
   "source": [
    "result = qrules.generate_transitions(\n",
    "    initial_state=(\"J/psi(1S)\", [-1, +1]),\n",
    "    final_state=[\"gamma\", \"pi0\", \"pi0\"],\n",
    "    allowed_intermediate_particles=[\"f(0)(980)\", \"f(0)(1500)\"],\n",
    "    allowed_interaction_types=[\"strong\", \"EM\"],\n",
    "    formalism=\"helicity\",\n",
    ")\n",
    "model_builder = get_builder(result)\n",
    "original_model = model_builder.formulate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    },
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "dot = qrules.io.asdot(result, collapse_graphs=True)\n",
    "graphviz.Source(dot)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Couple parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can couple parameters renaming them:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "full-width"
    ]
   },
   "outputs": [],
   "source": [
    "renames = {\n",
    "    R\"C_{J/\\psi(1S) \\to {f_{0}(980)}_{0} \\gamma_{+1}; f_{0}(980) \\to \\pi^{0}_{0} \\pi^{0}_{0}}\": (\n",
    "        \"C\"\n",
    "    ),\n",
    "    R\"C_{J/\\psi(1S) \\to {f_{0}(1500)}_{0} \\gamma_{+1}; f_{0}(1500) \\to \\pi^{0}_{0} \\pi^{0}_{0}}\": (\n",
    "        \"C\"\n",
    "    ),\n",
    "}\n",
    "new_model = original_model.rename_symbols(renames)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "new_model.parameter_defaults"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "full-width"
    ]
   },
   "outputs": [],
   "source": [
    "new_model.components[\n",
    "    R\"I_{J/\\psi(1S)_{+1} \\to \\gamma_{+1} \\pi^{0}_{0} \\pi^{0}_{0}}\"\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameter substitution"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's say we want to express all coefficients as a product $Ce^{i\\phi}$ of magnitude $C$  with phase $\\phi$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "full-width"
    ]
   },
   "outputs": [],
   "source": [
    "original_coefficients = [\n",
    "    par\n",
    "    for par in original_model.parameter_defaults\n",
    "    if par.name.startswith(\"C\")\n",
    "]\n",
    "original_coefficients"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{margin}\n",
    "The attributes {attr}`~.HelicityModel.parameter_defaults` and {attr}`~.HelicityModel.components` are _mutable_ {obj}`dict`s, so these can be modified (even if not set as a whole). This is why we make a copy of them below.\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are two things to note now:\n",
    "\n",
    "1. These parameters appear in {attr}`.HelicityModel.expression`, its {attr}`~.HelicityModel.parameter_defaults`, and its  {attr}`~.HelicityModel.components`, so both these attributes should be modified accordingly.\n",
    "2. A {class}`.HelicityModel` is {doc}`immutable <attrs:how-does-it-work>`, so we cannot directly replace its attributes. Instead, we should create a new {class}`.HelicityModel` with substituted attributes using {func}`attrs.evolve`:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following snippet shows how to do all this. It's shown in full, because it could well be you want to perform some completely different substitutions (can be any kinds of {meth}`~sympy.core.basic.Basic.subs`). The overall procedure is comparable, however."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_intensity = original_model.intensity\n",
    "new_amplitudes = dict(original_model.amplitudes)\n",
    "new_parameter_defaults = dict(original_model.parameter_defaults)  # copy!\n",
    "new_components = dict(original_model.components)  # copy!\n",
    "\n",
    "for coefficient in original_coefficients:\n",
    "    decay_description = coefficient.name[3:-1]\n",
    "    magnitude = sp.Symbol(  # coefficient with same name, but real, not complex\n",
    "        coefficient.name,\n",
    "        real=True,\n",
    "        positive=True,\n",
    "    )\n",
    "    phase = sp.Symbol(\n",
    "        Rf\"\\phi_{{{decay_description}}}\",\n",
    "        real=True,\n",
    "    )\n",
    "    replacement = magnitude * sp.exp(sp.I * phase)\n",
    "    display(replacement)\n",
    "    # replace parameter defaults\n",
    "    del new_parameter_defaults[coefficient]\n",
    "    new_parameter_defaults[magnitude] = 1.0\n",
    "    new_parameter_defaults[phase] = 0.0\n",
    "    # replace parameters in expression\n",
    "    new_intensity = new_intensity.subs(\n",
    "        coefficient, replacement, simultaneous=True\n",
    "    )\n",
    "    # replace parameters in each component\n",
    "    new_amplitudes = {\n",
    "        key: old_expression.subs(coefficient, replacement, simultaneous=True)\n",
    "        for key, old_expression in new_amplitudes.items()\n",
    "    }\n",
    "    new_components = {\n",
    "        key: old_expression.subs(coefficient, replacement, simultaneous=True)\n",
    "        for key, old_expression in new_components.items()\n",
    "    }\n",
    "\n",
    "# create new model from the old\n",
    "new_model = attr.evolve(\n",
    "    original_model,\n",
    "    intensity=new_intensity,\n",
    "    amplitudes=new_amplitudes,\n",
    "    parameter_defaults=new_parameter_defaults,\n",
    "    components=new_components,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    },
    "tags": [
     "hide-cell"
    ]
   },
   "outputs": [],
   "source": [
    "assert new_model != original_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As can be seen, the {attr}`~.HelicityModel.parameter_defaults` have bene updated, as have the {attr}`~.HelicityModel.components`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "full-width"
    ]
   },
   "outputs": [],
   "source": [
    "dict(new_model.parameter_defaults)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "full-width"
    ]
   },
   "outputs": [],
   "source": [
    "new_model.components[\n",
    "    R\"A_{J/\\psi(1S)_{-1} \\to f_{0}(980)_{0} \\gamma_{-1}; f_{0}(980)_{0} \\to\"\n",
    "    R\" \\pi^{0}_{0} \\pi^{0}_{0}}\"\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Also note that the new model reduces to the old once we replace the parameters with their suggested default values:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluated_expr = new_model.expression.subs(new_model.parameter_defaults).doit()\n",
    "evaluated_expr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    },
    "tags": [
     "hide-cell"
    ]
   },
   "outputs": [],
   "source": [
    "assert (\n",
    "    original_model.expression.subs(original_model.parameter_defaults).doit()\n",
    "    == evaluated_expr\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
