diff --git a/notebooks/Grahpviz rewrites.ipynb b/notebooks/Grahpviz rewrites.ipynb new file mode 100644 index 0000000000..b5c9e2703c --- /dev/null +++ b/notebooks/Grahpviz rewrites.ipynb @@ -0,0 +1,613 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import pytensor\n", + "import pytensor.tensor as pt\n", + "from pytensor.graph.fg import FunctionGraph\n", + "from pytensor.graph.features import History, FullHistory\n", + "from pytensor.graph.rewriting.utils import rewrite_graph" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "FunctionGraph(LogSoftmax{axis=None}(x))" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + " x = pt.scalar(\"x\")\n", + "out = pt.log(pt.exp(x) / pt.sum(pt.exp(x)))\n", + "\n", + "fg = FunctionGraph(outputs=[out])\n", + "history = FullHistory()\n", + "fg.attach_feature(history)\n", + "\n", + "rewrite_graph(fg, clone=False, include=(\"canonicalize\", \"stabilize\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", + "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", + " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", + "<!-- Generated by graphviz version 8.1.0 (20230707.2238)\n", + " -->\n", + "<!-- Title: G Pages: 1 -->\n", + "<svg width=\"531pt\" height=\"477pt\"\n", + " viewBox=\"0.00 0.00 530.82 476.50\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", + "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 472.5)\">\n", + "<title>G</title>\n", + "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-472.5 526.82,-472.5 526.82,4 -4,4\"/>\n", + "<!-- 139904552140112 -->\n", + "<g id=\"node1\" class=\"node\">\n", + "<title>139904552140112</title>\n", + "<ellipse fill=\"#ffaabb\" stroke=\"black\" cx=\"177.82\" cy=\"-377.5\" rx=\"30.69\" ry=\"18\"/>\n", + "<text text-anchor=\"middle\" x=\"177.82\" y=\"-372.82\" font-family=\"Times,serif\" font-size=\"14.00\">Exp</text>\n", + "</g>\n", + "<!-- 139904567218832 -->\n", + "<g id=\"node3\" class=\"node\">\n", + "<title>139904567218832</title>\n", + "<ellipse fill=\"none\" stroke=\"black\" cx=\"103.82\" cy=\"-269.5\" rx=\"103.82\" ry=\"18\"/>\n", + "<text text-anchor=\"middle\" x=\"103.82\" y=\"-264.82\" font-family=\"Times,serif\" font-size=\"14.00\">Sum{axes=None}</text>\n", + "</g>\n", + "<!-- 139904552140112->139904567218832 -->\n", + "<g id=\"edge2\" class=\"edge\">\n", + "<title>139904552140112->139904567218832</title>\n", + "<path fill=\"none\" stroke=\"black\" d=\"M147.19,-374.06C125.42,-370.55 97.61,-362.03 83.32,-341.5 74.15,-328.32 78.84,-311.07 86.03,-296.86\"/>\n", + "<polygon fill=\"black\" stroke=\"black\" points=\"89.39,-299.04 91.27,-288.61 83.3,-295.59 89.39,-299.04\"/>\n", + "<text text-anchor=\"middle\" x=\"173.07\" y=\"-318.82\" font-family=\"Times,serif\" font-size=\"14.00\">Scalar(float64, shape=())</text>\n", + "</g>\n", + "<!-- 139904803887568 -->\n", + "<g id=\"node2\" class=\"node\">\n", + "<title>139904803887568</title>\n", + "<polygon fill=\"green\" stroke=\"black\" points=\"381.94,-468.5 123.69,-468.5 123.69,-432.5 381.94,-432.5 381.94,-468.5\"/>\n", + "<text text-anchor=\"middle\" x=\"252.82\" y=\"-445.82\" font-family=\"Times,serif\" font-size=\"14.00\">name=x Scalar(float64, shape=())</text>\n", + "</g>\n", + "<!-- 139904803887568->139904552140112 -->\n", + "<g id=\"edge1\" class=\"edge\">\n", + "<title>139904803887568->139904552140112</title>\n", + "<path fill=\"none\" stroke=\"black\" d=\"M234.66,-432.31C224.58,-422.76 211.92,-410.78 201.11,-400.55\"/>\n", + "<polygon fill=\"black\" stroke=\"black\" points=\"203.94,-398.46 194.27,-394.13 199.13,-403.55 203.94,-398.46\"/>\n", + "</g>\n", + "<!-- 139904546108816 -->\n", + "<g id=\"node4\" class=\"node\">\n", + "<title>139904546108816</title>\n", + "<ellipse fill=\"#ffaabb\" stroke=\"black\" cx=\"328.82\" cy=\"-323.5\" rx=\"58.05\" ry=\"18\"/>\n", + "<text text-anchor=\"middle\" x=\"328.82\" y=\"-318.82\" font-family=\"Times,serif\" font-size=\"14.00\">Exp id=2</text>\n", + "</g>\n", + "<!-- 139904803887568->139904546108816 -->\n", + "<g id=\"edge3\" class=\"edge\">\n", + "<title>139904803887568->139904546108816</title>\n", + "<path fill=\"none\" stroke=\"black\" d=\"M263.39,-432.12C276.28,-410.92 298.19,-374.88 313.07,-350.41\"/>\n", + "<polygon fill=\"black\" stroke=\"black\" points=\"316.38,-352.7 318.58,-342.33 310.4,-349.06 316.38,-352.7\"/>\n", + "</g>\n", + "<!-- 139904546108240 -->\n", + "<g id=\"node5\" class=\"node\">\n", + "<title>139904546108240</title>\n", + "<ellipse fill=\"#ffaabb\" stroke=\"black\" cx=\"215.82\" cy=\"-180.25\" rx=\"54.36\" ry=\"18\"/>\n", + "<text text-anchor=\"middle\" x=\"215.82\" y=\"-175.57\" font-family=\"Times,serif\" font-size=\"14.00\">True_div</text>\n", + "</g>\n", + "<!-- 139904567218832->139904546108240 -->\n", + "<g id=\"edge5\" class=\"edge\">\n", + "<title>139904567218832->139904546108240</title>\n", + "<path fill=\"none\" stroke=\"black\" d=\"M104.46,-251.01C105.75,-239.8 109.16,-225.64 117.82,-216.25 128.32,-204.86 142.77,-197.18 157.36,-191.99\"/>\n", + "<polygon fill=\"black\" stroke=\"black\" points=\"157.96,-195.16 166.43,-188.79 155.84,-188.48 157.96,-195.16\"/>\n", + "<text text-anchor=\"middle\" x=\"213.82\" y=\"-220.2\" font-family=\"Times,serif\" font-size=\"14.00\">1 Scalar(float64, shape=())</text>\n", + "</g>\n", + "<!-- 139904546108816->139904546108240 -->\n", + "<g id=\"edge4\" class=\"edge\">\n", + "<title>139904546108816->139904546108240</title>\n", + "<path fill=\"none\" stroke=\"black\" d=\"M330.6,-305.27C332.01,-282.29 330.96,-241.57 309.82,-216.25 300.57,-205.17 287.46,-197.6 273.95,-192.42\"/>\n", + "<polygon fill=\"black\" stroke=\"black\" points=\"275.2,-188.81 264.61,-188.91 272.95,-195.44 275.2,-188.81\"/>\n", + "<text text-anchor=\"middle\" x=\"426.82\" y=\"-264.82\" font-family=\"Times,serif\" font-size=\"14.00\">0 Scalar(float64, shape=())</text>\n", + "</g>\n", + "<!-- 139904546106512 -->\n", + "<g id=\"node6\" class=\"node\">\n", + "<title>139904546106512</title>\n", + "<ellipse fill=\"#ffaabb\" stroke=\"black\" cx=\"215.82\" cy=\"-91\" rx=\"29.64\" ry=\"18\"/>\n", + "<text text-anchor=\"middle\" x=\"215.82\" y=\"-86.33\" font-family=\"Times,serif\" font-size=\"14.00\">Log</text>\n", + "</g>\n", + "<!-- 139904546108240->139904546106512 -->\n", + "<g id=\"edge6\" class=\"edge\">\n", + "<title>139904546108240->139904546106512</title>\n", + "<path fill=\"none\" stroke=\"black\" d=\"M215.82,-162.01C215.82,-150.06 215.82,-133.88 215.82,-120.08\"/>\n", + "<polygon fill=\"black\" stroke=\"black\" points=\"219.32,-120.2 215.82,-110.2 212.32,-120.2 219.32,-120.2\"/>\n", + "<text text-anchor=\"middle\" x=\"305.07\" y=\"-130.95\" font-family=\"Times,serif\" font-size=\"14.00\">Scalar(float64, shape=())</text>\n", + "</g>\n", + "<!-- 139904546106320 -->\n", + "<g id=\"node7\" class=\"node\">\n", + "<title>139904546106320</title>\n", + "<polygon fill=\"blue\" stroke=\"black\" points=\"313.07,-36 118.57,-36 118.57,0 313.07,0 313.07,-36\"/>\n", + "<text text-anchor=\"middle\" x=\"215.82\" y=\"-13.32\" font-family=\"Times,serif\" font-size=\"14.00\">Scalar(float64, shape=())</text>\n", + "</g>\n", + "<!-- 139904546106512->139904546106320 -->\n", + "<g id=\"edge7\" class=\"edge\">\n", + "<title>139904546106512->139904546106320</title>\n", + "<path fill=\"none\" stroke=\"black\" d=\"M215.82,-72.81C215.82,-65.05 215.82,-55.68 215.82,-46.95\"/>\n", + "<polygon fill=\"black\" stroke=\"black\" points=\"219.32,-47.03 215.82,-37.03 212.32,-47.03 219.32,-47.03\"/>\n", + "</g>\n", + "</g>\n", + "</svg>\n" + ], + "text/plain": [ + "<graphviz.sources.Source at 0x7f3e10a9ec90>" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import graphviz\n", + "from pytensor.printing import pydotprint\n", + "out = pydotprint(history.start(), str)\n", + "gv = graphviz.Source(out.to_string())\n", + "gv" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "PosixPath('/home/evelin/Documents/Ricardo/Projects/pytensor/notebooks/widget.js')" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pathlib\n", + "import os\n", + "(pathlib.Path().parent / \"widget.js\").absolute()" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e6744349c4d04aac8cd51f08278f3042", + "version_major": 2, + "version_minor": 1 + }, + "text/plain": [ + "CounterWidget()" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import anywidget\n", + "import traitlets\n", + "\n", + "class CounterWidget(anywidget.AnyWidget):\n", + " # Widget front-end JavaScript code\n", + " _esm = \"\"\"\n", + " function render({ model, el }) {\n", + " let button = document.createElement(\"button\");\n", + " button.innerHTML = `count is ${model.get(\"value\")}`;\n", + " button.addEventListener(\"click\", () => {\n", + " model.set(\"value\", model.get(\"value\") + 1);\n", + " model.save_changes();\n", + " });\n", + " model.on(\"change:value\", () => {\n", + " button.innerHTML = `count is ${model.get(\"value\")}`;\n", + " });\n", + " el.appendChild(button);\n", + " }\n", + " export default { render };\n", + " \"\"\"\n", + " # Stateful property that can be accessed by JavaScript & Python\n", + " value = traitlets.Int(0).tag(sync=True)\n", + " \n", + "counter = CounterWidget()\n", + "counter" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "89a17dbc5e30488ba64fb084f383a7df", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "IntSlider(value=0, max=5)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import ipywidgets\n", + "ipywidgets.IntSlider(max=5, value=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "import anywidget\n", + "import pathlib\n", + "import traitlets\n", + "import ipywidgets\n", + "\n", + "class GraphvizWidget(anywidget.AnyWidget):\n", + " \"\"\"\n", + " Graphviz widget to render multiple graphviz graphs.\n", + "\n", + " The index will choose the one that is currently displayed, defaulting to the last one.\n", + " If the index or the graphs change, there will be a re-rendering, with animation.\n", + " \"\"\"\n", + "\n", + " _esm = r\"/home/evelin/Documents/Ricardo/Projects/pytensor/notebooks/widget.js\"\n", + " _css = r\"/home/evelin/Documents/Ricardo/Projects/pytensor/notebooks/widget.css\"\n", + " dots = traitlets.List().tag(sync=True)\n", + " index = traitlets.Int(0, allow_none=True).tag(sync=True)\n", + " performance = traitlets.Bool(False).tag(sync=True)\n", + "\n", + " \n", + "class CounterWidget(anywidget.AnyWidget):\n", + " # Widget front-end JavaScript code\n", + " _esm = \"\"\"\n", + " function render({ model, el }) {\n", + " let button = document.createElement(\"button\");\n", + " button.innerHTML = `count is ${model.get(\"value\")}`;\n", + " button.addEventListener(\"click\", () => {\n", + " model.set(\"value\", model.get(\"value\") + 1);\n", + " model.save_changes();\n", + " });\n", + " model.on(\"change:value\", () => {\n", + " button.innerHTML = `count is ${model.get(\"value\")}`;\n", + " });\n", + " el.appendChild(button);\n", + " }\n", + " export default { render };\n", + " \"\"\"\n", + " # Stateful property that can be accessed by JavaScript & Python\n", + " value = traitlets.Int(0).tag(sync=True)\n", + " \n", + "\n", + "def graphviz_widget_with_slider(dots: list[str], *, performance: bool = False) -> ipywidgets.VBox:\n", + " n_dots = len(dots)\n", + " graphviz_widget = GraphvizWidget()\n", + " graphviz_widget.dots = dots\n", + " graphviz_widget.performance = performance\n", + " slider_widget = ipywidgets.IntSlider(max=n_dots - 1, value=0)\n", + " ipywidgets.jslink((slider_widget, \"value\"), (graphviz_widget, \"index\"))\n", + "# play_widget = ipywidgets.Play(max=n_dots - 1, repeat=True, interval=4000)\n", + "# ipywidgets.jslink((slider_widget, \"value\"), (play_widget, \"value\"))\n", + "# top = ipywidgets.HBox([play_widget, slider_widget])\n", + " return ipywidgets.VBox([slider_widget, CounterWidget()])" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "graphviz_widget_with_slider([out.to_string(), out.to_string()])" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "import graphviz\n", + "from pytensor.printing import pydotprint\n", + "out = pydotprint(predict, str)\n", + "gv = graphviz.Source(out.to_string())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import Image\n", + "Image('./examples/mlp.png', width='80%')" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", + "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", + " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", + "<!-- Generated by graphviz version 8.1.0 (20230707.2238)\n", + " -->\n", + "<!-- Title: the holy hand grenade Pages: 1 -->\n", + "<svg width=\"332pt\" height=\"44pt\"\n", + " viewBox=\"0.00 0.00 332.00 44.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", + "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 40)\">\n", + "<title>the holy hand grenade</title>\n", + "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-40 328,-40 328,4 -4,4\"/>\n", + "<!-- 1 -->\n", + "<g id=\"node1\" class=\"node\">\n", + "<title>1</title>\n", + "<ellipse fill=\"none\" stroke=\"black\" cx=\"27\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n", + "<text text-anchor=\"middle\" x=\"27\" y=\"-13.32\" font-family=\"Times,serif\" font-size=\"14.00\">1</text>\n", + "</g>\n", + "<!-- 2 -->\n", + "<g id=\"node2\" class=\"node\">\n", + "<title>2</title>\n", + "<ellipse fill=\"none\" stroke=\"black\" cx=\"117\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n", + "<text text-anchor=\"middle\" x=\"117\" y=\"-13.32\" font-family=\"Times,serif\" font-size=\"14.00\">2</text>\n", + "</g>\n", + "<!-- 1->2 -->\n", + "<g id=\"edge1\" class=\"edge\">\n", + "<title>1->2</title>\n", + "<path fill=\"none\" stroke=\"black\" d=\"M54.4,-18C62.06,-18 70.57,-18 78.76,-18\"/>\n", + "<polygon fill=\"black\" stroke=\"black\" points=\"78.62,-21.5 88.62,-18 78.62,-14.5 78.62,-21.5\"/>\n", + "</g>\n", + "<!-- 3 -->\n", + "<g id=\"node3\" class=\"node\">\n", + "<title>3</title>\n", + "<ellipse fill=\"none\" stroke=\"black\" cx=\"207\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n", + "<text text-anchor=\"middle\" x=\"207\" y=\"-13.32\" font-family=\"Times,serif\" font-size=\"14.00\">3</text>\n", + "</g>\n", + "<!-- 2->3 -->\n", + "<g id=\"edge2\" class=\"edge\">\n", + "<title>2->3</title>\n", + "<path fill=\"none\" stroke=\"black\" d=\"M144.4,-18C152.06,-18 160.57,-18 168.76,-18\"/>\n", + "<polygon fill=\"black\" stroke=\"black\" points=\"168.62,-21.5 178.62,-18 168.62,-14.5 168.62,-21.5\"/>\n", + "</g>\n", + "<!-- lob -->\n", + "<g id=\"node4\" class=\"node\">\n", + "<title>lob</title>\n", + "<ellipse fill=\"none\" stroke=\"black\" cx=\"297\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n", + "<text text-anchor=\"middle\" x=\"297\" y=\"-13.32\" font-family=\"Times,serif\" font-size=\"14.00\">lob</text>\n", + "</g>\n", + "<!-- 3->lob -->\n", + "<g id=\"edge3\" class=\"edge\">\n", + "<title>3->lob</title>\n", + "<path fill=\"none\" stroke=\"black\" d=\"M234.4,-18C242.06,-18 250.57,-18 258.76,-18\"/>\n", + "<polygon fill=\"black\" stroke=\"black\" points=\"258.62,-21.5 268.62,-18 258.62,-14.5 258.62,-21.5\"/>\n", + "</g>\n", + "</g>\n", + "</svg>\n" + ], + "text/plain": [ + "<graphviz.sources.Source at 0x7f188a4d2690>" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import graphviz\n", + "src = graphviz.Source('digraph \"the holy hand grenade\" { rankdir=LR; 1 -> 2 -> 3 -> lob }')\n", + "src" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", + "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", + " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", + "<!-- Generated by graphviz version 8.1.0 (20230707.2238)\n", + " -->\n", + "<!-- Title: my_graph Pages: 1 -->\n", + "<svg width=\"66pt\" height=\"188pt\"\n", + " viewBox=\"0.00 0.00 66.22 188.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", + "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 184)\">\n", + "<title>my_graph</title>\n", + "<polygon fill=\"yellow\" stroke=\"none\" points=\"-4,4 -4,-184 62.22,-184 62.22,4 -4,4\"/>\n", + "<!-- a -->\n", + "<g id=\"node1\" class=\"node\">\n", + "<title>a</title>\n", + "<ellipse fill=\"none\" stroke=\"black\" cx=\"29.11\" cy=\"-162\" rx=\"29.11\" ry=\"18\"/>\n", + "<text text-anchor=\"middle\" x=\"29.11\" y=\"-157.32\" font-family=\"Times,serif\" font-size=\"14.00\">Foo</text>\n", + "</g>\n", + "<!-- b -->\n", + "<g id=\"node2\" class=\"node\">\n", + "<title>b</title>\n", + "<ellipse fill=\"none\" stroke=\"black\" cx=\"29.11\" cy=\"-90\" rx=\"18\" ry=\"18\"/>\n", + "<text text-anchor=\"middle\" x=\"29.11\" y=\"-85.33\" font-family=\"Times,serif\" font-size=\"14.00\">b</text>\n", + "</g>\n", + "<!-- a--b -->\n", + "<g id=\"edge1\" class=\"edge\">\n", + "<title>a--b</title>\n", + "<path fill=\"none\" stroke=\"blue\" d=\"M29.11,-143.7C29.11,-132.85 29.11,-118.92 29.11,-108.1\"/>\n", + "</g>\n", + "<!-- c -->\n", + "<g id=\"node3\" class=\"node\">\n", + "<title>c</title>\n", + "<ellipse fill=\"none\" stroke=\"black\" cx=\"29.11\" cy=\"-18\" rx=\"27\" ry=\"18\"/>\n", + "<text text-anchor=\"middle\" x=\"29.11\" y=\"-13.32\" font-family=\"Times,serif\" font-size=\"14.00\">c</text>\n", + "</g>\n", + "<!-- b--c -->\n", + "<g id=\"edge2\" class=\"edge\">\n", + "<title>b--c</title>\n", + "<path fill=\"none\" stroke=\"blue\" d=\"M29.11,-71.7C29.11,-60.85 29.11,-46.92 29.11,-36.1\"/>\n", + "</g>\n", + "</g>\n", + "</svg>\n" + ], + "text/plain": [ + "<graphviz.sources.Source at 0x7fb75041de50>" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pydot\n", + "import graphviz\n", + "\n", + "dot_string = \"\"\"graph my_graph {\n", + " bgcolor=\"yellow\";\n", + " a [label=\"Foo\"];\n", + " b [shape=circle];\n", + " a -- b -- c [color=blue];\n", + "}\"\"\"\n", + "\n", + "graphviz.Source(dot_string)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import io\n", + "file = io.StringIO()\n", + "file.endswith = lambda *args: True\n", + "pydotprint(predict, file)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "??pydotprint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "hide_input": false, + "kernelspec": { + "display_name": "pytensor", + "language": "python", + "name": "pytensor" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + }, + "varInspector": { + "cols": { + "lenName": 16, + "lenType": 16, + "lenVar": 40 + }, + "kernels_config": { + "python": { + "delete_cmd_postfix": "", + "delete_cmd_prefix": "del ", + "library": "var_list.py", + "varRefreshCmd": "print(var_dic_list())" + }, + "r": { + "delete_cmd_postfix": ") ", + "delete_cmd_prefix": "rm(", + "library": "var_list.r", + "varRefreshCmd": "cat(var_dic_list()) " + } + }, + "types_to_exclude": [ + "module", + "function", + "builtin_function_or_method", + "instance", + "_Feature" + ], + "window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/widget.css b/notebooks/widget.css new file mode 100644 index 0000000000..fe74d0dd81 --- /dev/null +++ b/notebooks/widget.css @@ -0,0 +1,6 @@ +.graphviz-container { + /* position: absolute; + top: 0; + left: 0; */ + /* height: 100px; */ +} diff --git a/notebooks/widget.js b/notebooks/widget.js new file mode 100644 index 0000000000..45e3f9cfa1 --- /dev/null +++ b/notebooks/widget.js @@ -0,0 +1,50 @@ +import * as d3 from "https://cdn.skypack.dev/pin/d3@v7.8.5-eC7TKxlFLay7fmsv0gvu/dist=es2020,mode=imports,min/optimized/d3.js"; +import "https://cdn.skypack.dev/-/d3-graphviz@v5.1.0-TcGnvMu4khUzCpL7Wr2k/dist=es2020,mode=imports,min/optimized/d3-graphviz.js"; + +export function render({ model, el }) { + if (!document.getElementById("graphviz_script")) { + const graphviz_script = document.createElement("script"); + graphviz_script.setAttribute("id", "graphviz_script"); + graphviz_script.setAttribute("src", "https://unpkg.com/@hpcc-js/wasm/dist/graphviz.umd.js"); + graphviz_script.setAttribute("type", "javascript/worker"); + + document.head.appendChild(graphviz_script); + } + // let getCount = () => model.get("dot"); + + const div = document.createElement("div"); + div.classList.add("graphviz-container"); + const graphContainer = d3.select(div); + + let setDot = () => { + const dots = model.get("dots"); + const performance = model.get("performance"); + // Use the last dots if there is no index + const dot = dots[model.get("index")]; + // const width = div.clientWidth; + // const height = div.clientHeight; + const graphviz = graphContainer + .graphviz({ + // Fit graph to that size, so that all is visible + fit: true, + // Set to be as big as container + // width, + // height, + // Don't animate transitions between shapes for performance + tweenPaths: !performance, + tweenShapes: !performance, + useWorker: true, + }) + .transition(() => d3.transition("t").duration(2000).ease(d3.easeLinear)) + .renderDot(dot); + // If we have made a zoom selection, reset that before transitioning + // TODO: figure out how to transition BOTH zoom and dot at once + // if (graphviz._zoomSelection) { + // graphviz.resetZoom(); + // } + }; + + model.on("change:dots change:index", setDot); + el.appendChild(div); + requestAnimationFrame(() => setTimeout(setDot, 0)); +} diff --git a/pytensor/graph/features.py b/pytensor/graph/features.py index 2bce7f1748..0233869e67 100644 --- a/pytensor/graph/features.py +++ b/pytensor/graph/features.py @@ -439,6 +439,170 @@ def revert(self, fgraph, checkpoint): self.history[fgraph] = h +class FullHistory(Feature): + """Keeps track of all changes in FunctionGraph and allows arbitrary back and forth through intermediate states + + .. testcode:: + import pytensor + import pytensor.tensor as pt + from pytensor.graph.fg import FunctionGraph + from pytensor.graph.features import History, FullHistory + from pytensor.graph.rewriting.utils import rewrite_graph + + x = pt.scalar("x") + out = pt.log(pt.exp(x) / pt.sum(pt.exp(x))) + + fg = FunctionGraph(outputs=[out]) + history = FullHistory() + fg.attach_feature(history) + + rewrite_graph(fg, clone=False, include=("canonicalize", "stabilize")) + + # Replay rewrites + history.start() + pytensor.dprint(fg) + pytensor.config.optimizer_verbose = True + for i in range(3): + print(">> ", end="") + pytensor.dprint(history.next()) + + .. testoutput:: + Log [id A] 4 + └─ True_div [id B] 3 + ├─ Exp [id C] 2 + │ └─ x [id D] + └─ Sum{axes=None} [id E] 1 + └─ Exp [id F] 0 + └─ x [id D] + >> MergeOptimizer + Log [id A] 3 + └─ True_div [id B] 2 + ├─ Exp [id C] 0 + │ └─ x [id D] + └─ Sum{axes=None} [id E] 1 + └─ Exp [id C] 0 + └─ ··· + >> local_mul_canonizer + Log [id A] 1 + └─ Softmax{axis=None} [id B] 0 + └─ x [id C] + >> local_logsoftmax + LogSoftmax{axis=None} [id A] 0 + └─ x [id B] + + + .. testcode:: + # Or in reverse + for i in range(3): + print(">> ", end="") + pytensor.dprint(history.prev()) + pytensor.config.optimizer_verbose = False + + + .. testoutput:: + >> local_logsoftmax + Log [id A] 1 + └─ Softmax{axis=None} [id B] 0 + └─ x [id C] + >> local_mul_canonizer + Log [id A] 3 + └─ True_div [id B] 2 + ├─ Exp [id C] 0 + │ └─ x [id D] + └─ Sum{axes=None} [id E] 1 + └─ Exp [id C] 0 + └─ ··· + >> MergeOptimizer + Log [id A] 4 + └─ True_div [id B] 3 + ├─ Exp [id C] 2 + │ └─ x [id D] + └─ Sum{axes=None} [id E] 1 + └─ Exp [id F] 0 + └─ x [id D] + + + .. testcode:: + # Or go to any step + pytensor.dprint(history.goto(2)) + + .. testoutput:: + Log [id A] 1 + └─ Softmax{axis=None} [id B] 0 + └─ x [id C] + + + """ + + def __init__(self): + self.fw = [] + self.bw = [] + self.pointer = -1 + self.fg = None + + def on_attach(self, fgraph): + if self.fg is not None: + raise ValueError("Full History already attached to another fgraph") + self.fg = fgraph + + def on_change_input(self, fgraph, node, i, r, new_r, reason=None): + self.bw.append(LambdaExtract(fgraph, node, i, r, reason)) + self.fw.append(LambdaExtract(fgraph, node, i, new_r, reason)) + self.pointer += 1 + + def goto(self, checkpoint): + """ + Reverts the graph to whatever it was at the provided + checkpoint (undoes all replacements). A checkpoint at any + given time can be obtained using self.checkpoint(). + + """ + history_len = len(self.bw) + pointer = self.pointer + assert 0 <= checkpoint <= history_len + verbose = config.optimizer_verbose + + # Go backwards + while pointer > checkpoint - 1: + reverse_fn = self.bw[pointer] + if verbose: + print(reverse_fn.reason) + reverse_fn() + pointer -= 1 + + # Go forward + while pointer < checkpoint - 1: + pointer += 1 + forward_fn = self.fw[pointer] + if verbose: + print(forward_fn.reason) + forward_fn() + + # Remove history changes caused by the foward/backward! + self.bw = self.bw[:history_len] + self.fw = self.fw[:history_len] + self.pointer = pointer + return self.fg + + def start(self): + return self.goto(0) + + def end(self): + return self.goto(len(self.bw)) + + def prev(self): + if self.pointer < 0: + return self.fg + else: + return self.goto(self.pointer) + + def next(self): + if self.pointer >= len(self.bw) - 1: + return self.fg + else: + return self.goto(self.pointer + 2) + + class Validator(Feature): pickle_rm_attr = ["validate", "consistent"] diff --git a/pytensor/printing.py b/pytensor/printing.py index 4d8cbb96da..298aef395e 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -1200,7 +1200,7 @@ def __call__(self, *args): def pydotprint( fct, - outfile: str | None = None, + outfile: Literal[str] | str | None = None, compact: bool = True, format: str = "png", with_ids: bool = False, @@ -1606,7 +1606,7 @@ def apply_name(node): g.add_subgraph(c2) g.add_subgraph(c3) - if not outfile.endswith("." + format): + if outfile is not str and not outfile.endswith("." + format): outfile += "." + format if scan_graphs: @@ -1640,6 +1640,9 @@ def apply_name(node): scan_graphs, ) + if outfile is str: + return g + if return_image: return g.create(prog="dot", format=format) else: diff --git a/tests/graph/test_features.py b/tests/graph/test_features.py index 475af20b57..e94efba3c8 100644 --- a/tests/graph/test_features.py +++ b/tests/graph/test_features.py @@ -1,7 +1,9 @@ import pytest -from pytensor.graph.basic import Apply, Variable -from pytensor.graph.features import Feature, NodeFinder, ReplaceValidate +import pytensor.tensor as pt +from pytensor.graph import rewrite_graph +from pytensor.graph.basic import Apply, Variable, equal_computations +from pytensor.graph.features import Feature, FullHistory, NodeFinder, ReplaceValidate from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op from pytensor.graph.type import Type @@ -119,3 +121,33 @@ def validate(self, *args): capres = capsys.readouterr() assert "rewriting: validate failed on node Op1.0" in capres.out + + +def test_full_history(): + x = pt.scalar("x") + out = pt.log(pt.exp(x) / pt.sum(pt.exp(x))) + fg = FunctionGraph(outputs=[out], clone=True, copy_inputs=False) + history = FullHistory() + fg.attach_feature(history) + rewrite_graph(fg, clone=False, include=("canonicalize", "stabilize")) + + history.start() + assert equal_computations(fg.outputs, [out]) + + history.end() + assert equal_computations(fg.outputs, [pt.special.log_softmax(x)]) + + history.prev() + assert equal_computations(fg.outputs, [pt.log(pt.special.softmax(x))]) + + for i in range(10): + history.prev() + assert equal_computations(fg.outputs, [out]) + + history.goto(2) + assert equal_computations(fg.outputs, [pt.log(pt.special.softmax(x))]) + + for i in range(10): + history.next() + + assert equal_computations(fg.outputs, [pt.special.log_softmax(x)])