diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 9677615206..831ab5d1bc 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -107,7 +107,7 @@ jobs: python-version: "3.13" include: - os: "ubuntu-latest" - part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link" + part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link --ignore=pytensor/ipython.py" python-version: "3.12" numpy-version: ">=2.0" fast-compile: 0 diff --git a/pyproject.toml b/pyproject.toml index 9a7827d83e..4201554054 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -118,7 +118,7 @@ versionfile_build = "pytensor/_version.py" tag_prefix = "rel-" [tool.pytest.ini_options] -addopts = "--durations=50 --doctest-modules --ignore=pytensor/link --ignore=pytensor/misc/check_duplicate_key.py" +addopts = "--durations=50 --doctest-modules --ignore=pytensor/link --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/ipython.py" testpaths = ["pytensor/", "tests/"] xfail_strict = true diff --git a/pytensor/graph/features.py b/pytensor/graph/features.py index 3521d3b1ba..7611a380bd 100644 --- a/pytensor/graph/features.py +++ b/pytensor/graph/features.py @@ -438,6 +438,172 @@ def revert(self, fgraph, checkpoint): self.history[fgraph] = h +class FullHistory(Feature): + """Keeps track of all changes in FunctionGraph and allows arbitrary back and forth through intermediate states + + .. testcode:: + import pytensor + import pytensor.tensor as pt + from pytensor.graph.fg import FunctionGraph + from pytensor.graph.features import FullHistory + from pytensor.graph.rewriting.utils import rewrite_graph + + x = pt.scalar("x") + out = pt.log(pt.exp(x) / pt.sum(pt.exp(x))) + + fg = FunctionGraph(outputs=[out]) + history = FullHistory() + fg.attach_feature(history) + + rewrite_graph(fg, clone=False, include=("canonicalize", "stabilize")) + + # Replay rewrites + history.start() + pytensor.dprint(fg) + with pytensor.config.change_flags(optimizer_verbose = True): + for i in range(3): + print(">> ", end="") + pytensor.dprint(history.next()) + + .. testoutput:: + Log [id A] 4 + └─ True_div [id B] 3 + ├─ Exp [id C] 2 + │ └─ x [id D] + └─ Sum{axes=None} [id E] 1 + └─ Exp [id F] 0 + └─ x [id D] + >> MergeOptimizer + Log [id A] 3 + └─ True_div [id B] 2 + ├─ Exp [id C] 0 + │ └─ x [id D] + └─ Sum{axes=None} [id E] 1 + └─ Exp [id C] 0 + └─ ··· + >> local_mul_canonizer + Log [id A] 1 + └─ Softmax{axis=None} [id B] 0 + └─ x [id C] + >> local_logsoftmax + LogSoftmax{axis=None} [id A] 0 + └─ x [id B] + + + .. testcode:: + # Or in reverse + with pytensor.config.change_flags(optimizer_verbose=True): + for i in range(3): + print(">> ", end="") + pytensor.dprint(history.prev()) + + .. testoutput:: + >> local_logsoftmax + Log [id A] 1 + └─ Softmax{axis=None} [id B] 0 + └─ x [id C] + >> local_mul_canonizer + Log [id A] 3 + └─ True_div [id B] 2 + ├─ Exp [id C] 0 + │ └─ x [id D] + └─ Sum{axes=None} [id E] 1 + └─ Exp [id C] 0 + └─ ··· + >> MergeOptimizer + Log [id A] 4 + └─ True_div [id B] 3 + ├─ Exp [id C] 2 + │ └─ x [id D] + └─ Sum{axes=None} [id E] 1 + └─ Exp [id F] 0 + └─ x [id D] + + + .. testcode:: + # Or go to any step + pytensor.dprint(history.goto(2)) + + .. testoutput:: + Log [id A] 1 + └─ Softmax{axis=None} [id B] 0 + └─ x [id C] + + + """ + + def __init__(self, callback=None): + self.fw = [] + self.bw = [] + self.pointer = -1 + self.fg = None + self.callback = callback + + def on_attach(self, fgraph): + if self.fg is not None: + raise ValueError("Full History already attached to another fgraph") + self.fg = fgraph + + def on_change_input(self, fgraph, node, i, r, new_r, reason=None): + self.bw.append(LambdaExtract(fgraph, node, i, r, reason)) + self.fw.append(LambdaExtract(fgraph, node, i, new_r, reason)) + self.pointer += 1 + if self.callback: + self.callback() + + def goto(self, checkpoint): + """ + Reverts the graph to whatever it was at the provided + checkpoint (undoes all replacements). A checkpoint at any + given time can be obtained using self.checkpoint(). + + """ + history_len = len(self.bw) + pointer = self.pointer + assert 0 <= checkpoint <= history_len + verbose = config.optimizer_verbose + + # Go backwards + while pointer > checkpoint - 1: + reverse_fn = self.bw[pointer] + if verbose: + print(reverse_fn.reason) # noqa: T201 + reverse_fn() + pointer -= 1 + + # Go forward + while pointer < checkpoint - 1: + pointer += 1 + forward_fn = self.fw[pointer] + if verbose: + print(forward_fn.reason) # noqa: T201 + forward_fn() + + # Remove history changes caused by the foward/backward! + self.bw = self.bw[:history_len] + self.fw = self.fw[:history_len] + self.pointer = pointer + return self.fg + + def start(self): + return self.goto(0) + + def end(self): + return self.goto(len(self.bw)) + + def prev(self): + if self.pointer < 0: + return self.fg + else: + return self.goto(self.pointer) + + def next(self): + if self.pointer >= len(self.bw) - 1: + return self.fg + else: + return self.goto(self.pointer + 2) + + class Validator(Feature): pickle_rm_attr = ["validate", "consistent"] diff --git a/pytensor/ipython.py b/pytensor/ipython.py new file mode 100644 index 0000000000..33adf5792d --- /dev/null +++ b/pytensor/ipython.py @@ -0,0 +1,202 @@ +import anywidget +import ipywidgets as widgets +import traitlets +from IPython.display import display + +from pytensor.graph import FunctionGraph, Variable, rewrite_graph +from pytensor.graph.features import FullHistory + + +class CodeBlockWidget(anywidget.AnyWidget): + """Widget that displays text content as a monospaced code block.""" + + content = traitlets.Unicode("").tag(sync=True) + + _esm = """ + function render({ model, el }) { + const pre = document.createElement("pre"); + pre.style.backgroundColor = "#f5f5f5"; + pre.style.padding = "10px"; + pre.style.borderRadius = "4px"; + pre.style.overflowX = "auto"; + pre.style.maxHeight = "500px"; + + const code = document.createElement("code"); + code.textContent = model.get("content"); + + pre.appendChild(code); + el.appendChild(pre); + + model.on("change:content", () => { + code.textContent = model.get("content"); + }); + } + export default { render }; + """ + + _css = """ + .jp-RenderedHTMLCommon pre { + font-family: monospace; + white-space: pre; + line-height: 1.4; + } + """ + + +class InteractiveRewrite: + """ + A class that wraps a graph history object with interactive widgets + to navigate through history and display the graph at each step. + + Includes an option to display the reason for the last change. + """ + + def __init__(self, fg, display_reason=True): + """ + Initialize with a history object that has a goto method + and tracks a FunctionGraph. + + Parameters: + ----------- + fg : FunctionGraph (or Variables) + The function graph to track + display_reason : bool, optional + Whether to display the reason for each rewrite + """ + self.history = FullHistory(callback=self._history_callback) + if not isinstance(fg, FunctionGraph): + outs = [fg] if isinstance(fg, Variable) else fg + fg = FunctionGraph(outputs=outs) + fg.attach_feature(self.history) + + self.updating_from_callback = False # Flag to prevent recursion + self.code_widget = CodeBlockWidget(content="") + self.display_reason = display_reason + + if self.display_reason: + self.reason_label = widgets.HTML( + value="", description="", style={"description_width": "initial"} + ) + self.slider_label = widgets.Label(value="") + self.slider = widgets.IntSlider( + value=self.history.pointer, + min=0, + max=0, + step=1, + description="", # Empty description since we're using a separate label + continuous_update=True, + layout=widgets.Layout(width="300px"), + ) + self.prev_button = widgets.Button(description="← Previous") + self.next_button = widgets.Button(description="Next →") + self.slider.observe(self._on_slider_change, names="value") + self.prev_button.on_click(self._on_prev_click) + self.next_button.on_click(self._on_next_click) + + self.rewrite_button = widgets.Button( + description="Apply Rewrites", + button_style="primary", # 'success', 'info', 'warning', 'danger' or '' + tooltip="Apply default rewrites to the current graph", + icon="cogs", # Optional: add an icon (requires font-awesome) + ) + self.rewrite_button.on_click(self._on_rewrite_click) + + self.nav_button_box = widgets.HBox([self.prev_button, self.next_button]) + self.slider_box = widgets.HBox([self.slider_label, self.slider]) + self.control_box = widgets.HBox([self.slider_box, self.rewrite_button]) + + # Update the display with the initial state + self._update_display() + + def _on_slider_change(self, change): + """Handle slider value changes""" + if change["name"] == "value" and not self.updating_from_callback: + self.updating_from_callback = True + index = change["new"] + self.history.goto(index) + self._update_display() + self.updating_from_callback = False + + def _on_prev_click(self, b): + """Go to previous history item""" + if self.slider.value > 0: + self.slider.value -= 1 + + def _on_next_click(self, b): + """Go to next history item""" + if self.slider.value < self.slider.max: + self.slider.value += 1 + + def _on_rewrite_click(self, b): + """Handle rewrite button click""" + self.slider.value = self.slider.max + self.rewrite() + + def display(self): + """Display the full widget interface""" + display( + widgets.VBox( + [ + self.control_box, + self.nav_button_box, + *((self.reason_label,) if self.display_reason else ()), + self.code_widget, + ] + ) + ) + + def _ipython_display_(self): + self.display() + + def _history_callback(self): + """Callback for history updates that prevents recursion""" + if not self.updating_from_callback: + self.updating_from_callback = True + self._update_display() + self.updating_from_callback = False + + def _update_display(self): + """Update the code widget with the current graph and reason""" + # Update the reason label if checkbox is checked + if self.display_reason: + if self.history.pointer == -1: + reason = "" + else: + reason = self.history.fw[self.history.pointer].reason + reason = getattr(reason, "name", str(reason)) + + self.reason_label.value = f""" +