diff --git a/notebooks/Grahpviz rewrites.ipynb b/notebooks/Grahpviz rewrites.ipynb
new file mode 100644
index 0000000000..b5c9e2703c
--- /dev/null
+++ b/notebooks/Grahpviz rewrites.ipynb	
@@ -0,0 +1,613 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "%load_ext autoreload\n",
+    "%autoreload 2"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import pytensor\n",
+    "import pytensor.tensor as pt\n",
+    "from pytensor.graph.fg import FunctionGraph\n",
+    "from pytensor.graph.features import History, FullHistory\n",
+    "from pytensor.graph.rewriting.utils import rewrite_graph"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "FunctionGraph(LogSoftmax{axis=None}(x))"
+      ]
+     },
+     "execution_count": 4,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    " x = pt.scalar(\"x\")\n",
+    "out = pt.log(pt.exp(x) / pt.sum(pt.exp(x)))\n",
+    "\n",
+    "fg = FunctionGraph(outputs=[out])\n",
+    "history = FullHistory()\n",
+    "fg.attach_feature(history)\n",
+    "\n",
+    "rewrite_graph(fg, clone=False, include=(\"canonicalize\", \"stabilize\"))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/svg+xml": [
+       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
+       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
+       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
+       "<!-- Generated by graphviz version 8.1.0 (20230707.2238)\n",
+       " -->\n",
+       "<!-- Title: G Pages: 1 -->\n",
+       "<svg width=\"531pt\" height=\"477pt\"\n",
+       " viewBox=\"0.00 0.00 530.82 476.50\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
+       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 472.5)\">\n",
+       "<title>G</title>\n",
+       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-472.5 526.82,-472.5 526.82,4 -4,4\"/>\n",
+       "<!-- 139904552140112 -->\n",
+       "<g id=\"node1\" class=\"node\">\n",
+       "<title>139904552140112</title>\n",
+       "<ellipse fill=\"#ffaabb\" stroke=\"black\" cx=\"177.82\" cy=\"-377.5\" rx=\"30.69\" ry=\"18\"/>\n",
+       "<text text-anchor=\"middle\" x=\"177.82\" y=\"-372.82\" font-family=\"Times,serif\" font-size=\"14.00\">Exp</text>\n",
+       "</g>\n",
+       "<!-- 139904567218832 -->\n",
+       "<g id=\"node3\" class=\"node\">\n",
+       "<title>139904567218832</title>\n",
+       "<ellipse fill=\"none\" stroke=\"black\" cx=\"103.82\" cy=\"-269.5\" rx=\"103.82\" ry=\"18\"/>\n",
+       "<text text-anchor=\"middle\" x=\"103.82\" y=\"-264.82\" font-family=\"Times,serif\" font-size=\"14.00\">Sum{axes=None}</text>\n",
+       "</g>\n",
+       "<!-- 139904552140112&#45;&gt;139904567218832 -->\n",
+       "<g id=\"edge2\" class=\"edge\">\n",
+       "<title>139904552140112&#45;&gt;139904567218832</title>\n",
+       "<path fill=\"none\" stroke=\"black\" d=\"M147.19,-374.06C125.42,-370.55 97.61,-362.03 83.32,-341.5 74.15,-328.32 78.84,-311.07 86.03,-296.86\"/>\n",
+       "<polygon fill=\"black\" stroke=\"black\" points=\"89.39,-299.04 91.27,-288.61 83.3,-295.59 89.39,-299.04\"/>\n",
+       "<text text-anchor=\"middle\" x=\"173.07\" y=\"-318.82\" font-family=\"Times,serif\" font-size=\"14.00\">Scalar(float64, shape=())</text>\n",
+       "</g>\n",
+       "<!-- 139904803887568 -->\n",
+       "<g id=\"node2\" class=\"node\">\n",
+       "<title>139904803887568</title>\n",
+       "<polygon fill=\"green\" stroke=\"black\" points=\"381.94,-468.5 123.69,-468.5 123.69,-432.5 381.94,-432.5 381.94,-468.5\"/>\n",
+       "<text text-anchor=\"middle\" x=\"252.82\" y=\"-445.82\" font-family=\"Times,serif\" font-size=\"14.00\">name=x Scalar(float64, shape=())</text>\n",
+       "</g>\n",
+       "<!-- 139904803887568&#45;&gt;139904552140112 -->\n",
+       "<g id=\"edge1\" class=\"edge\">\n",
+       "<title>139904803887568&#45;&gt;139904552140112</title>\n",
+       "<path fill=\"none\" stroke=\"black\" d=\"M234.66,-432.31C224.58,-422.76 211.92,-410.78 201.11,-400.55\"/>\n",
+       "<polygon fill=\"black\" stroke=\"black\" points=\"203.94,-398.46 194.27,-394.13 199.13,-403.55 203.94,-398.46\"/>\n",
+       "</g>\n",
+       "<!-- 139904546108816 -->\n",
+       "<g id=\"node4\" class=\"node\">\n",
+       "<title>139904546108816</title>\n",
+       "<ellipse fill=\"#ffaabb\" stroke=\"black\" cx=\"328.82\" cy=\"-323.5\" rx=\"58.05\" ry=\"18\"/>\n",
+       "<text text-anchor=\"middle\" x=\"328.82\" y=\"-318.82\" font-family=\"Times,serif\" font-size=\"14.00\">Exp id=2</text>\n",
+       "</g>\n",
+       "<!-- 139904803887568&#45;&gt;139904546108816 -->\n",
+       "<g id=\"edge3\" class=\"edge\">\n",
+       "<title>139904803887568&#45;&gt;139904546108816</title>\n",
+       "<path fill=\"none\" stroke=\"black\" d=\"M263.39,-432.12C276.28,-410.92 298.19,-374.88 313.07,-350.41\"/>\n",
+       "<polygon fill=\"black\" stroke=\"black\" points=\"316.38,-352.7 318.58,-342.33 310.4,-349.06 316.38,-352.7\"/>\n",
+       "</g>\n",
+       "<!-- 139904546108240 -->\n",
+       "<g id=\"node5\" class=\"node\">\n",
+       "<title>139904546108240</title>\n",
+       "<ellipse fill=\"#ffaabb\" stroke=\"black\" cx=\"215.82\" cy=\"-180.25\" rx=\"54.36\" ry=\"18\"/>\n",
+       "<text text-anchor=\"middle\" x=\"215.82\" y=\"-175.57\" font-family=\"Times,serif\" font-size=\"14.00\">True_div</text>\n",
+       "</g>\n",
+       "<!-- 139904567218832&#45;&gt;139904546108240 -->\n",
+       "<g id=\"edge5\" class=\"edge\">\n",
+       "<title>139904567218832&#45;&gt;139904546108240</title>\n",
+       "<path fill=\"none\" stroke=\"black\" d=\"M104.46,-251.01C105.75,-239.8 109.16,-225.64 117.82,-216.25 128.32,-204.86 142.77,-197.18 157.36,-191.99\"/>\n",
+       "<polygon fill=\"black\" stroke=\"black\" points=\"157.96,-195.16 166.43,-188.79 155.84,-188.48 157.96,-195.16\"/>\n",
+       "<text text-anchor=\"middle\" x=\"213.82\" y=\"-220.2\" font-family=\"Times,serif\" font-size=\"14.00\">1 Scalar(float64, shape=())</text>\n",
+       "</g>\n",
+       "<!-- 139904546108816&#45;&gt;139904546108240 -->\n",
+       "<g id=\"edge4\" class=\"edge\">\n",
+       "<title>139904546108816&#45;&gt;139904546108240</title>\n",
+       "<path fill=\"none\" stroke=\"black\" d=\"M330.6,-305.27C332.01,-282.29 330.96,-241.57 309.82,-216.25 300.57,-205.17 287.46,-197.6 273.95,-192.42\"/>\n",
+       "<polygon fill=\"black\" stroke=\"black\" points=\"275.2,-188.81 264.61,-188.91 272.95,-195.44 275.2,-188.81\"/>\n",
+       "<text text-anchor=\"middle\" x=\"426.82\" y=\"-264.82\" font-family=\"Times,serif\" font-size=\"14.00\">0 Scalar(float64, shape=())</text>\n",
+       "</g>\n",
+       "<!-- 139904546106512 -->\n",
+       "<g id=\"node6\" class=\"node\">\n",
+       "<title>139904546106512</title>\n",
+       "<ellipse fill=\"#ffaabb\" stroke=\"black\" cx=\"215.82\" cy=\"-91\" rx=\"29.64\" ry=\"18\"/>\n",
+       "<text text-anchor=\"middle\" x=\"215.82\" y=\"-86.33\" font-family=\"Times,serif\" font-size=\"14.00\">Log</text>\n",
+       "</g>\n",
+       "<!-- 139904546108240&#45;&gt;139904546106512 -->\n",
+       "<g id=\"edge6\" class=\"edge\">\n",
+       "<title>139904546108240&#45;&gt;139904546106512</title>\n",
+       "<path fill=\"none\" stroke=\"black\" d=\"M215.82,-162.01C215.82,-150.06 215.82,-133.88 215.82,-120.08\"/>\n",
+       "<polygon fill=\"black\" stroke=\"black\" points=\"219.32,-120.2 215.82,-110.2 212.32,-120.2 219.32,-120.2\"/>\n",
+       "<text text-anchor=\"middle\" x=\"305.07\" y=\"-130.95\" font-family=\"Times,serif\" font-size=\"14.00\">Scalar(float64, shape=())</text>\n",
+       "</g>\n",
+       "<!-- 139904546106320 -->\n",
+       "<g id=\"node7\" class=\"node\">\n",
+       "<title>139904546106320</title>\n",
+       "<polygon fill=\"blue\" stroke=\"black\" points=\"313.07,-36 118.57,-36 118.57,0 313.07,0 313.07,-36\"/>\n",
+       "<text text-anchor=\"middle\" x=\"215.82\" y=\"-13.32\" font-family=\"Times,serif\" font-size=\"14.00\">Scalar(float64, shape=())</text>\n",
+       "</g>\n",
+       "<!-- 139904546106512&#45;&gt;139904546106320 -->\n",
+       "<g id=\"edge7\" class=\"edge\">\n",
+       "<title>139904546106512&#45;&gt;139904546106320</title>\n",
+       "<path fill=\"none\" stroke=\"black\" d=\"M215.82,-72.81C215.82,-65.05 215.82,-55.68 215.82,-46.95\"/>\n",
+       "<polygon fill=\"black\" stroke=\"black\" points=\"219.32,-47.03 215.82,-37.03 212.32,-47.03 219.32,-47.03\"/>\n",
+       "</g>\n",
+       "</g>\n",
+       "</svg>\n"
+      ],
+      "text/plain": [
+       "<graphviz.sources.Source at 0x7f3e10a9ec90>"
+      ]
+     },
+     "execution_count": 5,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import graphviz\n",
+    "from pytensor.printing import pydotprint\n",
+    "out = pydotprint(history.start(), str)\n",
+    "gv = graphviz.Source(out.to_string())\n",
+    "gv"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "PosixPath('/home/evelin/Documents/Ricardo/Projects/pytensor/notebooks/widget.js')"
+      ]
+     },
+     "execution_count": 6,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import pathlib\n",
+    "import os\n",
+    "(pathlib.Path().parent / \"widget.js\").absolute()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "e6744349c4d04aac8cd51f08278f3042",
+       "version_major": 2,
+       "version_minor": 1
+      },
+      "text/plain": [
+       "CounterWidget()"
+      ]
+     },
+     "execution_count": 1,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import anywidget\n",
+    "import traitlets\n",
+    "\n",
+    "class CounterWidget(anywidget.AnyWidget):\n",
+    "    # Widget front-end JavaScript code\n",
+    "    _esm = \"\"\"\n",
+    "    function render({ model, el }) {\n",
+    "      let button = document.createElement(\"button\");\n",
+    "      button.innerHTML = `count is ${model.get(\"value\")}`;\n",
+    "      button.addEventListener(\"click\", () => {\n",
+    "        model.set(\"value\", model.get(\"value\") + 1);\n",
+    "        model.save_changes();\n",
+    "      });\n",
+    "      model.on(\"change:value\", () => {\n",
+    "        button.innerHTML = `count is ${model.get(\"value\")}`;\n",
+    "      });\n",
+    "      el.appendChild(button);\n",
+    "    }\n",
+    "    export default { render };\n",
+    "    \"\"\"\n",
+    "    # Stateful property that can be accessed by JavaScript & Python\n",
+    "    value = traitlets.Int(0).tag(sync=True)\n",
+    "    \n",
+    "counter = CounterWidget()\n",
+    "counter"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "89a17dbc5e30488ba64fb084f383a7df",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "IntSlider(value=0, max=5)"
+      ]
+     },
+     "execution_count": 4,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import ipywidgets\n",
+    "ipywidgets.IntSlider(max=5, value=0)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 18,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import anywidget\n",
+    "import pathlib\n",
+    "import traitlets\n",
+    "import ipywidgets\n",
+    "\n",
+    "class GraphvizWidget(anywidget.AnyWidget):\n",
+    "    \"\"\"\n",
+    "    Graphviz widget to render multiple graphviz graphs.\n",
+    "\n",
+    "    The index will choose the one that is currently displayed, defaulting to the last one.\n",
+    "    If the index or the graphs change, there will be a re-rendering, with animation.\n",
+    "    \"\"\"\n",
+    "\n",
+    "    _esm = r\"/home/evelin/Documents/Ricardo/Projects/pytensor/notebooks/widget.js\"\n",
+    "    _css = r\"/home/evelin/Documents/Ricardo/Projects/pytensor/notebooks/widget.css\"\n",
+    "    dots = traitlets.List().tag(sync=True)\n",
+    "    index = traitlets.Int(0, allow_none=True).tag(sync=True)\n",
+    "    performance = traitlets.Bool(False).tag(sync=True)\n",
+    "\n",
+    "    \n",
+    "class CounterWidget(anywidget.AnyWidget):\n",
+    "    # Widget front-end JavaScript code\n",
+    "    _esm = \"\"\"\n",
+    "    function render({ model, el }) {\n",
+    "      let button = document.createElement(\"button\");\n",
+    "      button.innerHTML = `count is ${model.get(\"value\")}`;\n",
+    "      button.addEventListener(\"click\", () => {\n",
+    "        model.set(\"value\", model.get(\"value\") + 1);\n",
+    "        model.save_changes();\n",
+    "      });\n",
+    "      model.on(\"change:value\", () => {\n",
+    "        button.innerHTML = `count is ${model.get(\"value\")}`;\n",
+    "      });\n",
+    "      el.appendChild(button);\n",
+    "    }\n",
+    "    export default { render };\n",
+    "    \"\"\"\n",
+    "    # Stateful property that can be accessed by JavaScript & Python\n",
+    "    value = traitlets.Int(0).tag(sync=True)\n",
+    "    \n",
+    "\n",
+    "def graphviz_widget_with_slider(dots: list[str], *, performance: bool = False) -> ipywidgets.VBox:\n",
+    "    n_dots = len(dots)\n",
+    "    graphviz_widget = GraphvizWidget()\n",
+    "    graphviz_widget.dots = dots\n",
+    "    graphviz_widget.performance = performance\n",
+    "    slider_widget = ipywidgets.IntSlider(max=n_dots - 1, value=0)\n",
+    "    ipywidgets.jslink((slider_widget, \"value\"), (graphviz_widget, \"index\"))\n",
+    "#     play_widget = ipywidgets.Play(max=n_dots - 1, repeat=True, interval=4000)\n",
+    "#     ipywidgets.jslink((slider_widget, \"value\"), (play_widget, \"value\"))\n",
+    "#     top = ipywidgets.HBox([play_widget, slider_widget])\n",
+    "    return ipywidgets.VBox([slider_widget, CounterWidget()])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 19,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "graphviz_widget_with_slider([out.to_string(), out.to_string()])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import graphviz\n",
+    "from pytensor.printing import pydotprint\n",
+    "out = pydotprint(predict, str)\n",
+    "gv = graphviz.Source(out.to_string())"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from IPython.display import Image\n",
+    "Image('./examples/mlp.png', width='80%')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/svg+xml": [
+       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
+       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
+       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
+       "<!-- Generated by graphviz version 8.1.0 (20230707.2238)\n",
+       " -->\n",
+       "<!-- Title: the holy hand grenade Pages: 1 -->\n",
+       "<svg width=\"332pt\" height=\"44pt\"\n",
+       " viewBox=\"0.00 0.00 332.00 44.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
+       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 40)\">\n",
+       "<title>the holy hand grenade</title>\n",
+       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-40 328,-40 328,4 -4,4\"/>\n",
+       "<!-- 1 -->\n",
+       "<g id=\"node1\" class=\"node\">\n",
+       "<title>1</title>\n",
+       "<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
+       "<text text-anchor=\"middle\" x=\"27\" y=\"-13.32\" font-family=\"Times,serif\" font-size=\"14.00\">1</text>\n",
+       "</g>\n",
+       "<!-- 2 -->\n",
+       "<g id=\"node2\" class=\"node\">\n",
+       "<title>2</title>\n",
+       "<ellipse fill=\"none\" stroke=\"black\" cx=\"117\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
+       "<text text-anchor=\"middle\" x=\"117\" y=\"-13.32\" font-family=\"Times,serif\" font-size=\"14.00\">2</text>\n",
+       "</g>\n",
+       "<!-- 1&#45;&gt;2 -->\n",
+       "<g id=\"edge1\" class=\"edge\">\n",
+       "<title>1&#45;&gt;2</title>\n",
+       "<path fill=\"none\" stroke=\"black\" d=\"M54.4,-18C62.06,-18 70.57,-18 78.76,-18\"/>\n",
+       "<polygon fill=\"black\" stroke=\"black\" points=\"78.62,-21.5 88.62,-18 78.62,-14.5 78.62,-21.5\"/>\n",
+       "</g>\n",
+       "<!-- 3 -->\n",
+       "<g id=\"node3\" class=\"node\">\n",
+       "<title>3</title>\n",
+       "<ellipse fill=\"none\" stroke=\"black\" cx=\"207\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
+       "<text text-anchor=\"middle\" x=\"207\" y=\"-13.32\" font-family=\"Times,serif\" font-size=\"14.00\">3</text>\n",
+       "</g>\n",
+       "<!-- 2&#45;&gt;3 -->\n",
+       "<g id=\"edge2\" class=\"edge\">\n",
+       "<title>2&#45;&gt;3</title>\n",
+       "<path fill=\"none\" stroke=\"black\" d=\"M144.4,-18C152.06,-18 160.57,-18 168.76,-18\"/>\n",
+       "<polygon fill=\"black\" stroke=\"black\" points=\"168.62,-21.5 178.62,-18 168.62,-14.5 168.62,-21.5\"/>\n",
+       "</g>\n",
+       "<!-- lob -->\n",
+       "<g id=\"node4\" class=\"node\">\n",
+       "<title>lob</title>\n",
+       "<ellipse fill=\"none\" stroke=\"black\" cx=\"297\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
+       "<text text-anchor=\"middle\" x=\"297\" y=\"-13.32\" font-family=\"Times,serif\" font-size=\"14.00\">lob</text>\n",
+       "</g>\n",
+       "<!-- 3&#45;&gt;lob -->\n",
+       "<g id=\"edge3\" class=\"edge\">\n",
+       "<title>3&#45;&gt;lob</title>\n",
+       "<path fill=\"none\" stroke=\"black\" d=\"M234.4,-18C242.06,-18 250.57,-18 258.76,-18\"/>\n",
+       "<polygon fill=\"black\" stroke=\"black\" points=\"258.62,-21.5 268.62,-18 258.62,-14.5 258.62,-21.5\"/>\n",
+       "</g>\n",
+       "</g>\n",
+       "</svg>\n"
+      ],
+      "text/plain": [
+       "<graphviz.sources.Source at 0x7f188a4d2690>"
+      ]
+     },
+     "execution_count": 11,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import graphviz\n",
+    "src = graphviz.Source('digraph \"the holy hand grenade\" { rankdir=LR; 1 -> 2 -> 3 -> lob }')\n",
+    "src"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "image/svg+xml": [
+       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
+       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
+       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
+       "<!-- Generated by graphviz version 8.1.0 (20230707.2238)\n",
+       " -->\n",
+       "<!-- Title: my_graph Pages: 1 -->\n",
+       "<svg width=\"66pt\" height=\"188pt\"\n",
+       " viewBox=\"0.00 0.00 66.22 188.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
+       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 184)\">\n",
+       "<title>my_graph</title>\n",
+       "<polygon fill=\"yellow\" stroke=\"none\" points=\"-4,4 -4,-184 62.22,-184 62.22,4 -4,4\"/>\n",
+       "<!-- a -->\n",
+       "<g id=\"node1\" class=\"node\">\n",
+       "<title>a</title>\n",
+       "<ellipse fill=\"none\" stroke=\"black\" cx=\"29.11\" cy=\"-162\" rx=\"29.11\" ry=\"18\"/>\n",
+       "<text text-anchor=\"middle\" x=\"29.11\" y=\"-157.32\" font-family=\"Times,serif\" font-size=\"14.00\">Foo</text>\n",
+       "</g>\n",
+       "<!-- b -->\n",
+       "<g id=\"node2\" class=\"node\">\n",
+       "<title>b</title>\n",
+       "<ellipse fill=\"none\" stroke=\"black\" cx=\"29.11\" cy=\"-90\" rx=\"18\" ry=\"18\"/>\n",
+       "<text text-anchor=\"middle\" x=\"29.11\" y=\"-85.33\" font-family=\"Times,serif\" font-size=\"14.00\">b</text>\n",
+       "</g>\n",
+       "<!-- a&#45;&#45;b -->\n",
+       "<g id=\"edge1\" class=\"edge\">\n",
+       "<title>a&#45;&#45;b</title>\n",
+       "<path fill=\"none\" stroke=\"blue\" d=\"M29.11,-143.7C29.11,-132.85 29.11,-118.92 29.11,-108.1\"/>\n",
+       "</g>\n",
+       "<!-- c -->\n",
+       "<g id=\"node3\" class=\"node\">\n",
+       "<title>c</title>\n",
+       "<ellipse fill=\"none\" stroke=\"black\" cx=\"29.11\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n",
+       "<text text-anchor=\"middle\" x=\"29.11\" y=\"-13.32\" font-family=\"Times,serif\" font-size=\"14.00\">c</text>\n",
+       "</g>\n",
+       "<!-- b&#45;&#45;c -->\n",
+       "<g id=\"edge2\" class=\"edge\">\n",
+       "<title>b&#45;&#45;c</title>\n",
+       "<path fill=\"none\" stroke=\"blue\" d=\"M29.11,-71.7C29.11,-60.85 29.11,-46.92 29.11,-36.1\"/>\n",
+       "</g>\n",
+       "</g>\n",
+       "</svg>\n"
+      ],
+      "text/plain": [
+       "<graphviz.sources.Source at 0x7fb75041de50>"
+      ]
+     },
+     "execution_count": 3,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import pydot\n",
+    "import graphviz\n",
+    "\n",
+    "dot_string = \"\"\"graph my_graph {\n",
+    "    bgcolor=\"yellow\";\n",
+    "    a [label=\"Foo\"];\n",
+    "    b [shape=circle];\n",
+    "    a -- b -- c [color=blue];\n",
+    "}\"\"\"\n",
+    "\n",
+    "graphviz.Source(dot_string)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import io\n",
+    "file = io.StringIO()\n",
+    "file.endswith = lambda *args: True\n",
+    "pydotprint(predict, file)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "??pydotprint"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "hide_input": false,
+  "kernelspec": {
+   "display_name": "pytensor",
+   "language": "python",
+   "name": "pytensor"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.11.5"
+  },
+  "toc": {
+   "base_numbering": 1,
+   "nav_menu": {},
+   "number_sections": true,
+   "sideBar": true,
+   "skip_h1_title": false,
+   "title_cell": "Table of Contents",
+   "title_sidebar": "Contents",
+   "toc_cell": false,
+   "toc_position": {},
+   "toc_section_display": true,
+   "toc_window_display": false
+  },
+  "varInspector": {
+   "cols": {
+    "lenName": 16,
+    "lenType": 16,
+    "lenVar": 40
+   },
+   "kernels_config": {
+    "python": {
+     "delete_cmd_postfix": "",
+     "delete_cmd_prefix": "del ",
+     "library": "var_list.py",
+     "varRefreshCmd": "print(var_dic_list())"
+    },
+    "r": {
+     "delete_cmd_postfix": ") ",
+     "delete_cmd_prefix": "rm(",
+     "library": "var_list.r",
+     "varRefreshCmd": "cat(var_dic_list()) "
+    }
+   },
+   "types_to_exclude": [
+    "module",
+    "function",
+    "builtin_function_or_method",
+    "instance",
+    "_Feature"
+   ],
+   "window_display": false
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/notebooks/widget.css b/notebooks/widget.css
new file mode 100644
index 0000000000..fe74d0dd81
--- /dev/null
+++ b/notebooks/widget.css
@@ -0,0 +1,6 @@
+.graphviz-container {
+  /* position: absolute;
+  top: 0;
+  left: 0; */
+  /* height: 100px; */
+}
diff --git a/notebooks/widget.js b/notebooks/widget.js
new file mode 100644
index 0000000000..45e3f9cfa1
--- /dev/null
+++ b/notebooks/widget.js
@@ -0,0 +1,50 @@
+import * as d3 from "https://cdn.skypack.dev/pin/d3@v7.8.5-eC7TKxlFLay7fmsv0gvu/dist=es2020,mode=imports,min/optimized/d3.js";
+import "https://cdn.skypack.dev/-/d3-graphviz@v5.1.0-TcGnvMu4khUzCpL7Wr2k/dist=es2020,mode=imports,min/optimized/d3-graphviz.js";
+
+export function render({ model, el }) {
+  if (!document.getElementById("graphviz_script")) {
+    const graphviz_script = document.createElement("script");
+    graphviz_script.setAttribute("id", "graphviz_script");
+    graphviz_script.setAttribute("src", "https://unpkg.com/@hpcc-js/wasm/dist/graphviz.umd.js");
+    graphviz_script.setAttribute("type", "javascript/worker");
+
+    document.head.appendChild(graphviz_script);
+  }
+  //   let getCount = () => model.get("dot");
+
+  const div = document.createElement("div");
+  div.classList.add("graphviz-container");
+  const graphContainer = d3.select(div);
+
+  let setDot = () => {
+    const dots = model.get("dots");
+    const performance = model.get("performance");
+    // Use the last dots if there is no index
+    const dot = dots[model.get("index")];
+    // const width = div.clientWidth;
+    // const height = div.clientHeight;
+    const graphviz = graphContainer
+      .graphviz({
+        // Fit graph to that size, so that all is visible
+        fit: true,
+        // Set to be as big as container
+        // width,
+        // height,
+        // Don't animate transitions between shapes for performance
+        tweenPaths: !performance,
+        tweenShapes: !performance,
+        useWorker: true,
+      })
+      .transition(() => d3.transition("t").duration(2000).ease(d3.easeLinear))
+      .renderDot(dot);
+    // If we have made a zoom selection, reset that before transitioning
+    // TODO: figure out how to transition BOTH zoom and dot at once
+    // if (graphviz._zoomSelection) {
+    //   graphviz.resetZoom();
+    // }
+  };
+
+  model.on("change:dots change:index", setDot);
+  el.appendChild(div);
+  requestAnimationFrame(() => setTimeout(setDot, 0));
+}
diff --git a/pytensor/graph/features.py b/pytensor/graph/features.py
index 2bce7f1748..0233869e67 100644
--- a/pytensor/graph/features.py
+++ b/pytensor/graph/features.py
@@ -439,6 +439,170 @@ def revert(self, fgraph, checkpoint):
         self.history[fgraph] = h
 
 
+class FullHistory(Feature):
+    """Keeps track of all changes in FunctionGraph and allows arbitrary back and forth through intermediate states
+
+    .. testcode::
+        import pytensor
+        import pytensor.tensor as pt
+        from pytensor.graph.fg import FunctionGraph
+        from pytensor.graph.features import History, FullHistory
+        from pytensor.graph.rewriting.utils import rewrite_graph
+
+        x = pt.scalar("x")
+        out = pt.log(pt.exp(x) / pt.sum(pt.exp(x)))
+
+        fg = FunctionGraph(outputs=[out])
+        history = FullHistory()
+        fg.attach_feature(history)
+
+        rewrite_graph(fg, clone=False, include=("canonicalize", "stabilize"))
+
+        # Replay rewrites
+        history.start()
+        pytensor.dprint(fg)
+        pytensor.config.optimizer_verbose = True
+        for i in range(3):
+            print(">> ", end="")
+            pytensor.dprint(history.next())
+
+    .. testoutput::
+        Log [id A] 4
+         └─ True_div [id B] 3
+            ├─ Exp [id C] 2
+            │  └─ x [id D]
+            └─ Sum{axes=None} [id E] 1
+               └─ Exp [id F] 0
+                  └─ x [id D]
+        >> MergeOptimizer
+        Log [id A] 3
+         └─ True_div [id B] 2
+            ├─ Exp [id C] 0
+            │  └─ x [id D]
+            └─ Sum{axes=None} [id E] 1
+               └─ Exp [id C] 0
+                  └─ ···
+        >> local_mul_canonizer
+        Log [id A] 1
+         └─ Softmax{axis=None} [id B] 0
+            └─ x [id C]
+        >> local_logsoftmax
+        LogSoftmax{axis=None} [id A] 0
+         └─ x [id B]
+
+
+    .. testcode::
+        # Or in reverse
+        for i in range(3):
+            print(">> ", end="")
+            pytensor.dprint(history.prev())
+        pytensor.config.optimizer_verbose = False
+
+
+    .. testoutput::
+        >> local_logsoftmax
+        Log [id A] 1
+         └─ Softmax{axis=None} [id B] 0
+            └─ x [id C]
+        >> local_mul_canonizer
+        Log [id A] 3
+         └─ True_div [id B] 2
+            ├─ Exp [id C] 0
+            │  └─ x [id D]
+            └─ Sum{axes=None} [id E] 1
+               └─ Exp [id C] 0
+                  └─ ···
+        >> MergeOptimizer
+        Log [id A] 4
+         └─ True_div [id B] 3
+            ├─ Exp [id C] 2
+            │  └─ x [id D]
+            └─ Sum{axes=None} [id E] 1
+               └─ Exp [id F] 0
+                  └─ x [id D]
+
+
+    .. testcode::
+        # Or go to any step
+        pytensor.dprint(history.goto(2))
+
+    .. testoutput::
+        Log [id A] 1
+         └─ Softmax{axis=None} [id B] 0
+            └─ x [id C]
+
+
+    """
+
+    def __init__(self):
+        self.fw = []
+        self.bw = []
+        self.pointer = -1
+        self.fg = None
+
+    def on_attach(self, fgraph):
+        if self.fg is not None:
+            raise ValueError("Full History already attached to another fgraph")
+        self.fg = fgraph
+
+    def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
+        self.bw.append(LambdaExtract(fgraph, node, i, r, reason))
+        self.fw.append(LambdaExtract(fgraph, node, i, new_r, reason))
+        self.pointer += 1
+
+    def goto(self, checkpoint):
+        """
+        Reverts the graph to whatever it was at the provided
+        checkpoint (undoes all replacements). A checkpoint at any
+        given time can be obtained using self.checkpoint().
+
+        """
+        history_len = len(self.bw)
+        pointer = self.pointer
+        assert 0 <= checkpoint <= history_len
+        verbose = config.optimizer_verbose
+
+        # Go backwards
+        while pointer > checkpoint - 1:
+            reverse_fn = self.bw[pointer]
+            if verbose:
+                print(reverse_fn.reason)
+            reverse_fn()
+            pointer -= 1
+
+        # Go forward
+        while pointer < checkpoint - 1:
+            pointer += 1
+            forward_fn = self.fw[pointer]
+            if verbose:
+                print(forward_fn.reason)
+            forward_fn()
+
+        # Remove history changes caused by the foward/backward!
+        self.bw = self.bw[:history_len]
+        self.fw = self.fw[:history_len]
+        self.pointer = pointer
+        return self.fg
+
+    def start(self):
+        return self.goto(0)
+
+    def end(self):
+        return self.goto(len(self.bw))
+
+    def prev(self):
+        if self.pointer < 0:
+            return self.fg
+        else:
+            return self.goto(self.pointer)
+
+    def next(self):
+        if self.pointer >= len(self.bw) - 1:
+            return self.fg
+        else:
+            return self.goto(self.pointer + 2)
+
+
 class Validator(Feature):
     pickle_rm_attr = ["validate", "consistent"]
 
diff --git a/pytensor/printing.py b/pytensor/printing.py
index 4d8cbb96da..298aef395e 100644
--- a/pytensor/printing.py
+++ b/pytensor/printing.py
@@ -1200,7 +1200,7 @@ def __call__(self, *args):
 
 def pydotprint(
     fct,
-    outfile: str | None = None,
+    outfile: Literal[str] | str | None = None,
     compact: bool = True,
     format: str = "png",
     with_ids: bool = False,
@@ -1606,7 +1606,7 @@ def apply_name(node):
         g.add_subgraph(c2)
         g.add_subgraph(c3)
 
-    if not outfile.endswith("." + format):
+    if outfile is not str and not outfile.endswith("." + format):
         outfile += "." + format
 
     if scan_graphs:
@@ -1640,6 +1640,9 @@ def apply_name(node):
                 scan_graphs,
             )
 
+    if outfile is str:
+        return g
+
     if return_image:
         return g.create(prog="dot", format=format)
     else:
diff --git a/tests/graph/test_features.py b/tests/graph/test_features.py
index 475af20b57..e94efba3c8 100644
--- a/tests/graph/test_features.py
+++ b/tests/graph/test_features.py
@@ -1,7 +1,9 @@
 import pytest
 
-from pytensor.graph.basic import Apply, Variable
-from pytensor.graph.features import Feature, NodeFinder, ReplaceValidate
+import pytensor.tensor as pt
+from pytensor.graph import rewrite_graph
+from pytensor.graph.basic import Apply, Variable, equal_computations
+from pytensor.graph.features import Feature, FullHistory, NodeFinder, ReplaceValidate
 from pytensor.graph.fg import FunctionGraph
 from pytensor.graph.op import Op
 from pytensor.graph.type import Type
@@ -119,3 +121,33 @@ def validate(self, *args):
 
         capres = capsys.readouterr()
         assert "rewriting: validate failed on node Op1.0" in capres.out
+
+
+def test_full_history():
+    x = pt.scalar("x")
+    out = pt.log(pt.exp(x) / pt.sum(pt.exp(x)))
+    fg = FunctionGraph(outputs=[out], clone=True, copy_inputs=False)
+    history = FullHistory()
+    fg.attach_feature(history)
+    rewrite_graph(fg, clone=False, include=("canonicalize", "stabilize"))
+
+    history.start()
+    assert equal_computations(fg.outputs, [out])
+
+    history.end()
+    assert equal_computations(fg.outputs, [pt.special.log_softmax(x)])
+
+    history.prev()
+    assert equal_computations(fg.outputs, [pt.log(pt.special.softmax(x))])
+
+    for i in range(10):
+        history.prev()
+    assert equal_computations(fg.outputs, [out])
+
+    history.goto(2)
+    assert equal_computations(fg.outputs, [pt.log(pt.special.softmax(x))])
+
+    for i in range(10):
+        history.next()
+
+    assert equal_computations(fg.outputs, [pt.special.log_softmax(x)])