From 16ea1711a2b96141297b1fb63c103367bba35766 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 17 Dec 2024 10:24:07 +0000 Subject: [PATCH 01/26] feat: Tracing prototype --- examples/tracing.ipynb | 528 ++++++++++++++++++++++ guppylang/decorator.py | 99 ++-- guppylang/definition/struct.py | 4 +- guppylang/definition/traced.py | 146 ++++++ guppylang/definition/ty.py | 8 +- guppylang/error.py | 8 +- guppylang/module.py | 22 +- guppylang/std/_internal/compiler/array.py | 50 ++ guppylang/tracing/__init__.py | 0 guppylang/tracing/function.py | 138 ++++++ guppylang/tracing/object.py | 256 +++++++++++ guppylang/tracing/state.py | 54 +++ guppylang/tracing/unpacking.py | 164 +++++++ guppylang/tracing/util.py | 63 +++ guppylang/tys/builtin.py | 8 + tests/integration/test_examples.py | 5 + 16 files changed, 1486 insertions(+), 67 deletions(-) create mode 100644 examples/tracing.ipynb create mode 100644 guppylang/definition/traced.py create mode 100644 guppylang/tracing/__init__.py create mode 100644 guppylang/tracing/function.py create mode 100644 guppylang/tracing/object.py create mode 100644 guppylang/tracing/state.py create mode 100644 guppylang/tracing/unpacking.py create mode 100644 guppylang/tracing/util.py diff --git a/examples/tracing.ipynb b/examples/tracing.ipynb new file mode 100644 index 00000000..f7509b73 --- /dev/null +++ b/examples/tracing.ipynb @@ -0,0 +1,528 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fccf55ad-3db2-40a4-89c7-8c5e08f3397c", + "metadata": {}, + "source": [ + "# Tracing Demo\n", + "\n", + "This notebook showcases the new experimental tracing mode in Guppy." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "8f121875-45f5-4cb8-aeef-2c56542a845e", + "metadata": {}, + "outputs": [], + "source": [ + "from guppylang import GuppyModule, qubit\n", + "from guppylang.decorator import guppy\n", + "from guppylang.std.builtins import array\n", + "from guppylang.std.quantum import h, cx, rz, measure\n", + "from guppylang.std.angles import angle, pi\n", + "from guppylang.std import quantum\n", + "from guppylang.std import angles\n", + "\n", + "from hugr.hugr.render import DotRenderer" + ] + }, + { + "cell_type": "markdown", + "id": "d03f3ee4-3ec7-43be-beae-8d832aef2e31", + "metadata": {}, + "source": [ + "## Intro\n", + "\n", + "Traced functions are executed using the Python interpreter which drives the Hugr generation. Thus, traced functions can contain arbitrary Python code, but everything is evaluated at compile-time. The result is a flat Hugr program:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "3404519a-e790-4028-87f8-a9d25cee9eab", + "metadata": {}, + "outputs": [], + "source": [ + "module = GuppyModule(\"test\")\n", + "module.load_all(quantum)\n", + "\n", + "@guppy.traced(module)\n", + "def ladder() -> array[qubit, 5]:\n", + " qs = [qubit() for _ in range(5)]\n", + " for q in qs:\n", + " h(q)\n", + " for q1, q2 in zip(qs[1:], qs[:-1]):\n", + " cx(q1, q2)\n", + " return qs\n", + "\n", + "DotRenderer().render(module.compile().module);" + ] + }, + { + "cell_type": "markdown", + "id": "427aea92-dc18-4094-850e-1c5a6a4338ca", + "metadata": {}, + "source": [ + "Traced functions can be called from regular Guppy functions and vice versa." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9349ad6c-f131-48cc-941e-b259673b4138", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "raises-exception" + ] + }, + "outputs": [], + "source": [ + "module = GuppyModule(\"test\")\n", + "module.load_all(quantum)\n", + "\n", + "@guppy(module)\n", + "def regular1() -> tuple[qubit, qubit]:\n", + " q1 = qubit()\n", + " h(q1)\n", + " q2 = traced(q1)\n", + " return q1, q2\n", + "\n", + "@guppy.traced(module)\n", + "def traced(q: qubit) -> qubit:\n", + " r = regular2()\n", + " cx(q, r)\n", + " return r\n", + "\n", + "@guppy(module)\n", + "def regular2() -> qubit:\n", + " q = qubit()\n", + " h(q)\n", + " return q\n", + "\n", + "module.compile();" + ] + }, + { + "cell_type": "markdown", + "id": "0465ad90-9fa9-4e07-97fe-3e37266429e1", + "metadata": {}, + "source": [ + "Traced functions can even call out to non-Guppy functions and pass qubits along as data:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "1f5d84b6-40af-43c0-b229-80ab5af730d9", + "metadata": {}, + "outputs": [], + "source": [ + "module = GuppyModule(\"test\")\n", + "module.load_all(quantum)\n", + "\n", + "@guppy.traced(module)\n", + "def foo() -> qubit:\n", + " q = qubit()\n", + " bar(q)\n", + " return q\n", + "\n", + "def bar(q):\n", + " \"\"\"Regular Python function, no type annotations needed\"\"\"\n", + " h(q)\n", + "\n", + "module.compile();" + ] + }, + { + "cell_type": "markdown", + "id": "f1aca673-7796-4caf-928e-d19cdb03dfc0", + "metadata": {}, + "source": [ + "## Arithmetic\n", + "\n", + "Traced functions can do arbitrary arithmetic on their inputs:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "a60d6b2c-c2fe-4ea0-9905-1f9c1fbc3f2e", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "module = GuppyModule(\"test\")\n", + "module.load_all(quantum)\n", + "\n", + "@guppy.traced(module)\n", + "def foo(q: qubit, x: float) -> None:\n", + " x = x * 2\n", + " #print(x) # What is x?\n", + " rz(q, angle(x))\n", + "\n", + "module.compile();" + ] + }, + { + "cell_type": "markdown", + "id": "38e14bc8-15af-4787-b14b-acffe673ee7d", + "metadata": {}, + "source": [ + "However, we are not allowed to branch conditioned on the value of `x`:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "df4b2283-f948-44d5-bcd9-77d6b2b0698e", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "raises-exception" + ] + }, + "outputs": [ + { + "ename": "ValueError", + "evalue": "Branching on a dynamic value is not allowed during tracing. Try using a regular guppy function", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "Cell \u001b[0;32mIn[6], line 10\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m x \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m3\u001b[39m:\n\u001b[1;32m 8\u001b[0m rz(q, angle(x))\n\u001b[0;32m---> 10\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompile\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m;\n", + "Cell \u001b[0;32mIn[6], line 7\u001b[0m, in \u001b[0;36mfoo\u001b[0;34m(q, x)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;129m@guppy\u001b[39m\u001b[38;5;241m.\u001b[39mtraced(module)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfoo\u001b[39m(q: qubit, x: \u001b[38;5;28mfloat\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 6\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2\u001b[39m\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m x \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m3\u001b[39m:\n\u001b[1;32m 8\u001b[0m rz(q, angle(x))\n", + "\u001b[0;31mValueError\u001b[0m: Branching on a dynamic value is not allowed during tracing. Try using a regular guppy function" + ] + } + ], + "source": [ + "module = GuppyModule(\"test\")\n", + "module.load_all(quantum)\n", + "\n", + "@guppy.traced(module)\n", + "def foo(q: qubit, x: float) -> None:\n", + " x = x * 2\n", + " if x > 3:\n", + " rz(q, angle(x))\n", + "\n", + "module.compile();" + ] + }, + { + "cell_type": "markdown", + "id": "588055a5-eea0-4bdf-ac89-e2d1b9a4c292", + "metadata": {}, + "source": [ + "Similarly, we can't branch on measurement results inside the tracing context:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "0b4fc5ae-7ae5-4230-a640-0f38ddb167ad", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "raises-exception" + ] + }, + "outputs": [ + { + "ename": "ValueError", + "evalue": "Branching on a dynamic value is not allowed during tracing. Try using a regular guppy function", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "Cell \u001b[0;32mIn[7], line 9\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m measure(q):\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mXXX\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 9\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompile\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m;\n", + "Cell \u001b[0;32mIn[7], line 6\u001b[0m, in \u001b[0;36mfoo\u001b[0;34m(q)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;129m@guppy\u001b[39m\u001b[38;5;241m.\u001b[39mtraced(module)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfoo\u001b[39m(q: qubit) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m----> 6\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m measure(q):\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mXXX\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "\u001b[0;31mValueError\u001b[0m: Branching on a dynamic value is not allowed during tracing. Try using a regular guppy function" + ] + } + ], + "source": [ + "module = GuppyModule(\"test\")\n", + "module.load_all(quantum)\n", + "\n", + "@guppy.traced(module)\n", + "def foo(q: qubit) -> None:\n", + " if measure(q):\n", + " print(\"XXX\")\n", + "\n", + "module.compile();" + ] + }, + { + "cell_type": "markdown", + "id": "4171a7d1-7322-4e7e-8f2a-d85e3bc17645", + "metadata": {}, + "source": [ + "Also, we can't use regular inputs to control the size of registers:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "0bc00ed5-2b3f-473b-8ed1-932171a503be", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "raises-exception" + ] + }, + "outputs": [ + { + "ename": "TypeError", + "evalue": "'GuppyObject' object cannot be interpreted as an integer", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "Cell \u001b[0;32mIn[8], line 8\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;129m@guppy\u001b[39m\u001b[38;5;241m.\u001b[39mtraced(module)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfoo\u001b[39m(n: \u001b[38;5;28mint\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 6\u001b[0m qs \u001b[38;5;241m=\u001b[39m [qubit() \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(n)]\n\u001b[0;32m----> 8\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompile\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m;\n", + "Cell \u001b[0;32mIn[8], line 6\u001b[0m, in \u001b[0;36mfoo\u001b[0;34m(n)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;129m@guppy\u001b[39m\u001b[38;5;241m.\u001b[39mtraced(module)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfoo\u001b[39m(n: \u001b[38;5;28mint\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m----> 6\u001b[0m qs \u001b[38;5;241m=\u001b[39m [qubit() \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28;43mrange\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mn\u001b[49m\u001b[43m)\u001b[49m]\n", + "\u001b[0;31mTypeError\u001b[0m: 'GuppyObject' object cannot be interpreted as an integer" + ] + } + ], + "source": [ + "module = GuppyModule(\"test\")\n", + "module.load_all(quantum)\n", + "\n", + "@guppy.traced(module)\n", + "def foo(n: int) -> None:\n", + " qs = [qubit() for _ in range(n)]\n", + "\n", + "module.compile();" + ] + }, + { + "cell_type": "markdown", + "id": "6baa993a-d64d-42a6-a342-03f850721f56", + "metadata": {}, + "source": [ + "## Arrays and Lists\n", + "\n", + "Input arrays can be used like regular lists." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "16eb41ad-a3af-4731-8dab-c76b6d4273f6", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "module = GuppyModule(\"test\")\n", + "module.load_all(quantum)\n", + "\n", + "@guppy.declare(module)\n", + "def foo(qs: array[qubit, 10]) -> None: ...\n", + "\n", + "@guppy.traced(module)\n", + "def bar(qs: array[qubit, 10]) -> None:\n", + " # Arrays are iterable in the Python context\n", + " for q1, q2 in zip(qs[1:], qs[:-1]):\n", + " cx(q1, q2)\n", + " [start, *_, end] = qs\n", + " cx(start, end)\n", + "\n", + " # Regular Guppy functions can be called with Python lists\n", + " # as long as the lengths match up\n", + " rev = [q for q in reversed(qs)]\n", + " foo(rev)\n", + "\n", + "module.compile();" + ] + }, + { + "cell_type": "markdown", + "id": "c134a7e5-65b1-48fe-872d-d145b88ce75f", + "metadata": {}, + "source": [ + "## Safety" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "1b278240-35dc-479e-9a0b-2864368b27a9", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "raises-exception" + ] + }, + "outputs": [ + { + "ename": "ValueError", + "evalue": "Value with linear type `qubit` was already used\n\nPrevious use occurred in :6 as an argument to `cx`", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", + "Cell \u001b[0;32mIn[10], line 8\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;129m@guppy\u001b[39m\u001b[38;5;241m.\u001b[39mtraced(module)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbad\u001b[39m(q: qubit) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 6\u001b[0m cx(q, q)\n\u001b[0;32m----> 8\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompile\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[10], line 6\u001b[0m, in \u001b[0;36mbad\u001b[0;34m(q)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;129m@guppy\u001b[39m\u001b[38;5;241m.\u001b[39mtraced(module)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbad\u001b[39m(q: qubit) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m----> 6\u001b[0m \u001b[43mcx\u001b[49m\u001b[43m(\u001b[49m\u001b[43mq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mq\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mValueError\u001b[0m: Value with linear type `qubit` was already used\n\nPrevious use occurred in :6 as an argument to `cx`" + ] + } + ], + "source": [ + "module = GuppyModule(\"test\")\n", + "module.load_all(quantum)\n", + "\n", + "@guppy.traced(module)\n", + "def bad(q: qubit) -> None:\n", + " cx(q, q)\n", + "\n", + "module.compile()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "8d50bfe3-125f-4737-9a85-7aba03a53684", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "raises-exception" + ] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Error: Linearity violation in function return (at :5:0)\n", + " | \n", + "3 | \n", + "4 | @guppy.traced(module)\n", + "5 | def bad(q: qubit) -> qubit:\n", + " | ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + "6 | return q\n", + " | ^^^^^^^^^^^^\n", + "\n", + "Value with linear type `qubit` was already used\n", + "\n", + "Previous use occurred in :8\n", + "\n", + "Guppy compilation failed due to 1 previous error\n" + ] + } + ], + "source": [ + "module = GuppyModule(\"test\")\n", + "module.load_all(quantum)\n", + "\n", + "@guppy.traced(module)\n", + "def bad(q: qubit) -> qubit:\n", + " return q\n", + "\n", + "module.compile()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "a1e3e917-ce2a-46cc-9a3c-cc980255d8c3", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "raises-exception" + ] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Error: Linearity violation in function return (at :5:0)\n", + " | \n", + "3 | \n", + "4 | @guppy.traced(module)\n", + "5 | def bad(q: qubit) -> None:\n", + " | ^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " | ...\n", + "7 | cx(tmp, q)\n", + " | ^^^^^^^^^^^^^^\n", + "\n", + "Value with linear type `qubit` is leaked by this function\n", + "\n", + "Guppy compilation failed due to 1 previous error\n" + ] + } + ], + "source": [ + "module = GuppyModule(\"test\")\n", + "module.load_all(quantum)\n", + "\n", + "@guppy.traced(module)\n", + "def bad(q: qubit) -> None:\n", + " tmp = qubit()\n", + " cx(tmp, q)\n", + "\n", + "module.compile()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "uv_guppy", + "language": "python", + "name": "uv_guppy" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/guppylang/decorator.py b/guppylang/decorator.py index 6d701b5a..dd10f467 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -13,7 +13,7 @@ import guppylang from guppylang.ast_util import annotate_location -from guppylang.definition.common import DefId, Definition +from guppylang.definition.common import DefId from guppylang.definition.const import RawConstDef from guppylang.definition.custom import ( CustomCallChecker, @@ -23,18 +23,15 @@ OpCompiler, RawCustomFunctionDef, ) -from guppylang.definition.declaration import RawFunctionDecl from guppylang.definition.extern import RawExternDef from guppylang.definition.function import ( CompiledFunctionDef, RawFunctionDef, ) from guppylang.definition.parameter import ConstVarDef, TypeVarDef -from guppylang.definition.pytket_circuits import ( - RawLoadPytketDef, - RawPytketDef, -) +from guppylang.definition.pytket_circuits import RawLoadPytketDef from guppylang.definition.struct import RawStructDef +from guppylang.definition.traced import RawTracedFunctionDef from guppylang.definition.ty import OpaqueTypeDef, TypeDef from guppylang.error import MissingModuleError, pretty_errors from guppylang.ipython_inspect import ( @@ -51,6 +48,7 @@ sphinx_running, ) from guppylang.span import Loc, SourceMap, Span +from guppylang.tracing.object import GuppyDefinition from guppylang.tys.arg import Argument from guppylang.tys.param import Parameter from guppylang.tys.subst import Inst @@ -60,13 +58,8 @@ T = TypeVar("T") Decorator = Callable[[S], T] -FuncDefDecorator = Decorator[PyFunc, RawFunctionDef] -FuncDeclDecorator = Decorator[PyFunc, RawFunctionDecl] -CustomFuncDecorator = Decorator[PyFunc, RawCustomFunctionDef] -PytketDecorator = Decorator[PyFunc, RawPytketDef] -ClassDecorator = Decorator[PyClass, PyClass] -OpaqueTypeDecorator = Decorator[PyClass, OpaqueTypeDef] -StructDecorator = Decorator[PyClass, RawStructDef] +FuncDecorator = Decorator[PyFunc, GuppyDefinition] +ClassDecorator = Decorator[PyClass, GuppyDefinition] _JUPYTER_NOTEBOOK_MODULE = "" @@ -101,21 +94,22 @@ def __init__(self) -> None: self._sources = SourceMap() @overload - def __call__(self, arg: PyFunc) -> RawFunctionDef: ... + def __call__(self, arg: PyFunc) -> GuppyDefinition: ... @overload - def __call__(self, arg: GuppyModule) -> FuncDefDecorator: ... + def __call__(self, arg: GuppyModule) -> FuncDecorator: ... @pretty_errors - def __call__(self, arg: PyFunc | GuppyModule) -> FuncDefDecorator | RawFunctionDef: + def __call__(self, arg: PyFunc | GuppyModule) -> FuncDecorator | GuppyDefinition: """Decorator to annotate Python functions as Guppy code. Optionally, the `GuppyModule` in which the function should be placed can be passed to the decorator. """ - def dec(f: Callable[..., Any], module: GuppyModule) -> RawFunctionDef: - return module.register_func_def(f) + def dec(f: Callable[..., Any], module: GuppyModule) -> GuppyDefinition: + defn = module.register_func_def(f) + return GuppyDefinition(defn) return self._with_optional_module(dec, arg) @@ -173,6 +167,15 @@ def _get_python_caller(self, fn: PyFunc | None = None) -> ModuleIdentifier: module_path, module.__name__ if module else module_path.name, module ) + @pretty_errors + def traced(self, arg: PyFunc | GuppyModule) -> FuncDecorator | GuppyDefinition: + def dec(f: Callable[..., Any], module: GuppyModule) -> GuppyDefinition: + defn = RawTracedFunctionDef(DefId.fresh(module), f.__name__, None, f, {}) + module.register_def(defn) + return GuppyDefinition(defn) + + return self._with_optional_module(dec, arg) + def init_module(self, import_builtins: bool = True) -> None: """Manually initialises a Guppy module for the current Python file. @@ -188,7 +191,7 @@ def init_module(self, import_builtins: bool = True) -> None: @pretty_errors def extend_type( self, defn: TypeDef, module: GuppyModule | None = None - ) -> ClassDecorator: + ) -> Callable[[type], type]: """Decorator to add new instance functions to a type.""" mod = module or self.get_module() mod._instance_func_buffer = {} @@ -208,7 +211,7 @@ def type( bound: ht.TypeBound | None = None, params: Sequence[Parameter] | None = None, module: GuppyModule | None = None, - ) -> OpaqueTypeDecorator: + ) -> ClassDecorator: """Decorator to annotate a class definitions as Guppy types. Requires the static Hugr translation of the type. Additionally, the type can be @@ -224,7 +227,7 @@ def type( mk_hugr_ty = (lambda _: hugr_ty) if isinstance(hugr_ty, ht.Type) else hugr_ty - def dec(c: type) -> OpaqueTypeDef: + def dec(c: type) -> GuppyDefinition: defn = OpaqueTypeDef( DefId.fresh(mod), name or c.__name__, @@ -236,14 +239,14 @@ def dec(c: type) -> OpaqueTypeDef: ) mod.register_def(defn) mod._register_buffered_instance_funcs(defn) - return defn + return GuppyDefinition(defn) return dec @property def struct( self, - ) -> Callable[[PyClass | GuppyModule], StructDecorator | RawStructDef]: + ) -> Callable[[PyClass | GuppyModule], ClassDecorator | GuppyDefinition]: """Decorator to define a new struct.""" # Note that this is a property. Thus, the code below is executed *before* # the members of the decorated class are executed. @@ -263,7 +266,7 @@ def struct( frame = get_calling_frame() python_scope = frame.f_globals | frame.f_locals if frame else {} - def dec(cls: type, module: GuppyModule) -> RawStructDef: + def dec(cls: type, module: GuppyModule) -> GuppyDefinition: defn = RawStructDef( DefId.fresh(module), cls.__name__, None, cls, python_scope ) @@ -276,9 +279,9 @@ def dec(cls: type, module: GuppyModule) -> RawStructDef: implicit_module._instance_func_buffer = None if not implicit_module_existed: self._modules.pop(caller_id) - return defn + return GuppyDefinition(defn) - def higher_dec(arg: GuppyModule | PyClass) -> StructDecorator | RawStructDef: + def higher_dec(arg: GuppyModule | PyClass) -> ClassDecorator | GuppyDefinition: if isinstance(arg, GuppyModule): arg._instance_func_buffer = {} return self._with_optional_module(dec, arg) @@ -317,7 +320,7 @@ def custom( higher_order_value: bool = True, name: str = "", module: GuppyModule | None = None, - ) -> CustomFuncDecorator: + ) -> FuncDecorator: """Decorator to add custom typing or compilation behaviour to function decls. Optionally, usage of the function as a higher-order value can be disabled. In @@ -326,7 +329,7 @@ def custom( """ mod = module or self.get_module() - def dec(f: PyFunc) -> RawCustomFunctionDef: + def dec(f: PyFunc) -> GuppyDefinition: call_checker = checker or DefaultCallChecker() func = RawCustomFunctionDef( DefId.fresh(mod), @@ -338,7 +341,7 @@ def dec(f: PyFunc) -> RawCustomFunctionDef: higher_order_value, ) mod.register_def(func) - return func + return GuppyDefinition(func) return dec @@ -349,7 +352,7 @@ def hugr_op( higher_order_value: bool = True, name: str = "", module: GuppyModule | None = None, - ) -> CustomFuncDecorator: + ) -> FuncDecorator: """Decorator to annotate function declarations as HUGR ops. Args: @@ -364,22 +367,23 @@ def hugr_op( return self.custom(OpCompiler(op), checker, higher_order_value, name, module) @overload - def declare(self, arg: GuppyModule) -> RawFunctionDecl: ... + def declare(self, arg: GuppyModule) -> GuppyDefinition: ... @overload - def declare(self, arg: PyFunc) -> FuncDeclDecorator: ... + def declare(self, arg: PyFunc) -> FuncDecorator: ... - def declare(self, arg: GuppyModule | PyFunc) -> FuncDeclDecorator | RawFunctionDecl: + def declare(self, arg: GuppyModule | PyFunc) -> FuncDecorator | GuppyDefinition: """Decorator to declare functions""" - def dec(f: Callable[..., Any], module: GuppyModule) -> RawFunctionDecl: - return module.register_func_decl(f) + def dec(f: Callable[..., Any], module: GuppyModule) -> GuppyDefinition: + defn = module.register_func_decl(f) + return GuppyDefinition(defn) return self._with_optional_module(dec, arg) def constant( self, name: str, ty: str, value: hv.Value, module: GuppyModule | None = None - ) -> RawConstDef: + ) -> GuppyDefinition: """Adds a constant to a module, backed by a `hugr.val.Value`.""" module = module or self.get_module() type_ast = _parse_expr_string( @@ -387,7 +391,7 @@ def constant( ) defn = RawConstDef(DefId.fresh(module), name, None, type_ast, value) module.register_def(defn) - return defn + return GuppyDefinition(defn) def extern( self, @@ -396,7 +400,7 @@ def extern( symbol: str | None = None, constant: bool = True, module: GuppyModule | None = None, - ) -> RawExternDef: + ) -> GuppyDefinition: """Adds an extern symbol to a module.""" module = module or self.get_module() type_ast = _parse_expr_string( @@ -406,7 +410,7 @@ def extern( DefId.fresh(module), name, None, symbol or name, constant, type_ast ) module.register_def(defn) - return defn + return GuppyDefinition(defn) def load(self, m: ModuleType | GuppyModule) -> None: caller = self._get_python_caller() @@ -434,10 +438,10 @@ def get_module( elif id.name == _JUPYTER_NOTEBOOK_MODULE: globs = get_ipython_globals() if globs: - defs: dict[str, Definition | ModuleType] = {} + defs: dict[str, GuppyDefinition | ModuleType] = {} for x, value in globs.items(): - if isinstance(value, Definition): - other_module = value.id.module + if isinstance(value, GuppyDefinition): + other_module = value.wrapped.id.module if other_module and other_module != module: defs[x] = value elif isinstance(value, ModuleType): @@ -485,7 +489,7 @@ def registered_modules(self) -> KeysView[ModuleIdentifier]: @pretty_errors def pytket( self, input_circuit: Any, module: GuppyModule | None = None - ) -> PytketDecorator: + ) -> FuncDecorator: """Adds a pytket circuit function definition with explicit signature.""" err_msg = "Only pytket circuits can be passed to guppy.pytket" try: @@ -499,15 +503,16 @@ def pytket( mod = module or self.get_module() - def func(f: PyFunc) -> RawPytketDef: - return mod.register_pytket_func(f, input_circuit) + def func(f: PyFunc) -> GuppyDefinition: + defn = mod.register_pytket_func(f, input_circuit) + return GuppyDefinition(defn) return func @pretty_errors def load_pytket( self, name: str, input_circuit: Any, module: GuppyModule | None = None - ) -> RawLoadPytketDef: + ) -> GuppyDefinition: """Adds a pytket circuit function definition with implicit signature.""" err_msg = "Only pytket circuits can be passed to guppy.load_pytket" try: @@ -523,7 +528,7 @@ def load_pytket( span = _find_load_call(self._sources) defn = RawLoadPytketDef(DefId.fresh(module), name, None, span, input_circuit) mod.register_def(defn) - return defn + return GuppyDefinition(defn) class _GuppyDummy: diff --git a/guppylang/definition/struct.py b/guppylang/definition/struct.py index ccbc3a51..c888e9b1 100644 --- a/guppylang/definition/struct.py +++ b/guppylang/definition/struct.py @@ -18,7 +18,6 @@ CheckableDef, CompiledDef, DefId, - Definition, ParsableDef, UnknownSourceError, ) @@ -33,6 +32,7 @@ from guppylang.error import GuppyError, InternalGuppyError from guppylang.ipython_inspect import find_ipython_def, is_running_ipython from guppylang.span import SourceMap +from guppylang.tracing.object import GuppyDefinition from guppylang.tys.arg import Argument from guppylang.tys.param import Parameter, check_all_args from guppylang.tys.parsing import type_from_ast @@ -134,7 +134,7 @@ def parse(self, globals: Globals, sources: SourceMap) -> "ParsedStructDef": # Ensure that all function definitions are Guppy functions case _, ast.FunctionDef(name=name) as node: v = getattr(self.python_class, name) - if not isinstance(v, Definition): + if not isinstance(v, GuppyDefinition): raise GuppyError(NonGuppyMethodError(node, self.name, name)) used_func_names[name] = node if name in used_field_names: diff --git a/guppylang/definition/traced.py b/guppylang/definition/traced.py new file mode 100644 index 00000000..25556b2f --- /dev/null +++ b/guppylang/definition/traced.py @@ -0,0 +1,146 @@ +import ast +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + +import hugr.build.function as hf +import hugr.tys as ht +from hugr import Wire +from hugr.build.dfg import DefinitionBuilder, OpVar +from hugr.package import FuncDefnPointer + +from guppylang.ast_util import AstNode, with_loc +from guppylang.checker.core import Context, Globals, PyScope +from guppylang.checker.expr_checker import ( + check_call, + synthesize_call, +) +from guppylang.checker.func_checker import ( + check_signature, +) +from guppylang.compiler.core import CompiledGlobals, DFContainer +from guppylang.definition.common import ( + CompilableDef, + ParsableDef, +) +from guppylang.definition.function import parse_py_func +from guppylang.definition.value import CallableDef, CallReturnWires, CompiledCallableDef +from guppylang.nodes import GlobalCall +from guppylang.span import SourceMap +from guppylang.tys.subst import Inst, Subst +from guppylang.tys.ty import FunctionType, Type, type_to_row + +PyFunc = Callable[..., Any] + + +@dataclass(frozen=True) +class RawTracedFunctionDef(ParsableDef): + python_func: PyFunc + python_scope: PyScope + + description: str = field(default="function", init=False) + + def parse(self, globals: Globals, sources: SourceMap) -> "TracedFunctionDef": + """Parses and checks the user-provided signature of the function.""" + func_ast, _docstring = parse_py_func(self.python_func, sources) + ty = check_signature(func_ast, globals.with_python_scope(self.python_scope)) + return TracedFunctionDef( + self.id, self.name, func_ast, ty, self.python_func, self.python_scope + ) + + def compile(self) -> FuncDefnPointer: + from guppylang.decorator import guppy + + return guppy.compile_function(self) + + +@dataclass(frozen=True) +class TracedFunctionDef(RawTracedFunctionDef, CallableDef, CompilableDef): + python_func: PyFunc + ty: FunctionType + + def check_call( + self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context + ) -> tuple[ast.expr, Subst]: + """Checks the return type of a function call against a given type.""" + # Use default implementation from the expression checker + args, subst, inst = check_call(self.ty, args, ty, node, ctx) + node = with_loc(node, GlobalCall(def_id=self.id, args=args, type_args=inst)) + return node, subst + + def synthesize_call( + self, args: list[ast.expr], node: AstNode, ctx: Context + ) -> tuple[ast.expr, Type]: + """Synthesizes the return type of a function call.""" + # Use default implementation from the expression checker + args, ty, inst = synthesize_call(self.ty, args, node, ctx) + node = with_loc(node, GlobalCall(def_id=self.id, args=args, type_args=inst)) + return node, ty + + def compile_outer( + self, module: DefinitionBuilder[OpVar] + ) -> "CompiledTracedFunctionDef": + """Adds a Hugr `FuncDefn` node for this function to the Hugr. + + Note that we don't compile the function body at this point since we don't have + access to the other compiled functions yet. The body is compiled later in + `CompiledFunctionDef.compile_inner()`. + """ + func_type = self.ty.to_hugr_poly() + func_def = module.define_function( + self.name, func_type.body.input, func_type.body.output, func_type.params + ) + return CompiledTracedFunctionDef( + self.id, + self.name, + self.defined_at, + self.ty, + self.python_func, + self.python_scope, + func_def, + ) + + +@dataclass(frozen=True) +class CompiledTracedFunctionDef(TracedFunctionDef, CompiledCallableDef): + func_def: hf.Function + + def load_with_args( + self, + type_args: Inst, + dfg: DFContainer, + globals: CompiledGlobals, + node: AstNode, + ) -> Wire: + """Loads the function as a value into a local Hugr dataflow graph.""" + func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr() + type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args] + return dfg.builder.load_function(self.func_def, func_ty, type_args) + + def compile_call( + self, + args: list[Wire], + type_args: Inst, + dfg: DFContainer, + globals: CompiledGlobals, + node: AstNode, + ) -> CallReturnWires: + """Compiles a call to the function.""" + func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr() + type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args] + num_returns = len(type_to_row(self.ty.output)) + call = dfg.builder.call( + self.func_def, *args, instantiation=func_ty, type_args=type_args + ) + return CallReturnWires( + regular_returns=list(call[:num_returns]), + inout_returns=list(call[num_returns:]), + ) + + def compile_inner(self, globals: CompiledGlobals) -> None: + """Compiles the body of the function by tracing it.""" + from guppylang.tracing.function import trace_function + + trace_function( + self.python_func, self.ty, self.func_def, globals, self.defined_at + ) diff --git a/guppylang/definition/ty.py b/guppylang/definition/ty.py index 123b99c6..e50fe4fd 100644 --- a/guppylang/definition/ty.py +++ b/guppylang/definition/ty.py @@ -1,7 +1,7 @@ from abc import abstractmethod from collections.abc import Callable, Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from hugr import tys @@ -41,12 +41,6 @@ class OpaqueTypeDef(TypeDef, CompiledDef): to_hugr: Callable[[Sequence[Argument]], tys.Type] bound: tys.TypeBound | None = None - def __getitem__(self, item: Any) -> "OpaqueTypeDef": - """Dummy implementation to allow generic instantiations in type signatures that - are evaluated by the Python interpreter. - """ - return self - def check_instantiate( self, args: Sequence[Argument], globals: "Globals", loc: AstNode | None = None ) -> OpaqueType: diff --git a/guppylang/error.py b/guppylang/error.py index d4f24edf..e60bdc2e 100644 --- a/guppylang/error.py +++ b/guppylang/error.py @@ -56,9 +56,13 @@ def ipython_excepthook( ) -> Any: return hook(etype, value, tb) - ipython_shell.set_custom_exc((GuppyError,), ipython_excepthook) + old_hook = ipython_shell.CustomTB + old_exc_tuple = ipython_shell.custom_exceptions + ipython_shell.set_custom_exc((Exception,), ipython_excepthook) yield - ipython_shell.set_custom_exc((), None) + ipython_shell.set_custom_exc( + old_exc_tuple, lambda shell, *args, **kwargs: old_hook(*args, **kwargs) + ) except NameError: pass else: diff --git a/guppylang/module.py b/guppylang/module.py index 64c02294..c73302b9 100644 --- a/guppylang/module.py +++ b/guppylang/module.py @@ -32,6 +32,7 @@ from guppylang.definition.ty import TypeDef from guppylang.error import pretty_errors from guppylang.experimental import enable_experimental_features +from guppylang.tracing.object import GuppyDefinition if TYPE_CHECKING: from hugr import Hugr, ops @@ -130,11 +131,11 @@ def load( ] while imports: alias, imp = imports.pop() - if isinstance(imp, Definition): - module = imp.id.module + if isinstance(imp, GuppyDefinition): + module = imp.wrapped.id.module assert module is not None module.check() - names[alias or imp.name] = imp.id + names[alias or imp.wrapped.name] = imp.wrapped.id modules.add(module) elif isinstance(imp, GuppyModule): imp.check() @@ -182,7 +183,7 @@ def load_all(self, mod: GuppyModule | ModuleType) -> None: mod.check() self.load( *( - defn + GuppyDefinition(defn) for defn in mod._globals.defs.values() if not defn.name.startswith("_") ) @@ -351,11 +352,14 @@ def compile(self) -> ModulePointer: graph.metadata["name"] = self.name # Lower definitions to Hugr - ctx = CompiledGlobals( - checked_defs, graph, self._imported_globals | self._globals - ) - for defn in self._checked_defs.values(): - ctx.compile(defn) + from guppylang.tracing.state import set_tracing_globals + + with set_tracing_globals(self._globals | self._imported_globals): + ctx = CompiledGlobals( + checked_defs, graph, self._imported_globals | self._globals + ) + for defn in self._checked_defs.values(): + ctx.compile(defn) # TODO: Currently we just include a hardcoded list of extensions. We should # compute this dynamically from the imported dependencies instead. diff --git a/guppylang/std/_internal/compiler/array.py b/guppylang/std/_internal/compiler/array.py index b6d90b12..ea1048b2 100644 --- a/guppylang/std/_internal/compiler/array.py +++ b/guppylang/std/_internal/compiler/array.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import TYPE_CHECKING, TypeVar + import hugr.std from hugr import Wire, ops from hugr import tys as ht @@ -19,6 +21,10 @@ ) from guppylang.tys.arg import ConstArg, TypeArg +if TYPE_CHECKING: + from hugr.build.dfg import DfBase + + # ------------------------------------------------------ # --------------- std.array operations ----------------- # ------------------------------------------------------ @@ -72,6 +78,26 @@ def array_set(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: ) +def array_pop(elem_ty: ht.Type, length: int, from_left: bool) -> ops.ExtOp: + """Returns an operation that pops an element from the left of an array.""" + assert length > 0 + length_arg = ht.BoundedNatArg(length) + arr_ty = array_type(elem_ty, length_arg) + popped_arr_ty = array_type(elem_ty, ht.BoundedNatArg(length - 1)) + op = "pop_left" if from_left else "pop_right" + return _instantiate_array_op( + op, elem_ty, length_arg, [arr_ty], [ht.Option(elem_ty, popped_arr_ty)] + ) + + +def array_discard_empty(elem_ty: ht.Type) -> ops.ExtOp: + """Returns an operation that discards an array of length zero.""" + arr_ty = array_type(elem_ty, ht.BoundedNatArg(0)) + return hugr.std.PRELUDE.get_op("discard_empty").instantiate( + [ht.TypeTypeArg(elem_ty)], ht.FunctionType([arr_ty], []) + ) + + def array_map(elem_ty: ht.Type, length: ht.TypeArg, new_elem_ty: ht.Type) -> ops.ExtOp: """Returns an operation that maps a function across an array.""" # TODO @@ -97,6 +123,30 @@ def array_repeat(elem_ty: ht.Type, length: ht.TypeArg) -> ops.ExtOp: # ------------------------------------------------------ +P = TypeVar("P", bound=ops.DfParentOp) + + +def unpack_array(builder: DfBase[P], array: Wire) -> list[Wire]: + """ """ + # TODO: This should be an op + array_ty = builder.hugr.port_type(array.out_port()) + assert isinstance(array_ty, ht.ExtType) + err = "Internal error: array unpacking failed" + match array_ty.args: + case [ht.BoundedNatArg(length), ht.TypeTypeArg(elem_ty)]: + elems = [] + for i in range(length): + res = builder.add_op( + array_pop(elem_ty, length - i, from_left=True), array + ) + elem, array = build_unwrap(builder, res, err) + elems.append(elem) + builder.add_op(array_discard_empty(elem_ty), array) + return elems + case _: + raise InternalGuppyError("Invalid array type args") + + class ArrayCompiler(CustomCallCompiler): """Base class for custom array op compilers.""" diff --git a/guppylang/tracing/__init__.py b/guppylang/tracing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/guppylang/tracing/function.py b/guppylang/tracing/function.py new file mode 100644 index 00000000..a9ce4696 --- /dev/null +++ b/guppylang/tracing/function.py @@ -0,0 +1,138 @@ +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, ClassVar + +from hugr import ops +from hugr.build.dfg import DfBase + +from guppylang.ast_util import AstNode, with_loc, with_type +from guppylang.cfg.builder import tmp_vars +from guppylang.checker.core import Context, Locals, Variable +from guppylang.checker.errors.type_errors import TypeMismatchError +from guppylang.compiler.core import CompiledGlobals, DFContainer +from guppylang.compiler.expr_compiler import ExprCompiler +from guppylang.definition.value import CompiledCallableDef +from guppylang.diagnostic import Error +from guppylang.error import GuppyError, exception_hook +from guppylang.nodes import PlaceNode +from guppylang.tracing.object import GuppyObject +from guppylang.tracing.state import ( + TracingState, + get_tracing_globals, + get_tracing_state, + set_tracing_state, +) +from guppylang.tracing.unpacking import ( + P, + guppy_object_from_py, + repack_guppy_object, + unpack_guppy_object, + update_packed_value, +) +from guppylang.tracing.util import tracing_except_hook +from guppylang.tys.ty import FunctionType, InputFlags, type_to_row + + +@dataclass(frozen=True) +class TracingReturnLinearityViolationError(Error): + title: ClassVar[str] = "Linearity violation in function return" + message: ClassVar[str] = "{msg}" + msg: str + + +def trace_function( + python_func: Callable[..., Any], + ty: FunctionType, + builder: DfBase[P], + globals: CompiledGlobals, + node: AstNode, +) -> None: + """Kicks off tracing of a function.""" + state = TracingState(globals, DFContainer(builder, {}), node) + with set_tracing_state(state): + inputs = [ + unpack_guppy_object(GuppyObject(inp.ty, wire), builder) + for wire, inp in zip(builder.inputs(), ty.inputs, strict=False) + ] + + with exception_hook(tracing_except_hook): + py_out = python_func(*inputs) + + try: + out_obj = repack_guppy_object(py_out, builder) + except ValueError as err: + # Linearity violation in the return statement + raise GuppyError( + TracingReturnLinearityViolationError(node, str(err)) + ) from None + + # Check that the output type is correct + if out_obj._ty != ty.output: + err = TypeMismatchError(node, ty.output, out_obj._ty, "return value") + raise GuppyError(err) + + # Unpack regular returns + out_tys = type_to_row(out_obj._ty) + if len(out_tys) > 1: + regular_returns = list( + builder.add_op(ops.UnpackTuple(), out_obj._use_wire(None)).outputs() + ) + elif len(out_tys) > 0: + regular_returns = [out_obj._use_wire(None)] + else: + regular_returns = [] + + # Compute the inout extra outputs + try: + inout_returns = [ + repack_guppy_object(inout_obj, builder)._use_wire(None) + for inout_obj, inp in zip(inputs, ty.inputs, strict=False) + if InputFlags.Inout in inp.flags + ] + except ValueError as err: + raise GuppyError( + TracingReturnLinearityViolationError(node, str(err)) + ) from None + + # Check that all allocated linear objects have been used + if state.unused_objs: + unused = state.allocated_objs[state.unused_objs.pop()] + err = f"Value with linear type `{unused._ty}` is leaked by this function" + raise GuppyError(TracingReturnLinearityViolationError(node, err)) from None + + builder.set_outputs(*regular_returns, *inout_returns) + + +def trace_call(func: CompiledCallableDef, *args: Any) -> Any: + state = get_tracing_state() + globals = get_tracing_globals() + + # Try to turn args into `GuppyObjects` + args_objs = [ + guppy_object_from_py(arg, state.dfg.builder, state.node) for arg in args + ] + + # Create dummy variables and bind the objects to them + arg_vars = [Variable(next(tmp_vars), obj._ty, None) for obj in args_objs] + locals = Locals({var.name: var.ty for var in arg_vars}) + for obj, var in zip(args_objs, arg_vars, strict=False): + state.dfg[var] = obj._use_wire(func) + + # Check call + arg_exprs = [ + with_loc(func.defined_at, with_type(var.ty, PlaceNode(var))) for var in arg_vars + ] + call_node, ret_ty = func.synthesize_call( + arg_exprs, func.defined_at, Context(globals, locals, {}) + ) + + # Compile call + ret_wire = ExprCompiler(state.globals).compile(call_node, state.dfg) + + # Update inouts + for inp, arg, var in zip(func.ty.inputs, args, arg_vars, strict=False): + if InputFlags.Inout in inp.flags: + inout_wire = state.dfg[var] + update_packed_value(arg, GuppyObject(inp.ty, inout_wire), state.dfg.builder) + + return GuppyObject(ret_ty, ret_wire) diff --git a/guppylang/tracing/object.py b/guppylang/tracing/object.py new file mode 100644 index 00000000..1b252e77 --- /dev/null +++ b/guppylang/tracing/object.py @@ -0,0 +1,256 @@ +import inspect +import itertools +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, NamedTuple + +from hugr import Wire, ops + +from guppylang.definition.common import DefId, Definition +from guppylang.definition.function import RawFunctionDef +from guppylang.definition.traced import CompiledTracedFunctionDef +from guppylang.definition.ty import TypeDef +from guppylang.definition.value import CompiledCallableDef, CompiledValueDef +from guppylang.ipython_inspect import find_ipython_def, is_running_ipython +from guppylang.span import Span +from guppylang.tracing.state import get_tracing_globals, get_tracing_state +from guppylang.tracing.util import get_calling_frame, hide_trace +from guppylang.tys.ty import TupleType, Type + + +class GetAttrDunders(ABC): + @abstractmethod + def __getattr__(self, item): ... + + def __abs__(self, other): + return self.__getattr__("__abs__")(other) + + def __add__(self, other): + return self.__getattr__("__add__")(other) + + def __and__(self, other): + return self.__getattr__("__and__")(other) + + def __bool__(self): + return self.__getattr__("__bool__")() + + def __ceil__(self): + return self.__getattr__("__bool__")() + + def __divmod__(self, other): + return self.__getattr__("__divmod__")(other) + + def __eq__(self, other): + return self.__getattr__("__eq__")(other) + + def __float__(self): + return self.__getattr__("__bool__")() + + def __floor__(self): + return self.__getattr__("__floor__")() + + def __floordiv__(self, other): + return self.__getattr__("__floordiv__")(other) + + def __ge__(self, other): + return self.__getattr__("__ge__")(other) + + def __gt__(self, other): + return self.__getattr__("__gt__")(other) + + def __int__(self): + return self.__getattr__("__int__")() + + def __invert__(self): + return self.__getattr__("__invert__")() + + def __le__(self, other): + return self.__getattr__("__le__")(other) + + def __lshift__(self, other): + return self.__getattr__("__lshift__")(other) + + def __lt__(self, other): + return self.__getattr__("__lt__")(other) + + def __mod__(self, other): + return self.__getattr__("__mod__")(other) + + def __mul__(self, other): + return self.__getattr__("__mul__")(other) + + def __ne__(self, other): + return self.__getattr__("__ne__")(other) + + def __neg__(self): + return self.__getattr__("__neg__")() + + def __or__(self, other): + return self.__getattr__("__or__")(other) + + def __pos__(self): + return self.__getattr__("__pos__")() + + def __pow__(self, other): + return self.__getattr__("__pow__")(other) + + def __sub__(self, other): + return self.__getattr__("__sub__")(other) + + def __truediv__(self, other): + return self.__getattr__("__truediv__")(other) + + def __trunc__(self): + return self.__getattr__("__trunc__")() + + def __xor__(self, other): + return self.__getattr__("__xor__")(other) + + +class ObjectUse(NamedTuple): + """Records a use of a linear `GuppyObject`.""" + + module: str + lineno: int + called_func: CompiledCallableDef | None + + +ObjectId = int + +fresh_id = itertools.count() + + +class GuppyObject(GetAttrDunders): + """The runtime representation of abstract Guppy objects during tracing.""" + + _ty: Type + _wire: Wire + _used: ObjectUse | None + _id: ObjectId + + def __init__(self, ty: Type, wire: Wire, used: Span | None = None) -> None: + self._ty = ty + self._wire = wire + self._used = used + self._id = next(fresh_id) + state = get_tracing_state() + state.allocated_objs[self._id] = self + if ty.linear and not self._used: + state.unused_objs.add(self._id) + + @hide_trace + def __getattr__(self, name: str): + globals = get_tracing_globals() + func = globals.get_instance_func(self._ty, name) + if func is None: + raise AttributeError( + f"Expression of type `{self._ty}` has no attribute `{name}`" + ) + return lambda *xs: GuppyDefinition(func)(self, *xs) + + @hide_trace + def __bool__(self): + err = ( + "Branching on a dynamic value is not allowed during tracing. Try using " + "a regular guppy function" + ) + raise ValueError(err) + + @hide_trace + def __iter__(self): + state = get_tracing_state() + builder = state.dfg.builder + if isinstance(self._ty, TupleType): + unpack = builder.add_op(ops.UnpackTuple(), self._use_wire(None)) + return ( + GuppyObject(ty, wire) + for ty, wire in zip( + self._ty.element_types, unpack.outputs(), strict=False + ) + ) + raise TypeError(f"Expression of type `{self._ty}` is not iterable") + + def _use_wire(self, called_func: CompiledCallableDef | None) -> Wire: + # Panic if the value has already been used + if self._used: + use = self._used + err = ( + f"Value with linear type `{self._ty}` was already used\n\n" + f"Previous use occurred in {use.module}:{use.lineno}" + ) + if use.called_func: + err += f" as an argument to `{use.called_func.name}`" + raise ValueError(err) + # Otherwise, mark it as used + else: + frame = get_calling_frame() + assert frame is not None + if is_running_ipython(): + module_name = "" + if defn := find_ipython_def(frame.f_code.co_name): + module_name = f"<{defn.cell_name}>" + else: + module = inspect.getmodule(frame) + module_name = module.__file__ if module else "???" + self._used = ObjectUse(module_name, frame.f_lineno, called_func) + if self._ty.linear: + state = get_tracing_state() + state.unused_objs.remove(self._id) + return self._wire + + +@dataclass(frozen=True) +class GuppyDefinition: + """A top-level Guppy definition. + + This is the object that is returned to the users when they decorate a function or + class. In particular, this is the version of the definition that is available during + tracing. + """ + + wrapped: Definition + + @property + def id(self) -> DefId: + return self.wrapped.id + + @hide_trace + def __call__(self, *args: Any) -> Any: + from guppylang.tracing.function import trace_call + + state = get_tracing_state() + defn = state.globals.build_compiled_def(self.wrapped.id) + if isinstance(defn, CompiledTracedFunctionDef): + return defn.python_func(*args) + elif isinstance(defn, CompiledCallableDef): + return trace_call(defn, *args) + elif isinstance(defn, TypeDef): + globals = get_tracing_globals() + if defn.id in globals.impls and "__new__" in globals.impls[defn.id]: + constructor = globals.defs[globals.impls[defn.id]["__new__"]] + return GuppyDefinition(constructor)(*args) + err = f"{defn.description.capitalize()} `{defn.name}` is not callable" + raise TypeError(err) + + def __getitem__(self, item: Any) -> Any: + return self + + def to_guppy_object(self) -> GuppyObject: + state = get_tracing_state() + defn = state.globals.build_compiled_def(self.id) + if isinstance(defn, CompiledValueDef): + wire = defn.load(state.dfg, state.globals, state.node) + return GuppyObject(defn.ty, wire, None) + elif isinstance(defn, TypeDef): + globals = get_tracing_globals() + if defn.id in globals.impls and "__new__" in globals.impls[defn.id]: + constructor = globals.defs[globals.impls[defn.id]["__new__"]] + return GuppyDefinition(constructor).to_guppy_object() + err = f"{defn.description.capitalize()} `{defn.name}` is not a value" + raise TypeError(err) + + def compile(self) -> Any: + from guppylang.decorator import guppy + + assert isinstance(self.wrapped, RawFunctionDef) + return guppy.compile_function(self.wrapped) diff --git a/guppylang/tracing/state.py b/guppylang/tracing/state.py new file mode 100644 index 00000000..f27b5c6b --- /dev/null +++ b/guppylang/tracing/state.py @@ -0,0 +1,54 @@ +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from guppylang.ast_util import AstNode +from guppylang.checker.core import Globals +from guppylang.compiler.core import CompiledGlobals, DFContainer + +if TYPE_CHECKING: + from guppylang.tracing.object import GuppyObject, ObjectId + + +@dataclass +class TracingState: + globals: CompiledGlobals + dfg: DFContainer + node: AstNode + + allocated_objs: "dict[ObjectId, GuppyObject]" = field(default_factory=dict) + unused_objs: "set[ObjectId]" = field(default_factory=set) + + +_STATE: TracingState | None = None +_GLOBALS: Globals | None = None + + +def get_tracing_state() -> TracingState: + if _STATE is None: + raise RuntimeError("Guppy tracing mode is not active") + return _STATE + + +def get_tracing_globals() -> Globals: + if _GLOBALS is None: + raise RuntimeError("Guppy tracing mode is not active") + return _GLOBALS + + +@contextmanager +def set_tracing_state(state: TracingState) -> None: + global _STATE + old_state = _STATE + _STATE = state + yield + _STATE = old_state + + +@contextmanager +def set_tracing_globals(globals: Globals) -> None: + global _GLOBALS + old_globals = _GLOBALS + _GLOBALS = globals + yield + _GLOBALS = old_globals diff --git a/guppylang/tracing/unpacking.py b/guppylang/tracing/unpacking.py new file mode 100644 index 00000000..03649824 --- /dev/null +++ b/guppylang/tracing/unpacking.py @@ -0,0 +1,164 @@ +from typing import Any, TypeVar + +from hugr import ops +from hugr import tys as ht +from hugr.build.dfg import DfBase + +from guppylang.ast_util import AstNode +from guppylang.checker.errors.py_errors import IllegalPyExpressionError +from guppylang.checker.expr_checker import python_value_to_guppy_type +from guppylang.compiler.expr_compiler import python_value_to_hugr +from guppylang.error import GuppyError, InternalGuppyError +from guppylang.std._internal.compiler.array import array_new, unpack_array +from guppylang.std._internal.compiler.prelude import build_unwrap +from guppylang.tracing.object import GuppyDefinition, GuppyObject +from guppylang.tracing.state import get_tracing_globals, get_tracing_state +from guppylang.tys.builtin import ( + array_type, + get_array_length, + get_element_type, + is_array_type, +) +from guppylang.tys.const import ConstValue +from guppylang.tys.ty import NoneType, TupleType + +P = TypeVar("P", bound=ops.DfParentOp) + + +def unpack_guppy_object(obj: GuppyObject, builder: DfBase[P]) -> Any: + """Tries to turn as much of the structure of a GuppyObject into Python objects. + + For example, Guppy tuples are turned into Python tuples and Guppy arrays are turned + into Python lists. + """ + match obj._ty: + case NoneType(): + return None + case TupleType(element_types=tys): + unpack = builder.add_op(ops.UnpackTuple(), obj._use_wire(None)) + return tuple( + unpack_guppy_object(GuppyObject(ty, wire), builder) + for ty, wire in zip(tys, unpack.outputs(), strict=False) + ) + case ty if is_array_type(ty): + length = get_array_length(ty) + if isinstance(length, ConstValue): + if length.value == 0: + # Zero-length lists cannot be turned back ito Guppy objects since + # there is no way to infer the type. Therefore, we should leave + # them as Guppy objects here + return obj + elem_ty = get_element_type(ty) + opt_elems = unpack_array(builder, obj._use_wire(None)) + err = "Linear array element has already been used" + elems = [build_unwrap(builder, opt_elem, err) for opt_elem in opt_elems] + return [ + unpack_guppy_object(GuppyObject(elem_ty, wire), builder) + for wire in elems + ] + else: + # Cannot handle generic sizes + return obj + case _: + return obj + + +def repack_guppy_object(v: Any, builder: DfBase[P]) -> GuppyObject: + """Undoes the `unpack_guppy_object` operation.""" + match v: + case GuppyObject() as obj: + return obj + case None: + return GuppyObject(NoneType(), builder.add_op(ops.MakeTuple())) + case tuple(vs): + objs = [repack_guppy_object(v, builder) for v in vs] + return GuppyObject( + TupleType([obj._ty for obj in objs]), + builder.add_op(ops.MakeTuple(), *(obj._use_wire(None) for obj in objs)), + ) + case list(vs) if len(vs) > 0: + objs = [repack_guppy_object(v, builder) for v in vs] + elem_ty = objs[0]._ty + hugr_elem_ty = ht.Option(elem_ty.to_hugr()) + wires = [ + builder.add_op(ops.Tag(1, hugr_elem_ty), obj._use_wire(None)) + for obj in objs + ] + return GuppyObject( + array_type(elem_ty, len(vs)), + builder.add_op(array_new(hugr_elem_ty, len(vs)), *wires), + ) + case _: + raise InternalGuppyError( + "Can only repack values that were constructed via " + "`unpack_guppy_object`" + ) + + +def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) -> None: + """Given a Python value `v` and a `GuppyObject` `obj` that was constructed from `v` + using `guppy_object_from_py`, updates the wires of any `GuppyObjects` contained in + `v` to the new wires specified by `obj`. + + Also resets the used flag on any of those updated wires. + """ + match v: + case GuppyObject() as v_obj: + assert v_obj._ty == obj._ty + v_obj._wire = obj._use_wire(None) + if v_obj._ty.linear and v_obj._used: + state = get_tracing_state() + state.unused_objs.add(v_obj._id) + v_obj._used = None + case None: + assert isinstance(obj._ty, NoneType) + case tuple(vs): + assert isinstance(obj._ty, TupleType) + wires = builder.add_op(ops.UnpackTuple(), obj._use_wire(None)).outputs() + for v, ty, wire in zip(vs, obj._ty.element_types, wires, strict=True): + update_packed_value(v, GuppyObject(ty, wire), builder) + case list(vs) if len(vs) > 0: + assert is_array_type(obj._ty) + elem_ty = get_element_type(obj._ty) + opt_wires = unpack_array(builder, obj._use_wire(None)) + err = "Linear array element has already been used" + for v, opt_wire in zip(vs, opt_wires, strict=True): + wire = build_unwrap(builder, opt_wire, err) + update_packed_value(v, GuppyObject(elem_ty, wire), builder) + case _: + pass + + +def guppy_object_from_py(v: Any, builder: DfBase[P], node: AstNode) -> GuppyObject: + match v: + case GuppyObject() as obj: + return obj + case GuppyDefinition() as defn: + return defn.to_guppy_object() + case None: + return GuppyObject(NoneType(), builder.add_op(ops.MakeTuple())) + case tuple(vs): + objs = [guppy_object_from_py(v, builder, node) for v in vs] + return GuppyObject( + TupleType([obj._ty for obj in objs]), + builder.add_op(ops.MakeTuple(), *(obj._use_wire(None) for obj in objs)), + ) + case list(vs) if len(vs) > 0: + objs = [guppy_object_from_py(v, builder, node) for v in vs] + elem_ty = objs[0]._ty + hugr_elem_ty = ht.Option(elem_ty.to_hugr()) + wires = [ + builder.add_op(ops.Tag(1, hugr_elem_ty), obj._use_wire(None)) + for obj in objs + ] + return GuppyObject( + array_type(elem_ty, len(vs)), + builder.add_op(array_new(hugr_elem_ty, len(vs)), *wires), + ) + case v: + ty = python_value_to_guppy_type(v, node, get_tracing_globals()) + if ty is None: + raise GuppyError(IllegalPyExpressionError(node, type(v))) + hugr_val = python_value_to_hugr(v, ty) + assert hugr_val is not None + return GuppyObject(ty, builder.load(hugr_val)) diff --git a/guppylang/tracing/util.py b/guppylang/tracing/util.py new file mode 100644 index 00000000..7b65ecd8 --- /dev/null +++ b/guppylang/tracing/util.py @@ -0,0 +1,63 @@ +import functools +import inspect +import sys +from types import FrameType, TracebackType + +from guppylang.error import GuppyError, exception_hook + + +def hide_trace(f): + """Function decorator that hides compiler-internal frames from the traceback of any + exception thrown by the decorated function.""" + + @functools.wraps(f) + def wrapped(*args, **kwargs): + with exception_hook(tracing_except_hook): + return f(*args, **kwargs) + + return wrapped + + +def tracing_except_hook( + excty: type[BaseException], err: BaseException, traceback: TracebackType | None +): + """Except hook that removes all compiler-internal frames from the traceback.""" + if isinstance(err, GuppyError): + diagnostic = err.error + msg = diagnostic.rendered_title + if diagnostic.span_label: + msg += f": {diagnostic.rendered_span_label}" + if diagnostic.message: + msg += f"\n{diagnostic.rendered_message}" + err = RuntimeError(msg) + excty = RuntimeError + + traceback = remove_internal_frames(traceback) + try: + # Check if we're inside a jupyter notebook since it uses its own exception + # hook. If we're in a regular interpreter, this line will raise a `NameError` + ipython_shell = get_ipython() # type: ignore[name-defined] + ipython_shell.excepthook(excty, err, traceback) + except NameError: + sys.__excepthook__(excty, err, traceback) + + +def get_calling_frame() -> FrameType | None: + """Finds the first frame that called this function outside the compiler.""" + frame = inspect.currentframe() + while frame: + module = inspect.getmodule(frame) + if module is None or not module.__name__.startswith("guppylang."): + return frame + frame = frame.f_back + return None + + +def remove_internal_frames(tb: TracebackType | None) -> TracebackType | None: + if tb: + module = inspect.getmodule(tb.tb_frame) + if module is not None and module.__name__.startswith("guppylang."): + return remove_internal_frames(tb.tb_next) + if tb.tb_next: + tb.tb_next = remove_internal_frames(tb.tb_next) + return tb diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index 86630c28..c5450809 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -255,6 +255,14 @@ def get_element_type(ty: Type) -> Type: return arg.ty +def get_array_length(ty: Type) -> Const: + assert isinstance(ty, OpaqueType) + assert ty.defn == array_type_def + [_, length_arg] = ty.args + assert isinstance(length_arg, ConstArg) + return length_arg.const + + def get_iter_size(ty: Type) -> Const: assert isinstance(ty, OpaqueType) assert ty.defn == sized_iter_type_def diff --git a/tests/integration/test_examples.py b/tests/integration/test_examples.py index 18ca8efc..2076cf0d 100644 --- a/tests/integration/test_examples.py +++ b/tests/integration/test_examples.py @@ -8,6 +8,11 @@ def test_demo_notebook(nb_regression): nb_regression.check("examples/demo.ipynb") +def test_tracing_notebook(nb_regression): + nb_regression.diff_ignore += ("/metadata/language_info/version",) + nb_regression.check("examples/tracing.ipynb") + + def test_random_walk_qpe(validate): from examples.random_walk_qpe import hugr From 6bc58909f0d9ad650a578ede2ea9f3301fd233f4 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 6 Jan 2025 10:40:12 +0100 Subject: [PATCH 02/26] Fix mypy issues --- guppylang/checker/expr_checker.py | 4 +- guppylang/decorator.py | 4 +- guppylang/definition/traced.py | 1 + guppylang/module.py | 6 +- guppylang/std/_internal/compiler/array.py | 2 +- guppylang/std/_internal/compiler/prelude.py | 2 +- guppylang/tracing/function.py | 23 ++++--- guppylang/tracing/object.py | 73 ++++++++++----------- guppylang/tracing/state.py | 5 +- guppylang/tracing/unpacking.py | 2 +- guppylang/tracing/util.py | 11 +++- 11 files changed, 74 insertions(+), 59 deletions(-) diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index 9824e4c1..b9da89f6 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -1165,7 +1165,7 @@ def eval_py_expr(node: PyExpr, ctx: Context) -> Any: def python_value_to_guppy_type( - v: Any, node: ast.expr, globals: Globals, type_hint: Type | None = None + v: Any, node: ast.AST, globals: Globals, type_hint: Type | None = None ) -> Type | None: """Turns a primitive Python value into a Guppy type. @@ -1208,7 +1208,7 @@ def python_value_to_guppy_type( def _python_list_to_guppy_type( - vs: list[Any], node: ast.expr, globals: Globals, type_hint: Type | None + vs: list[Any], node: ast.AST, globals: Globals, type_hint: Type | None ) -> OpaqueType | None: """Turns a Python list into a Guppy type. diff --git a/guppylang/decorator.py b/guppylang/decorator.py index dd10f467..e5a91821 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -466,7 +466,9 @@ def compile_module(self, id: ModuleIdentifier | None = None) -> ModulePointer: raise MissingModuleError(err) return module.compile() - def compile_function(self, f_def: RawFunctionDef) -> FuncDefnPointer: + def compile_function( + self, f_def: RawFunctionDef | RawTracedFunctionDef + ) -> FuncDefnPointer: """Compiles a single function definition.""" module = f_def.id.module if not module: diff --git a/guppylang/definition/traced.py b/guppylang/definition/traced.py index 25556b2f..05bbc12d 100644 --- a/guppylang/definition/traced.py +++ b/guppylang/definition/traced.py @@ -58,6 +58,7 @@ def compile(self) -> FuncDefnPointer: class TracedFunctionDef(RawTracedFunctionDef, CallableDef, CompilableDef): python_func: PyFunc ty: FunctionType + defined_at: ast.FunctionDef def check_call( self, args: list[ast.expr], ty: Type, node: AstNode, ctx: Context diff --git a/guppylang/module.py b/guppylang/module.py index c73302b9..0aff893b 100644 --- a/guppylang/module.py +++ b/guppylang/module.py @@ -108,8 +108,8 @@ def __init__(self, name: str, import_builtins: bool = True): def load( self, - *args: Definition | GuppyModule | ModuleType, - **kwargs: Definition | GuppyModule | ModuleType, + *args: GuppyDefinition | GuppyModule | ModuleType, + **kwargs: GuppyDefinition | GuppyModule | ModuleType, ) -> None: """Imports another Guppy module or selected definitions from a module. @@ -125,7 +125,7 @@ def load( names: dict[str, DefId] = {} # Collect imports in reverse since we'll use it as a stack to push and pop from - imports: list[tuple[str, Definition | GuppyModule | ModuleType]] = [ + imports: list[tuple[str, GuppyDefinition | GuppyModule | ModuleType]] = [ *reversed(kwargs.items()), *(("", arg) for arg in reversed(args)), ] diff --git a/guppylang/std/_internal/compiler/array.py b/guppylang/std/_internal/compiler/array.py index ac9666e3..a7ea5853 100644 --- a/guppylang/std/_internal/compiler/array.py +++ b/guppylang/std/_internal/compiler/array.py @@ -151,7 +151,7 @@ def unpack_array(builder: DfBase[P], array: Wire) -> list[Wire]: err = "Internal error: array unpacking failed" match array_ty.args: case [ht.BoundedNatArg(length), ht.TypeTypeArg(elem_ty)]: - elems = [] + elems: list[Wire] = [] for i in range(length): res = builder.add_op( array_pop(elem_ty, length - i, from_left=True), array diff --git a/guppylang/std/_internal/compiler/prelude.py b/guppylang/std/_internal/compiler/prelude.py index 518ea706..1a979de1 100644 --- a/guppylang/std/_internal/compiler/prelude.py +++ b/guppylang/std/_internal/compiler/prelude.py @@ -124,7 +124,7 @@ def build_unwrap_left( def build_unwrap( - builder: DfBase[ops.DfParentOp], option: Wire, error_msg: str, error_signal: int = 1 + builder: DfBase[P], option: Wire, error_msg: str, error_signal: int = 1 ) -> Node: """Unwraps an `hugr.tys.Option` value, panicking with the given message if the result is an error. diff --git a/guppylang/tracing/function.py b/guppylang/tracing/function.py index a9ce4696..79db2b90 100644 --- a/guppylang/tracing/function.py +++ b/guppylang/tracing/function.py @@ -1,6 +1,6 @@ from collections.abc import Callable from dataclasses import dataclass -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar from hugr import ops from hugr.build.dfg import DfBase @@ -32,6 +32,11 @@ from guppylang.tracing.util import tracing_except_hook from guppylang.tys.ty import FunctionType, InputFlags, type_to_row +if TYPE_CHECKING: + import ast + + from hugr import Wire + @dataclass(frozen=True) class TracingReturnLinearityViolationError(Error): @@ -68,13 +73,14 @@ def trace_function( # Check that the output type is correct if out_obj._ty != ty.output: - err = TypeMismatchError(node, ty.output, out_obj._ty, "return value") - raise GuppyError(err) + raise GuppyError( + TypeMismatchError(node, ty.output, out_obj._ty, "return value") + ) # Unpack regular returns out_tys = type_to_row(out_obj._ty) if len(out_tys) > 1: - regular_returns = list( + regular_returns: list[Wire] = list( builder.add_op(ops.UnpackTuple(), out_obj._use_wire(None)).outputs() ) elif len(out_tys) > 0: @@ -97,8 +103,8 @@ def trace_function( # Check that all allocated linear objects have been used if state.unused_objs: unused = state.allocated_objs[state.unused_objs.pop()] - err = f"Value with linear type `{unused._ty}` is leaked by this function" - raise GuppyError(TracingReturnLinearityViolationError(node, err)) from None + msg = f"Value with linear type `{unused._ty}` is leaked by this function" + raise GuppyError(TracingReturnLinearityViolationError(node, msg)) from None builder.set_outputs(*regular_returns, *inout_returns) @@ -106,6 +112,7 @@ def trace_function( def trace_call(func: CompiledCallableDef, *args: Any) -> Any: state = get_tracing_state() globals = get_tracing_globals() + assert func.defined_at is not None # Try to turn args into `GuppyObjects` args_objs = [ @@ -114,12 +121,12 @@ def trace_call(func: CompiledCallableDef, *args: Any) -> Any: # Create dummy variables and bind the objects to them arg_vars = [Variable(next(tmp_vars), obj._ty, None) for obj in args_objs] - locals = Locals({var.name: var.ty for var in arg_vars}) + locals = Locals({var.name: var for var in arg_vars}) for obj, var in zip(args_objs, arg_vars, strict=False): state.dfg[var] = obj._use_wire(func) # Check call - arg_exprs = [ + arg_exprs: list[ast.expr] = [ with_loc(func.defined_at, with_type(var.ty, PlaceNode(var))) for var in arg_vars ] call_node, ret_ty = func.synthesize_call( diff --git a/guppylang/tracing/object.py b/guppylang/tracing/object.py index 1b252e77..3712ea9d 100644 --- a/guppylang/tracing/object.py +++ b/guppylang/tracing/object.py @@ -12,7 +12,6 @@ from guppylang.definition.ty import TypeDef from guppylang.definition.value import CompiledCallableDef, CompiledValueDef from guppylang.ipython_inspect import find_ipython_def, is_running_ipython -from guppylang.span import Span from guppylang.tracing.state import get_tracing_globals, get_tracing_state from guppylang.tracing.util import get_calling_frame, hide_trace from guppylang.tys.ty import TupleType, Type @@ -20,90 +19,90 @@ class GetAttrDunders(ABC): @abstractmethod - def __getattr__(self, item): ... + def __getattr__(self, item: Any) -> Any: ... - def __abs__(self, other): + def __abs__(self, other: Any) -> Any: return self.__getattr__("__abs__")(other) - def __add__(self, other): + def __add__(self, other: Any) -> Any: return self.__getattr__("__add__")(other) - def __and__(self, other): + def __and__(self, other: Any) -> Any: return self.__getattr__("__and__")(other) - def __bool__(self): + def __bool__(self: Any) -> Any: return self.__getattr__("__bool__")() - def __ceil__(self): + def __ceil__(self: Any) -> Any: return self.__getattr__("__bool__")() - def __divmod__(self, other): + def __divmod__(self, other: Any) -> Any: return self.__getattr__("__divmod__")(other) - def __eq__(self, other): + def __eq__(self, other: object) -> Any: return self.__getattr__("__eq__")(other) - def __float__(self): + def __float__(self) -> Any: return self.__getattr__("__bool__")() - def __floor__(self): + def __floor__(self) -> Any: return self.__getattr__("__floor__")() - def __floordiv__(self, other): + def __floordiv__(self, other: Any) -> Any: return self.__getattr__("__floordiv__")(other) - def __ge__(self, other): + def __ge__(self, other: Any) -> Any: return self.__getattr__("__ge__")(other) - def __gt__(self, other): + def __gt__(self, other: Any) -> Any: return self.__getattr__("__gt__")(other) - def __int__(self): + def __int__(self) -> Any: return self.__getattr__("__int__")() - def __invert__(self): + def __invert__(self) -> Any: return self.__getattr__("__invert__")() - def __le__(self, other): + def __le__(self, other: Any) -> Any: return self.__getattr__("__le__")(other) - def __lshift__(self, other): + def __lshift__(self, other: Any) -> Any: return self.__getattr__("__lshift__")(other) - def __lt__(self, other): + def __lt__(self, other: Any) -> Any: return self.__getattr__("__lt__")(other) - def __mod__(self, other): + def __mod__(self, other: Any) -> Any: return self.__getattr__("__mod__")(other) - def __mul__(self, other): + def __mul__(self, other: Any) -> Any: return self.__getattr__("__mul__")(other) - def __ne__(self, other): + def __ne__(self, other: object) -> Any: return self.__getattr__("__ne__")(other) - def __neg__(self): + def __neg__(self) -> Any: return self.__getattr__("__neg__")() - def __or__(self, other): + def __or__(self, other: Any) -> Any: return self.__getattr__("__or__")(other) - def __pos__(self): + def __pos__(self) -> Any: return self.__getattr__("__pos__")() - def __pow__(self, other): + def __pow__(self, other: Any) -> Any: return self.__getattr__("__pow__")(other) - def __sub__(self, other): + def __sub__(self, other: Any) -> Any: return self.__getattr__("__sub__")(other) - def __truediv__(self, other): + def __truediv__(self, other: Any) -> Any: return self.__getattr__("__truediv__")(other) - def __trunc__(self): + def __trunc__(self) -> Any: return self.__getattr__("__trunc__")() - def __xor__(self, other): + def __xor__(self, other: Any) -> Any: return self.__getattr__("__xor__")(other) @@ -128,7 +127,7 @@ class GuppyObject(GetAttrDunders): _used: ObjectUse | None _id: ObjectId - def __init__(self, ty: Type, wire: Wire, used: Span | None = None) -> None: + def __init__(self, ty: Type, wire: Wire, used: ObjectUse | None = None) -> None: self._ty = ty self._wire = wire self._used = used @@ -139,17 +138,17 @@ def __init__(self, ty: Type, wire: Wire, used: Span | None = None) -> None: state.unused_objs.add(self._id) @hide_trace - def __getattr__(self, name: str): + def __getattr__(self, key: str) -> Any: # type: ignore[misc] globals = get_tracing_globals() - func = globals.get_instance_func(self._ty, name) + func = globals.get_instance_func(self._ty, key) if func is None: raise AttributeError( - f"Expression of type `{self._ty}` has no attribute `{name}`" + f"Expression of type `{self._ty}` has no attribute `{key}`" ) return lambda *xs: GuppyDefinition(func)(self, *xs) @hide_trace - def __bool__(self): + def __bool__(self) -> Any: err = ( "Branching on a dynamic value is not allowed during tracing. Try using " "a regular guppy function" @@ -157,7 +156,7 @@ def __bool__(self): raise ValueError(err) @hide_trace - def __iter__(self): + def __iter__(self) -> Any: state = get_tracing_state() builder = state.dfg.builder if isinstance(self._ty, TupleType): @@ -191,7 +190,7 @@ def _use_wire(self, called_func: CompiledCallableDef | None) -> Wire: module_name = f"<{defn.cell_name}>" else: module = inspect.getmodule(frame) - module_name = module.__file__ if module else "???" + module_name = module.__file__ if module and module.__file__ else "???" self._used = ObjectUse(module_name, frame.f_lineno, called_func) if self._ty.linear: state = get_tracing_state() diff --git a/guppylang/tracing/state.py b/guppylang/tracing/state.py index f27b5c6b..020919e0 100644 --- a/guppylang/tracing/state.py +++ b/guppylang/tracing/state.py @@ -1,3 +1,4 @@ +from collections.abc import Iterator from contextlib import contextmanager from dataclasses import dataclass, field from typing import TYPE_CHECKING @@ -37,7 +38,7 @@ def get_tracing_globals() -> Globals: @contextmanager -def set_tracing_state(state: TracingState) -> None: +def set_tracing_state(state: TracingState) -> Iterator[None]: global _STATE old_state = _STATE _STATE = state @@ -46,7 +47,7 @@ def set_tracing_state(state: TracingState) -> None: @contextmanager -def set_tracing_globals(globals: Globals) -> None: +def set_tracing_globals(globals: Globals) -> Iterator[None]: global _GLOBALS old_globals = _GLOBALS _GLOBALS = globals diff --git a/guppylang/tracing/unpacking.py b/guppylang/tracing/unpacking.py index 03649824..0bd6d219 100644 --- a/guppylang/tracing/unpacking.py +++ b/guppylang/tracing/unpacking.py @@ -123,7 +123,7 @@ def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) -> None: opt_wires = unpack_array(builder, obj._use_wire(None)) err = "Linear array element has already been used" for v, opt_wire in zip(vs, opt_wires, strict=True): - wire = build_unwrap(builder, opt_wire, err) + (wire,) = build_unwrap(builder, opt_wire, err).outputs() update_packed_value(v, GuppyObject(elem_ty, wire), builder) case _: pass diff --git a/guppylang/tracing/util.py b/guppylang/tracing/util.py index 7b65ecd8..749cd5cf 100644 --- a/guppylang/tracing/util.py +++ b/guppylang/tracing/util.py @@ -1,17 +1,22 @@ import functools import inspect import sys +from collections.abc import Callable from types import FrameType, TracebackType +from typing import ParamSpec, TypeVar from guppylang.error import GuppyError, exception_hook +P = ParamSpec("P") +T = TypeVar("T") -def hide_trace(f): + +def hide_trace(f: Callable[P, T]) -> Callable[P, T]: """Function decorator that hides compiler-internal frames from the traceback of any exception thrown by the decorated function.""" @functools.wraps(f) - def wrapped(*args, **kwargs): + def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: with exception_hook(tracing_except_hook): return f(*args, **kwargs) @@ -20,7 +25,7 @@ def wrapped(*args, **kwargs): def tracing_except_hook( excty: type[BaseException], err: BaseException, traceback: TracebackType | None -): +) -> None: """Except hook that removes all compiler-internal frames from the traceback.""" if isinstance(err, GuppyError): diagnostic = err.error From dff8cd8ebc9e4bce9f39a24fa3743eb3b3c5311d Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 6 Jan 2025 12:17:34 +0100 Subject: [PATCH 03/26] Don't retrace functions on call --- guppylang/tracing/object.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/guppylang/tracing/object.py b/guppylang/tracing/object.py index 3712ea9d..e3ecbab2 100644 --- a/guppylang/tracing/object.py +++ b/guppylang/tracing/object.py @@ -219,9 +219,7 @@ def __call__(self, *args: Any) -> Any: state = get_tracing_state() defn = state.globals.build_compiled_def(self.wrapped.id) - if isinstance(defn, CompiledTracedFunctionDef): - return defn.python_func(*args) - elif isinstance(defn, CompiledCallableDef): + if isinstance(defn, CompiledCallableDef): return trace_call(defn, *args) elif isinstance(defn, TypeDef): globals = get_tracing_globals() From c64823d8012768a38475eea589048140805a7cec Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 6 Jan 2025 12:18:39 +0100 Subject: [PATCH 04/26] Fix demo notebook kernel --- examples/tracing.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/tracing.ipynb b/examples/tracing.ipynb index f7509b73..cef83980 100644 --- a/examples/tracing.ipynb +++ b/examples/tracing.ipynb @@ -506,9 +506,9 @@ ], "metadata": { "kernelspec": { - "display_name": "uv_guppy", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "uv_guppy" + "name": "python3" }, "language_info": { "codemirror_mode": { From 34b0444348820a94f9ba28a71eabfed0c5977f2f Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 6 Jan 2025 12:23:15 +0100 Subject: [PATCH 05/26] Fix lint --- guppylang/tracing/object.py | 1 - 1 file changed, 1 deletion(-) diff --git a/guppylang/tracing/object.py b/guppylang/tracing/object.py index e3ecbab2..48089962 100644 --- a/guppylang/tracing/object.py +++ b/guppylang/tracing/object.py @@ -8,7 +8,6 @@ from guppylang.definition.common import DefId, Definition from guppylang.definition.function import RawFunctionDef -from guppylang.definition.traced import CompiledTracedFunctionDef from guppylang.definition.ty import TypeDef from guppylang.definition.value import CompiledCallableDef, CompiledValueDef from guppylang.ipython_inspect import find_ipython_def, is_running_ipython From 313c9cf145419867da8a03f4b7ab13e8c70ff5c9 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Mon, 6 Jan 2025 12:33:57 +0100 Subject: [PATCH 06/26] Rename traced to comptime --- examples/tracing.ipynb | 52 +++++++++++++++++++++--------------------- guppylang/decorator.py | 2 +- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/examples/tracing.ipynb b/examples/tracing.ipynb index cef83980..8bd19ced 100644 --- a/examples/tracing.ipynb +++ b/examples/tracing.ipynb @@ -5,9 +5,9 @@ "id": "fccf55ad-3db2-40a4-89c7-8c5e08f3397c", "metadata": {}, "source": [ - "# Tracing Demo\n", + "# Comptime Demo\n", "\n", - "This notebook showcases the new experimental tracing mode in Guppy." + "This notebook showcases the new experimental comptime mode in Guppy." ] }, { @@ -35,7 +35,7 @@ "source": [ "## Intro\n", "\n", - "Traced functions are executed using the Python interpreter which drives the Hugr generation. Thus, traced functions can contain arbitrary Python code, but everything is evaluated at compile-time. The result is a flat Hugr program:" + "Comptime functions are executed using the Python interpreter which drives the Hugr generation. Thus, comptime functions can contain arbitrary Python code, but everything is evaluated at compile-time. The result is a flat Hugr program:" ] }, { @@ -48,7 +48,7 @@ "module = GuppyModule(\"test\")\n", "module.load_all(quantum)\n", "\n", - "@guppy.traced(module)\n", + "@guppy.comptime(module)\n", "def ladder() -> array[qubit, 5]:\n", " qs = [qubit() for _ in range(5)]\n", " for q in qs:\n", @@ -65,7 +65,7 @@ "id": "427aea92-dc18-4094-850e-1c5a6a4338ca", "metadata": {}, "source": [ - "Traced functions can be called from regular Guppy functions and vice versa." + "Comptime functions can be called from regular Guppy functions and vice versa." ] }, { @@ -90,11 +90,11 @@ "def regular1() -> tuple[qubit, qubit]:\n", " q1 = qubit()\n", " h(q1)\n", - " q2 = traced(q1)\n", + " q2 = comptime_func(q1)\n", " return q1, q2\n", "\n", - "@guppy.traced(module)\n", - "def traced(q: qubit) -> qubit:\n", + "@guppy.comptime(module)\n", + "def comptime_func(q: qubit) -> qubit:\n", " r = regular2()\n", " cx(q, r)\n", " return r\n", @@ -113,7 +113,7 @@ "id": "0465ad90-9fa9-4e07-97fe-3e37266429e1", "metadata": {}, "source": [ - "Traced functions can even call out to non-Guppy functions and pass qubits along as data:" + "Comptime functions can even call out to non-Guppy functions and pass qubits along as data:" ] }, { @@ -126,7 +126,7 @@ "module = GuppyModule(\"test\")\n", "module.load_all(quantum)\n", "\n", - "@guppy.traced(module)\n", + "@guppy.comptime(module)\n", "def foo() -> qubit:\n", " q = qubit()\n", " bar(q)\n", @@ -165,7 +165,7 @@ "module = GuppyModule(\"test\")\n", "module.load_all(quantum)\n", "\n", - "@guppy.traced(module)\n", + "@guppy.comptime(module)\n", "def foo(q: qubit, x: float) -> None:\n", " x = x * 2\n", " #print(x) # What is x?\n", @@ -205,7 +205,7 @@ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", "Cell \u001b[0;32mIn[6], line 10\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m x \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m3\u001b[39m:\n\u001b[1;32m 8\u001b[0m rz(q, angle(x))\n\u001b[0;32m---> 10\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompile\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m;\n", - "Cell \u001b[0;32mIn[6], line 7\u001b[0m, in \u001b[0;36mfoo\u001b[0;34m(q, x)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;129m@guppy\u001b[39m\u001b[38;5;241m.\u001b[39mtraced(module)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfoo\u001b[39m(q: qubit, x: \u001b[38;5;28mfloat\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 6\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2\u001b[39m\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m x \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m3\u001b[39m:\n\u001b[1;32m 8\u001b[0m rz(q, angle(x))\n", + "Cell \u001b[0;32mIn[6], line 7\u001b[0m, in \u001b[0;36mfoo\u001b[0;34m(q, x)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;129m@guppy\u001b[39m\u001b[38;5;241m.\u001b[39mcomptime(module)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfoo\u001b[39m(q: qubit, x: \u001b[38;5;28mfloat\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 6\u001b[0m x \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m*\u001b[39m \u001b[38;5;241m2\u001b[39m\n\u001b[0;32m----> 7\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m x \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m3\u001b[39m:\n\u001b[1;32m 8\u001b[0m rz(q, angle(x))\n", "\u001b[0;31mValueError\u001b[0m: Branching on a dynamic value is not allowed during tracing. Try using a regular guppy function" ] } @@ -214,7 +214,7 @@ "module = GuppyModule(\"test\")\n", "module.load_all(quantum)\n", "\n", - "@guppy.traced(module)\n", + "@guppy.comptime(module)\n", "def foo(q: qubit, x: float) -> None:\n", " x = x * 2\n", " if x > 3:\n", @@ -254,7 +254,7 @@ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", "Cell \u001b[0;32mIn[7], line 9\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m measure(q):\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mXXX\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 9\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompile\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m;\n", - "Cell \u001b[0;32mIn[7], line 6\u001b[0m, in \u001b[0;36mfoo\u001b[0;34m(q)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;129m@guppy\u001b[39m\u001b[38;5;241m.\u001b[39mtraced(module)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfoo\u001b[39m(q: qubit) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m----> 6\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m measure(q):\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mXXX\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "Cell \u001b[0;32mIn[7], line 6\u001b[0m, in \u001b[0;36mfoo\u001b[0;34m(q)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;129m@guppy\u001b[39m\u001b[38;5;241m.\u001b[39mcomptime(module)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfoo\u001b[39m(q: qubit) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m----> 6\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m measure(q):\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mXXX\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", "\u001b[0;31mValueError\u001b[0m: Branching on a dynamic value is not allowed during tracing. Try using a regular guppy function" ] } @@ -263,7 +263,7 @@ "module = GuppyModule(\"test\")\n", "module.load_all(quantum)\n", "\n", - "@guppy.traced(module)\n", + "@guppy.comptime(module)\n", "def foo(q: qubit) -> None:\n", " if measure(q):\n", " print(\"XXX\")\n", @@ -301,8 +301,8 @@ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", - "Cell \u001b[0;32mIn[8], line 8\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;129m@guppy\u001b[39m\u001b[38;5;241m.\u001b[39mtraced(module)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfoo\u001b[39m(n: \u001b[38;5;28mint\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 6\u001b[0m qs \u001b[38;5;241m=\u001b[39m [qubit() \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(n)]\n\u001b[0;32m----> 8\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompile\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m;\n", - "Cell \u001b[0;32mIn[8], line 6\u001b[0m, in \u001b[0;36mfoo\u001b[0;34m(n)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;129m@guppy\u001b[39m\u001b[38;5;241m.\u001b[39mtraced(module)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfoo\u001b[39m(n: \u001b[38;5;28mint\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m----> 6\u001b[0m qs \u001b[38;5;241m=\u001b[39m [qubit() \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28;43mrange\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mn\u001b[49m\u001b[43m)\u001b[49m]\n", + "Cell \u001b[0;32mIn[8], line 8\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;129m@guppy\u001b[39m\u001b[38;5;241m.\u001b[39mcomptime(module)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfoo\u001b[39m(n: \u001b[38;5;28mint\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 6\u001b[0m qs \u001b[38;5;241m=\u001b[39m [qubit() \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(n)]\n\u001b[0;32m----> 8\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompile\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m;\n", + "Cell \u001b[0;32mIn[8], line 6\u001b[0m, in \u001b[0;36mfoo\u001b[0;34m(n)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;129m@guppy\u001b[39m\u001b[38;5;241m.\u001b[39mcomptime(module)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfoo\u001b[39m(n: \u001b[38;5;28mint\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m----> 6\u001b[0m qs \u001b[38;5;241m=\u001b[39m [qubit() \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28;43mrange\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mn\u001b[49m\u001b[43m)\u001b[49m]\n", "\u001b[0;31mTypeError\u001b[0m: 'GuppyObject' object cannot be interpreted as an integer" ] } @@ -311,7 +311,7 @@ "module = GuppyModule(\"test\")\n", "module.load_all(quantum)\n", "\n", - "@guppy.traced(module)\n", + "@guppy.comptime(module)\n", "def foo(n: int) -> None:\n", " qs = [qubit() for _ in range(n)]\n", "\n", @@ -347,7 +347,7 @@ "@guppy.declare(module)\n", "def foo(qs: array[qubit, 10]) -> None: ...\n", "\n", - "@guppy.traced(module)\n", + "@guppy.comptime(module)\n", "def bar(qs: array[qubit, 10]) -> None:\n", " # Arrays are iterable in the Python context\n", " for q1, q2 in zip(qs[1:], qs[:-1]):\n", @@ -393,8 +393,8 @@ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", " \u001b[0;31m[... skipping hidden 1 frame]\u001b[0m\n", - "Cell \u001b[0;32mIn[10], line 8\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;129m@guppy\u001b[39m\u001b[38;5;241m.\u001b[39mtraced(module)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbad\u001b[39m(q: qubit) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 6\u001b[0m cx(q, q)\n\u001b[0;32m----> 8\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompile\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", - "Cell \u001b[0;32mIn[10], line 6\u001b[0m, in \u001b[0;36mbad\u001b[0;34m(q)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;129m@guppy\u001b[39m\u001b[38;5;241m.\u001b[39mtraced(module)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbad\u001b[39m(q: qubit) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m----> 6\u001b[0m \u001b[43mcx\u001b[49m\u001b[43m(\u001b[49m\u001b[43mq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mq\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[10], line 8\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;129m@guppy\u001b[39m\u001b[38;5;241m.\u001b[39mcomptime(module)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbad\u001b[39m(q: qubit) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 6\u001b[0m cx(q, q)\n\u001b[0;32m----> 8\u001b[0m \u001b[43mmodule\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompile\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "Cell \u001b[0;32mIn[10], line 6\u001b[0m, in \u001b[0;36mbad\u001b[0;34m(q)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;129m@guppy\u001b[39m\u001b[38;5;241m.\u001b[39mcomptime(module)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbad\u001b[39m(q: qubit) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m----> 6\u001b[0m \u001b[43mcx\u001b[49m\u001b[43m(\u001b[49m\u001b[43mq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mq\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[0;31mValueError\u001b[0m: Value with linear type `qubit` was already used\n\nPrevious use occurred in :6 as an argument to `cx`" ] } @@ -403,7 +403,7 @@ "module = GuppyModule(\"test\")\n", "module.load_all(quantum)\n", "\n", - "@guppy.traced(module)\n", + "@guppy.comptime(module)\n", "def bad(q: qubit) -> None:\n", " cx(q, q)\n", "\n", @@ -431,7 +431,7 @@ "Error: Linearity violation in function return (at :5:0)\n", " | \n", "3 | \n", - "4 | @guppy.traced(module)\n", + "4 | @guppy.comptime(module)\n", "5 | def bad(q: qubit) -> qubit:\n", " | ^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "6 | return q\n", @@ -450,7 +450,7 @@ "module = GuppyModule(\"test\")\n", "module.load_all(quantum)\n", "\n", - "@guppy.traced(module)\n", + "@guppy.comptime(module)\n", "def bad(q: qubit) -> qubit:\n", " return q\n", "\n", @@ -478,7 +478,7 @@ "Error: Linearity violation in function return (at :5:0)\n", " | \n", "3 | \n", - "4 | @guppy.traced(module)\n", + "4 | @guppy.comptime(module)\n", "5 | def bad(q: qubit) -> None:\n", " | ^^^^^^^^^^^^^^^^^^^^^^^^^^\n", " | ...\n", @@ -495,7 +495,7 @@ "module = GuppyModule(\"test\")\n", "module.load_all(quantum)\n", "\n", - "@guppy.traced(module)\n", + "@guppy.comptime(module)\n", "def bad(q: qubit) -> None:\n", " tmp = qubit()\n", " cx(tmp, q)\n", diff --git a/guppylang/decorator.py b/guppylang/decorator.py index e5a91821..3b79f3b8 100644 --- a/guppylang/decorator.py +++ b/guppylang/decorator.py @@ -168,7 +168,7 @@ def _get_python_caller(self, fn: PyFunc | None = None) -> ModuleIdentifier: ) @pretty_errors - def traced(self, arg: PyFunc | GuppyModule) -> FuncDecorator | GuppyDefinition: + def comptime(self, arg: PyFunc | GuppyModule) -> FuncDecorator | GuppyDefinition: def dec(f: Callable[..., Any], module: GuppyModule) -> GuppyDefinition: defn = RawTracedFunctionDef(DefId.fresh(module), f.__name__, None, f, {}) module.register_def(defn) From b52522e24b995391f5cf849e17ec5806473a5344 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 15 Jan 2025 09:24:09 +0000 Subject: [PATCH 07/26] Add missing dunder methods --- guppylang/tracing/object.py | 39 +++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/guppylang/tracing/object.py b/guppylang/tracing/object.py index 48089962..8f0ad50c 100644 --- a/guppylang/tracing/object.py +++ b/guppylang/tracing/object.py @@ -92,6 +92,45 @@ def __pos__(self) -> Any: def __pow__(self, other: Any) -> Any: return self.__getattr__("__pow__")(other) + def __radd__(self, other: Any) -> Any: + return self.__getattr__("__pow__")(other) + + def __rand__(self, other: Any) -> Any: + return self.__getattr__("__pow__")(other) + + def __rfloordiv__(self, other: Any) -> Any: + return self.__getattr__("__pow__")(other) + + def __rlshift__(self, other: Any) -> Any: + return self.__getattr__("__pow__")(other) + + def __rmod__(self, other: Any) -> Any: + return self.__getattr__("__pow__")(other) + + def __rmul__(self, other: Any) -> Any: + return self.__getattr__("__pow__")(other) + + def __ror__(self, other: Any) -> Any: + return self.__getattr__("__pow__")(other) + + def __rpow__(self, other: Any) -> Any: + return self.__getattr__("__pow__")(other) + + def __rrshift__(self, other: Any) -> Any: + return self.__getattr__("__pow__")(other) + + def __rshift__(self, other: Any) -> Any: + return self.__getattr__("__pow__")(other) + + def __rsub__(self, other: Any) -> Any: + return self.__getattr__("__pow__")(other) + + def __rtruediv__(self, other: Any) -> Any: + return self.__getattr__("__pow__")(other) + + def __rxor__(self, other: Any) -> Any: + return self.__getattr__("__pow__")(other) + def __sub__(self, other: Any) -> Any: return self.__getattr__("__sub__")(other) From f814a79d0e8a20cb3b8fe677a22f593ec5e1e17d Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 15 Jan 2025 11:02:06 +0000 Subject: [PATCH 08/26] Also try reverse dunder methods --- guppylang/tracing/object.py | 78 ++++++++++++++++++++++++++++++------- guppylang/tracing/util.py | 4 +- 2 files changed, 66 insertions(+), 16 deletions(-) diff --git a/guppylang/tracing/object.py b/guppylang/tracing/object.py index 8f0ad50c..1bca1483 100644 --- a/guppylang/tracing/object.py +++ b/guppylang/tracing/object.py @@ -1,8 +1,10 @@ +import functools import inspect import itertools from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, NamedTuple +from typing import Any, NamedTuple, TypeAlias from hugr import Wire, ops @@ -10,22 +12,54 @@ from guppylang.definition.function import RawFunctionDef from guppylang.definition.ty import TypeDef from guppylang.definition.value import CompiledCallableDef, CompiledValueDef +from guppylang.error import GuppyError from guppylang.ipython_inspect import find_ipython_def, is_running_ipython from guppylang.tracing.state import get_tracing_globals, get_tracing_state from guppylang.tracing.util import get_calling_frame, hide_trace from guppylang.tys.ty import TupleType, Type +DunderMethod: TypeAlias = Callable[["GetAttrDunders", Any], Any] + + +def also_try_reversed(method_name: str) -> Callable[[DunderMethod], DunderMethod]: + """Decorator to delegate calls to dunder methods like `__add__` on `GuppyObject`s to + their reversed version `__radd__` if the original one doesn't type check. + """ + + def decorator(f: DunderMethod) -> DunderMethod: + @functools.wraps(f) + def wrapped(self: "GetAttrDunders", other: Any) -> Any: + try: + return f(self, other) + except GuppyError: + from guppylang.tracing.state import get_tracing_state + from guppylang.tracing.unpacking import guppy_object_from_py + + state = get_tracing_state() + obj = guppy_object_from_py(other, state.dfg.builder, state.node) + return obj.__getattr__(method_name)(self) + + return wrapped + + return decorator + class GetAttrDunders(ABC): + """Mixin class to allow `GuppyObject`s to be used in arithmetic expressions etc. + via providing the corresponding dunder methods delegating to the objects impls. + """ + @abstractmethod def __getattr__(self, item: Any) -> Any: ... def __abs__(self, other: Any) -> Any: return self.__getattr__("__abs__")(other) + @also_try_reversed("__radd__") def __add__(self, other: Any) -> Any: return self.__getattr__("__add__")(other) + @also_try_reversed("__rand__") def __and__(self, other: Any) -> Any: return self.__getattr__("__and__")(other) @@ -38,6 +72,7 @@ def __ceil__(self: Any) -> Any: def __divmod__(self, other: Any) -> Any: return self.__getattr__("__divmod__")(other) + @also_try_reversed("__eq__") def __eq__(self, other: object) -> Any: return self.__getattr__("__eq__")(other) @@ -47,12 +82,15 @@ def __float__(self) -> Any: def __floor__(self) -> Any: return self.__getattr__("__floor__")() + @also_try_reversed("__rfloordiv__") def __floordiv__(self, other: Any) -> Any: return self.__getattr__("__floordiv__")(other) + @also_try_reversed("__le__") def __ge__(self, other: Any) -> Any: return self.__getattr__("__ge__")(other) + @also_try_reversed("__lt__") def __gt__(self, other: Any) -> Any: return self.__getattr__("__gt__")(other) @@ -62,84 +100,96 @@ def __int__(self) -> Any: def __invert__(self) -> Any: return self.__getattr__("__invert__")() + @also_try_reversed("__ge__") def __le__(self, other: Any) -> Any: return self.__getattr__("__le__")(other) + @also_try_reversed("__rlshift__") def __lshift__(self, other: Any) -> Any: return self.__getattr__("__lshift__")(other) + @also_try_reversed("__gt__") def __lt__(self, other: Any) -> Any: return self.__getattr__("__lt__")(other) + @also_try_reversed("__rmod__") def __mod__(self, other: Any) -> Any: return self.__getattr__("__mod__")(other) + @also_try_reversed("__rmul__") def __mul__(self, other: Any) -> Any: return self.__getattr__("__mul__")(other) + @also_try_reversed("__ne__") def __ne__(self, other: object) -> Any: return self.__getattr__("__ne__")(other) def __neg__(self) -> Any: return self.__getattr__("__neg__")() + @also_try_reversed("__ror__") def __or__(self, other: Any) -> Any: return self.__getattr__("__or__")(other) def __pos__(self) -> Any: return self.__getattr__("__pos__")() + @also_try_reversed("__rpow__") def __pow__(self, other: Any) -> Any: return self.__getattr__("__pow__")(other) def __radd__(self, other: Any) -> Any: - return self.__getattr__("__pow__")(other) + return self.__getattr__("__radd__")(other) def __rand__(self, other: Any) -> Any: - return self.__getattr__("__pow__")(other) + return self.__getattr__("__rand__")(other) def __rfloordiv__(self, other: Any) -> Any: - return self.__getattr__("__pow__")(other) + return self.__getattr__("__rfloordiv__")(other) def __rlshift__(self, other: Any) -> Any: - return self.__getattr__("__pow__")(other) + return self.__getattr__("__rlshift__")(other) def __rmod__(self, other: Any) -> Any: - return self.__getattr__("__pow__")(other) + return self.__getattr__("__rmod__")(other) def __rmul__(self, other: Any) -> Any: - return self.__getattr__("__pow__")(other) + return self.__getattr__("__rmul__")(other) def __ror__(self, other: Any) -> Any: - return self.__getattr__("__pow__")(other) + return self.__getattr__("__ror__")(other) def __rpow__(self, other: Any) -> Any: - return self.__getattr__("__pow__")(other) + return self.__getattr__("__rpow__")(other) def __rrshift__(self, other: Any) -> Any: return self.__getattr__("__pow__")(other) + @also_try_reversed("__rrshift__") def __rshift__(self, other: Any) -> Any: - return self.__getattr__("__pow__")(other) + return self.__getattr__("__rshift__")(other) def __rsub__(self, other: Any) -> Any: - return self.__getattr__("__pow__")(other) + return self.__getattr__("__rsub__")(other) def __rtruediv__(self, other: Any) -> Any: - return self.__getattr__("__pow__")(other) + return self.__getattr__("__rtruediv__")(other) def __rxor__(self, other: Any) -> Any: - return self.__getattr__("__pow__")(other) + return self.__getattr__("__rxor__")(other) + @also_try_reversed("__rsub__") def __sub__(self, other: Any) -> Any: return self.__getattr__("__sub__")(other) + @also_try_reversed("__rtruediv__") def __truediv__(self, other: Any) -> Any: return self.__getattr__("__truediv__")(other) def __trunc__(self) -> Any: return self.__getattr__("__trunc__")() + @also_try_reversed("__rxor__") def __xor__(self, other: Any) -> Any: return self.__getattr__("__xor__")(other) @@ -209,7 +259,7 @@ def __iter__(self) -> Any: def _use_wire(self, called_func: CompiledCallableDef | None) -> Wire: # Panic if the value has already been used - if self._used: + if self._used and self._ty.linear: use = self._used err = ( f"Value with linear type `{self._ty}` was already used\n\n" diff --git a/guppylang/tracing/util.py b/guppylang/tracing/util.py index 749cd5cf..ed1e2d59 100644 --- a/guppylang/tracing/util.py +++ b/guppylang/tracing/util.py @@ -34,8 +34,8 @@ def tracing_except_hook( msg += f": {diagnostic.rendered_span_label}" if diagnostic.message: msg += f"\n{diagnostic.rendered_message}" - err = RuntimeError(msg) - excty = RuntimeError + err = TypeError(msg) + excty = TypeError traceback = remove_internal_frames(traceback) try: From 8810c6fba847cc0cdef4c4b39ee8d54d4403ce0a Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 15 Jan 2025 11:02:35 +0000 Subject: [PATCH 09/26] Remove unnecessary repack --- guppylang/tracing/function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guppylang/tracing/function.py b/guppylang/tracing/function.py index 79db2b90..44049f06 100644 --- a/guppylang/tracing/function.py +++ b/guppylang/tracing/function.py @@ -64,7 +64,7 @@ def trace_function( py_out = python_func(*inputs) try: - out_obj = repack_guppy_object(py_out, builder) + out_obj = guppy_object_from_py(py_out, builder, node) except ValueError as err: # Linearity violation in the return statement raise GuppyError( From 13fa2092bce3079717bc5ddd7632cf884ac1b69a Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 15 Jan 2025 17:15:36 +0000 Subject: [PATCH 10/26] Fix dunders and rename --- guppylang/tracing/object.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/guppylang/tracing/object.py b/guppylang/tracing/object.py index 1bca1483..3c0cb5df 100644 --- a/guppylang/tracing/object.py +++ b/guppylang/tracing/object.py @@ -18,7 +18,7 @@ from guppylang.tracing.util import get_calling_frame, hide_trace from guppylang.tys.ty import TupleType, Type -DunderMethod: TypeAlias = Callable[["GetAttrDunders", Any], Any] +DunderMethod: TypeAlias = Callable[["DunderMixin", Any], Any] def also_try_reversed(method_name: str) -> Callable[[DunderMethod], DunderMethod]: @@ -28,7 +28,7 @@ def also_try_reversed(method_name: str) -> Callable[[DunderMethod], DunderMethod def decorator(f: DunderMethod) -> DunderMethod: @functools.wraps(f) - def wrapped(self: "GetAttrDunders", other: Any) -> Any: + def wrapped(self: "DunderMixin", other: Any) -> Any: try: return f(self, other) except GuppyError: @@ -44,7 +44,7 @@ def wrapped(self: "GetAttrDunders", other: Any) -> Any: return decorator -class GetAttrDunders(ABC): +class DunderMixin(ABC): """Mixin class to allow `GuppyObject`s to be used in arithmetic expressions etc. via providing the corresponding dunder methods delegating to the objects impls. """ @@ -52,8 +52,8 @@ class GetAttrDunders(ABC): @abstractmethod def __getattr__(self, item: Any) -> Any: ... - def __abs__(self, other: Any) -> Any: - return self.__getattr__("__abs__")(other) + def __abs__(self) -> Any: + return self.__getattr__("__abs__")() @also_try_reversed("__radd__") def __add__(self, other: Any) -> Any: @@ -67,7 +67,7 @@ def __bool__(self: Any) -> Any: return self.__getattr__("__bool__")() def __ceil__(self: Any) -> Any: - return self.__getattr__("__bool__")() + return self.__getattr__("__ceil__")() def __divmod__(self, other: Any) -> Any: return self.__getattr__("__divmod__")(other) @@ -77,7 +77,7 @@ def __eq__(self, other: object) -> Any: return self.__getattr__("__eq__")(other) def __float__(self) -> Any: - return self.__getattr__("__bool__")() + return self.__getattr__("__float__")() def __floor__(self) -> Any: return self.__getattr__("__floor__")() @@ -207,7 +207,7 @@ class ObjectUse(NamedTuple): fresh_id = itertools.count() -class GuppyObject(GetAttrDunders): +class GuppyObject(DunderMixin): """The runtime representation of abstract Guppy objects during tracing.""" _ty: Type From b52a7a904100b1203386575a03be0aa44da3a0d3 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Wed, 15 Jan 2025 17:27:21 +0000 Subject: [PATCH 11/26] Support builtins float, int, and len --- guppylang/tracing/builtins_mock.py | 62 ++++++++++++++++++++++++++++++ guppylang/tracing/function.py | 2 + 2 files changed, 64 insertions(+) create mode 100644 guppylang/tracing/builtins_mock.py diff --git a/guppylang/tracing/builtins_mock.py b/guppylang/tracing/builtins_mock.py new file mode 100644 index 00000000..9d6d4449 --- /dev/null +++ b/guppylang/tracing/builtins_mock.py @@ -0,0 +1,62 @@ +"""Mocks for a number of builtin functions to be used during tracing. + +For example, the builtin `int(x)` function tries calling `x.__int__()` and raises a +`TypeError` if this call doesn't return an `int`. During tracing however, we also want +to allow `int(x)` if `x` is an abstract `GuppyObject`. To allow this, we need to mock +the builtins to avoid raising the `TypeError` in that case. +""" + +import builtins +from collections.abc import Callable +from typing import Any + +from guppylang.tracing.object import GuppyObject + + +def _mock_meta(cls: type) -> type: + """Returns a metaclass that replicates the behaviour of the provided class. + + The only way to distinguishing a `_mock_meta(T)` from an actual `T` is via checking + reference equality using the `is` operator. + """ + + class MockMeta(type): + def __instancecheck__(self, instance: Any) -> bool: + return cls.__instancecheck__(instance) + + def __subclasscheck__(self, subclass: type) -> bool: + return cls.__subclasscheck__(subclass) + + def __eq__(self, other: object) -> Any: + return other == builtins.int + + def __ne__(self, other: object) -> Any: + return other != builtins.int + + MockMeta.__name__ = type.__name__ + MockMeta.__qualname__ = type.__qualname__ + return MockMeta + + +class float(metaclass=_mock_meta(builtins.float)): # type: ignore[misc] + def __new__(cls, x: Any = 0.0, /) -> Any: + if isinstance(x, GuppyObject): + return x.__float__() + return builtins.float(x) + + +class int(metaclass=_mock_meta(builtins.int)): # type: ignore[misc] + def __new__(cls, x: Any = 0, /, **kwargs: Any) -> Any: + if isinstance(x, GuppyObject): + return x.__int__(**kwargs) + return builtins.int(x, **kwargs) + + +def len(x: Any) -> Any: + if isinstance(x, GuppyObject): + return x.__len__() + return builtins.len(x) + + +def mock_builtins(f: Callable[..., Any]) -> None: + f.__globals__.update({"float": float, "int": int, "len": len}) diff --git a/guppylang/tracing/function.py b/guppylang/tracing/function.py index 44049f06..4a168666 100644 --- a/guppylang/tracing/function.py +++ b/guppylang/tracing/function.py @@ -15,6 +15,7 @@ from guppylang.diagnostic import Error from guppylang.error import GuppyError, exception_hook from guppylang.nodes import PlaceNode +from guppylang.tracing.builtins_mock import mock_builtins from guppylang.tracing.object import GuppyObject from guppylang.tracing.state import ( TracingState, @@ -61,6 +62,7 @@ def trace_function( ] with exception_hook(tracing_except_hook): + mock_builtins(python_func) py_out = python_func(*inputs) try: From 3dd1bd23eb06c6cda0850c1091877153dedcfce6 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 4 Feb 2025 10:39:02 +0000 Subject: [PATCH 12/26] Add basic tests --- tests/integration/tracing/test_basic.py | 69 +++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 tests/integration/tracing/test_basic.py diff --git a/tests/integration/tracing/test_basic.py b/tests/integration/tracing/test_basic.py new file mode 100644 index 00000000..7f23a317 --- /dev/null +++ b/tests/integration/tracing/test_basic.py @@ -0,0 +1,69 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +from hugr import ops +from hugr.std.int import IntVal + + +def test_flat(validate): + module = GuppyModule("module") + + @guppy.comptime(module) + def foo() -> int: + x = 0 + for i in range(10): + x += i + return x + + hugr = module.compile() + assert hugr.module.num_nodes() == 6 + [const] = [ + data.op for _, data in hugr.module.nodes() if isinstance(data.op, ops.Const) + ] + assert isinstance(const.val, IntVal) + assert const.val.v == sum(i for i in range(10)) + validate(hugr) + + +def test_inputs(validate): + module = GuppyModule("module") + + @guppy.comptime(module) + def foo(x: int, y: float) -> tuple[int, float]: + return x, y + + validate(module.compile()) + + +def test_recursion(validate): + module = GuppyModule("module") + + @guppy.comptime(module) + def foo(x: int) -> int: + # `foo` doesn't terminate but the compiler should! + return foo(x) + + validate(module.compile()) + + +def test_calls(validate): + module = GuppyModule("module") + + @guppy.comptime(module) + def comptime1(x: int) -> int: + return regular1(x) + + @guppy(module) + def regular1(x: int) -> int: + return comptime2(x) + + @guppy.comptime(module) + def comptime2(x: int) -> int: + return regular2(x) + + @guppy(module) + def regular2(x: int) -> int: + return comptime1(x) + + validate(module.compile()) + From a99b93b2749859343a315d2d5a90408ad8b03bc2 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 4 Feb 2025 10:39:20 +0000 Subject: [PATCH 13/26] Add arithmetic tests --- execute_llvm/src/lib.rs | 59 ++++++-- tests/integration/conftest.py | 12 +- tests/integration/tracing/test_arithmetic.py | 135 +++++++++++++++++++ 3 files changed, 190 insertions(+), 16 deletions(-) create mode 100644 tests/integration/tracing/test_arithmetic.py diff --git a/execute_llvm/src/lib.rs b/execute_llvm/src/lib.rs index 1c10e97a..47dafa00 100644 --- a/execute_llvm/src/lib.rs +++ b/execute_llvm/src/lib.rs @@ -5,6 +5,8 @@ use hugr::llvm::CodegenExtsBuilder; use hugr::package::Package; use hugr::Hugr; use hugr::{self, ops, std_extensions, HugrView}; +use inkwell::types::BasicType; +use inkwell::values::BasicMetadataValueEnum; use inkwell::{context::Context, module::Module, values::GenericValue}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; @@ -79,9 +81,11 @@ fn compile_module<'a>( Ok(emitter.finish()) } -fn run_function( +fn run_function( hugr_json: &str, fn_name: &str, + args: &[T], + encode_arg: impl Fn(&Context, T) -> BasicMetadataValueEnum, parse_result: impl FnOnce(&Context, GenericValue) -> PyResult, ) -> PyResult { let mut hugr = parse_hugr(hugr_json)?; @@ -98,10 +102,29 @@ fn run_function( .get_function(&mangled_name) .ok_or(pyerr!("Couldn't find function {} in module", mangled_name))?; + // Build a new function that calls the target function with the provided arguments. + // Calling `ExecutionEngine::run_function` with arguments directly always segfaults for some + // reason... + let main = module.add_function( + "__main__", + fv.get_type().get_return_type().unwrap().fn_type(&[], false), + None, + ); + let bb = ctx.append_basic_block(main, ""); + let builder = ctx.create_builder(); + builder.position_at_end(bb); + let args: Vec<_> = args.iter().map(|a| encode_arg(&ctx, a.clone())).collect(); + let res = builder + .build_call(fv, &args, "") + .unwrap() + .try_as_basic_value() + .unwrap_left(); + builder.build_return(Some(&res)).unwrap(); + let ee = module .create_execution_engine() .map_err(|_| pyerr!("Failed to create execution engine"))?; - let llvm_result = unsafe { ee.run_function(fv, &[]) }; + let llvm_result = unsafe { ee.run_function(main, &[]) }; parse_result(&ctx, llvm_result) } @@ -124,19 +147,29 @@ mod execute_llvm { } #[pyfunction] - fn run_int_function(hugr_json: &str, fn_name: &str) -> PyResult { - run_function::(hugr_json, fn_name, |_, llvm_val| { - // GenericVal is 64 bits wide - let int_with_sign = llvm_val.as_int(true); - let signed_int = int_with_sign as i64; - Ok(signed_int) - }) + fn run_int_function(hugr_json: &str, fn_name: &str, args: Vec) -> PyResult { + run_function::( + hugr_json, + fn_name, + &args, + |ctx, i| ctx.i64_type().const_int(i as u64, true).into(), + |_, llvm_val| { + // GenericVal is 64 bits wide + let int_with_sign = llvm_val.as_int(true); + let signed_int = int_with_sign as i64; + Ok(signed_int) + }, + ) } #[pyfunction] - fn run_float_function(hugr_json: &str, fn_name: &str) -> PyResult { - run_function::(hugr_json, fn_name, |ctx, llvm_val| { - Ok(llvm_val.as_float(&ctx.f64_type())) - }) + fn run_float_function(hugr_json: &str, fn_name: &str, args: Vec) -> PyResult { + run_function::( + hugr_json, + fn_name, + &args, + |ctx, f| ctx.f64_type().const_float(f).into(), + |ctx, llvm_val| Ok(llvm_val.as_float(&ctx.f64_type())), + ) } } diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 314149a8..07dcc025 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -71,7 +71,7 @@ class LLVMException(Exception): def _run_fn(run_fn_name: str): - def f(module: ModulePointer, expected: Any, fn_name: str = "main"): + def f(module: ModulePointer, expected: Any, fn_name: str = "main", args: list[Any] | None = None): try: import execute_llvm @@ -80,7 +80,7 @@ def f(module: ModulePointer, expected: Any, fn_name: str = "main"): pytest.skip("Skipping llvm execution") package_json: str = module.package.to_json() - res = fn(package_json, fn_name) + res = fn(package_json, fn_name, args or []) if res != expected: raise LLVMException( f"Expected value ({expected}) doesn't match actual value ({res})" @@ -108,11 +108,17 @@ def run_approx( hugr: Package, expected: float, fn_name: str = "main", + args: list[Any] | None = None, *, rel: float | None = None, abs: float | None = None, nan_ok: bool = False, ): - return run_fn(hugr, pytest.approx(expected, rel=rel, abs=abs, nan_ok=nan_ok)) + return run_fn( + hugr, + pytest.approx(expected, rel=rel, abs=abs, nan_ok=nan_ok), + fn_name, + args, + ) return run_approx diff --git a/tests/integration/tracing/test_arithmetic.py b/tests/integration/tracing/test_arithmetic.py new file mode 100644 index 00000000..b408bc85 --- /dev/null +++ b/tests/integration/tracing/test_arithmetic.py @@ -0,0 +1,135 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.std.builtins import nat + + +def test_int(validate, run_int_fn): + module = GuppyModule("module") + + @guppy.comptime(module) + def pos(x: int) -> int: + return +x + + @guppy.comptime(module) + def neg(x: int) -> int: + return -x + + @guppy.comptime(module) + def add(x: int, y: int) -> int: + return 1 + (x + (y + 2)) + + @guppy.comptime(module) + def sub(x: int, y: int) -> int: + return 1 - (x - (y - 2)) + + @guppy.comptime(module) + def mul(x: int, y: int) -> int: + return 1 * (x * (y * 2)) + + @guppy.comptime(module) + def div(x: int, y: int) -> int: + return 100 // (x // (y // 2)) + + @guppy.comptime(module) + def mod(x: int, y: int) -> int: + return 15 % (x % (y % 10)) + + @guppy.comptime(module) + def pow(x: int, y: int) -> int: + return 4 ** (x ** (y ** 0)) + + @guppy(module) + def main() -> None: + """Dummy main function""" + + compiled = module.compile() + validate(compiled) + + run_int_fn(compiled, 10, "pos", [10]) + run_int_fn(compiled, -10, "neg", [10]) + run_int_fn(compiled, 2, "add", [3, -4]) + run_int_fn(compiled, -8, "sub", [3, -4]) + run_int_fn(compiled, -24, "mul", [3, -4]) + run_int_fn(compiled, 20, "div", [25, 10]) + run_int_fn(compiled, 7, "mod", [8, 9]) + run_int_fn(compiled, 16, "pow", [2, 100]) + + +def test_float(validate, run_float_fn_approx): + module = GuppyModule("module") + + @guppy.comptime(module) + def pos(x: float) -> float: + return +x + + @guppy.comptime(module) + def neg(x: float) -> float: + return -x + + @guppy.comptime(module) + def add(x: float, y: float) -> float: + return 1 + (x + (y + 2)) + + @guppy.comptime(module) + def sub(x: float, y: float) -> float: + return 1 - (x - (y - 2)) + + @guppy.comptime(module) + def mul(x: float, y: float) -> float: + return 1 * (x * (y * 2)) + + @guppy.comptime(module) + def div(x: float, y: float) -> float: + return 100 / (x / (y / 2)) + + # TODO: Requires lowering of `ffloor` op: https://github.com/CQCL/hugr/issues/1905 + # @guppy.comptime(module) + # def floordiv(x: float, y: float) -> float: + # return 100 // (x // (y // 2)) + + # TODO: Requires lowering of `fpow` op: https://github.com/CQCL/hugr/issues/1905 + # @guppy.comptime(module) + # def pow(x: float, y: float) -> float: + # return 4 ** (x ** (y ** 0.5)) + + @guppy(module) + def main() -> None: + """Dummy main function""" + + compiled = module.compile() + validate(compiled) + + run_float_fn_approx(compiled, 10.5, "pos", [10.5]) + run_float_fn_approx(compiled, -10.5, "neg", [10.5]) + run_float_fn_approx(compiled, 1.5, "add", [3, -4.5]) + run_float_fn_approx(compiled, -8.5, "sub", [3, -4.5]) + run_float_fn_approx(compiled, -27.0, "mul", [3, -4.5]) + run_float_fn_approx(compiled, 400.0, "div", [0.5, 4]) + + # TODO: Requires lowering of `ffloor` op: https://github.com/CQCL/hugr/issues/1905 + # run_float_fn_approx(compiled, ... "div", [...]) + + # TODO: Requires lowering of `fpow` op: https://github.com/CQCL/hugr/issues/1905 + # run_float_fn_approx(compiled, ..., "pow", [...]) + + +def test_mixed(validate, run_int_fn): + module = GuppyModule("module") + + @guppy.comptime(module) + def foo(x: int, y: float) -> int: + a = 1 + (x + 2) + b = 1 - (a - 2) + c = 2 * (b * 3) + x += 1 + x *= y + z = 1 + x + return int(z / 2) + + @guppy(module) + def main() -> int: + return foo(1, 2.0) + + compiled = module.compile() + validate(compiled) + run_int_fn(compiled, 42) From f8965c2243e0d33728d201784e3ec7a92966e3da Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 4 Feb 2025 10:44:07 +0000 Subject: [PATCH 14/26] Check that called function is available in the current module --- guppylang/tracing/object.py | 23 ++++++++++++++++++----- tests/integration/tracing/test_typing.py | 23 +++++++++++++++++++++++ 2 files changed, 41 insertions(+), 5 deletions(-) create mode 100644 tests/integration/tracing/test_typing.py diff --git a/guppylang/tracing/object.py b/guppylang/tracing/object.py index 3c0cb5df..416a7819 100644 --- a/guppylang/tracing/object.py +++ b/guppylang/tracing/object.py @@ -305,15 +305,28 @@ def id(self) -> DefId: def __call__(self, *args: Any) -> Any: from guppylang.tracing.function import trace_call + # Check that the functions is loaded in the current module + globals = get_tracing_globals() + if self.wrapped.id not in globals.defs: + assert self.wrapped.id.module is not None + err = ( + f"{self.wrapped.description.capitalize()} `{self.wrapped.name}` is not " + f"available in this module, consider importing it from " + f"`{self.wrapped.id.module.name}`" + ) + raise TypeError(err) + state = get_tracing_state() defn = state.globals.build_compiled_def(self.wrapped.id) if isinstance(defn, CompiledCallableDef): return trace_call(defn, *args) - elif isinstance(defn, TypeDef): - globals = get_tracing_globals() - if defn.id in globals.impls and "__new__" in globals.impls[defn.id]: - constructor = globals.defs[globals.impls[defn.id]["__new__"]] - return GuppyDefinition(constructor)(*args) + elif ( + isinstance(defn, TypeDef) + and defn.id in globals.impls + and "__new__" in globals.impls[defn.id] + ): + constructor = globals.defs[globals.impls[defn.id]["__new__"]] + return GuppyDefinition(constructor)(*args) err = f"{defn.description.capitalize()} `{defn.name}` is not callable" raise TypeError(err) diff --git a/tests/integration/tracing/test_typing.py b/tests/integration/tracing/test_typing.py new file mode 100644 index 00000000..2d91bbce --- /dev/null +++ b/tests/integration/tracing/test_typing.py @@ -0,0 +1,23 @@ +import pytest + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +def test_wrong_module(): + module1 = GuppyModule("module1") + module2 = GuppyModule("module2") + + @guppy.declare(module1) + def foo() -> int: ... + + @guppy.comptime(module2) + def main() -> int: + return foo() + + err = ( + "Function `foo` is not available in this module, consider importing it from " + "`module1`" + ) + with pytest.raises(TypeError, match=err): + module2.compile() From 9b2d0ea2d5a68c0ba1c2319b572ed66e641f58be Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 4 Feb 2025 12:29:23 +0000 Subject: [PATCH 15/26] Support golden tests with tracebacks --- guppylang/error.py | 24 +----------------------- guppylang/tracing/function.py | 3 ++- guppylang/tracing/util.py | 30 ++++++++++++++++++++---------- tests/error/util.py | 15 +++++++++++++-- 4 files changed, 36 insertions(+), 36 deletions(-) diff --git a/guppylang/error.py b/guppylang/error.py index e60bdc2e..754af281 100644 --- a/guppylang/error.py +++ b/guppylang/error.py @@ -1,5 +1,4 @@ import functools -import os import sys from collections.abc import Callable, Iterator from contextlib import contextmanager @@ -7,8 +6,6 @@ from types import TracebackType from typing import TYPE_CHECKING, Any, TypeVar, cast -from guppylang.ipython_inspect import is_running_ipython - if TYPE_CHECKING: from guppylang.diagnostic import Error, Fatal @@ -101,25 +98,6 @@ def hook( @functools.wraps(f) def pretty_errors_wrapped(*args: Any, **kwargs: Any) -> Any: with exception_hook(hook): - try: - return f(*args, **kwargs) - except GuppyError as err: - # For normal usage, this `try` block is not necessary since the - # excepthook is automatically invoked when the exception (which is being - # reraised below) is not handled. However, when running tests, we have - # to manually invoke the hook to print the error message, since the - # tests always have to capture exceptions. The only exception are - # notebook tests which don't rely on the capsys fixture. - if _pytest_running() and not is_running_ipython(): - hook(type(err), err, err.__traceback__) - raise + return f(*args, **kwargs) return cast(FuncT, pretty_errors_wrapped) - - -def _pytest_running() -> bool: - """Checks if we are currently running pytest. - - See https://docs.pytest.org/en/latest/example/simple.html#pytest-current-test-environment-variable - """ - return "PYTEST_CURRENT_TEST" in os.environ diff --git a/guppylang/tracing/function.py b/guppylang/tracing/function.py index 4a168666..c56d9637 100644 --- a/guppylang/tracing/function.py +++ b/guppylang/tracing/function.py @@ -30,7 +30,7 @@ unpack_guppy_object, update_packed_value, ) -from guppylang.tracing.util import tracing_except_hook +from guppylang.tracing.util import capture_guppy_errors, tracing_except_hook from guppylang.tys.ty import FunctionType, InputFlags, type_to_row if TYPE_CHECKING: @@ -111,6 +111,7 @@ def trace_function( builder.set_outputs(*regular_returns, *inout_returns) +@capture_guppy_errors def trace_call(func: CompiledCallableDef, *args: Any) -> Any: state = get_tracing_state() globals = get_tracing_globals() diff --git a/guppylang/tracing/util.py b/guppylang/tracing/util.py index ed1e2d59..832d48a8 100644 --- a/guppylang/tracing/util.py +++ b/guppylang/tracing/util.py @@ -11,6 +11,26 @@ T = TypeVar("T") +def capture_guppy_errors(f: Callable[P, T]) -> Callable[P, T]: + """Context manager that captures Guppy errors and turns them into runtime + `TypeError`s.""" + + @functools.wraps(f) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> T: + try: + return f(*args, **kwargs) + except GuppyError as err: + diagnostic = err.error + msg = diagnostic.rendered_title + if diagnostic.span_label: + msg += f": {diagnostic.rendered_span_label}" + if diagnostic.message: + msg += f"\n{diagnostic.rendered_message}" + raise TypeError(msg) from None + + return wrapped + + def hide_trace(f: Callable[P, T]) -> Callable[P, T]: """Function decorator that hides compiler-internal frames from the traceback of any exception thrown by the decorated function.""" @@ -27,16 +47,6 @@ def tracing_except_hook( excty: type[BaseException], err: BaseException, traceback: TracebackType | None ) -> None: """Except hook that removes all compiler-internal frames from the traceback.""" - if isinstance(err, GuppyError): - diagnostic = err.error - msg = diagnostic.rendered_title - if diagnostic.span_label: - msg += f": {diagnostic.rendered_span_label}" - if diagnostic.message: - msg += f"\n{diagnostic.rendered_message}" - err = TypeError(msg) - excty = TypeError - traceback = remove_internal_frames(traceback) try: # Check if we're inside a jupyter notebook since it uses its own exception diff --git a/tests/error/util.py b/tests/error/util.py index a52b015c..e23f121e 100644 --- a/tests/error/util.py +++ b/tests/error/util.py @@ -1,10 +1,12 @@ import importlib.util +import inspect import pathlib +import sys + import pytest from hugr import tys from hugr.tys import TypeBound -from guppylang.error import GuppyError from guppylang.module import GuppyModule import guppylang.decorator as decorator @@ -13,9 +15,18 @@ def run_error_test(file, capsys, snapshot): file = pathlib.Path(file) - with pytest.raises(GuppyError): + with pytest.raises(Exception) as exc_info: importlib.import_module(f"tests.error.{file.parent.name}.{file.name}") + # Remove the importlib frames from the traceback by skipping beginning frames until + # we end up in the executed file + tb = exc_info.tb + while tb is not None and inspect.getfile(tb.tb_frame) != str(file): + tb = tb.tb_next + + # Invoke except hook to print the exception to stderr + sys.excepthook(exc_info.type, exc_info.value.with_traceback(tb), tb) + err = capsys.readouterr().err err = err.replace(str(file), "$FILE") From 0fdb1cf178b59da8e9a2840229bcc74f63ee69eb Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 4 Feb 2025 12:33:35 +0000 Subject: [PATCH 16/26] Add some golden tests --- tests/error/test_tracing_errors.py | 19 +++++++++++++++++++ tests/error/tracing_errors/__init__.py | 0 tests/error/tracing_errors/bad_arg.err | 6 ++++++ tests/error/tracing_errors/bad_arg.py | 16 ++++++++++++++++ tests/error/tracing_errors/bad_return1.err | 10 ++++++++++ tests/error/tracing_errors/bad_return1.py | 12 ++++++++++++ tests/error/tracing_errors/bad_return2.err | 11 +++++++++++ tests/error/tracing_errors/bad_return2.py | 15 +++++++++++++++ tests/error/tracing_errors/no_return.err | 10 ++++++++++ tests/error/tracing_errors/no_return.py | 12 ++++++++++++ .../error/tracing_errors/not_enough_args.err | 6 ++++++ tests/error/tracing_errors/not_enough_args.py | 16 ++++++++++++++++ tests/error/tracing_errors/too_many_args.err | 6 ++++++ tests/error/tracing_errors/too_many_args.py | 16 ++++++++++++++++ 14 files changed, 155 insertions(+) create mode 100644 tests/error/test_tracing_errors.py create mode 100644 tests/error/tracing_errors/__init__.py create mode 100644 tests/error/tracing_errors/bad_arg.err create mode 100644 tests/error/tracing_errors/bad_arg.py create mode 100644 tests/error/tracing_errors/bad_return1.err create mode 100644 tests/error/tracing_errors/bad_return1.py create mode 100644 tests/error/tracing_errors/bad_return2.err create mode 100644 tests/error/tracing_errors/bad_return2.py create mode 100644 tests/error/tracing_errors/no_return.err create mode 100644 tests/error/tracing_errors/no_return.py create mode 100644 tests/error/tracing_errors/not_enough_args.err create mode 100644 tests/error/tracing_errors/not_enough_args.py create mode 100644 tests/error/tracing_errors/too_many_args.err create mode 100644 tests/error/tracing_errors/too_many_args.py diff --git a/tests/error/test_tracing_errors.py b/tests/error/test_tracing_errors.py new file mode 100644 index 00000000..40bb1f51 --- /dev/null +++ b/tests/error/test_tracing_errors.py @@ -0,0 +1,19 @@ +import pathlib +import pytest + +from tests.error.util import run_error_test + +path = pathlib.Path(__file__).parent.resolve() / "tracing_errors" +files = [ + x + for x in path.iterdir() + if x.is_file() and x.suffix == ".py" and x.name != "__init__.py" +] + +# Turn paths into strings, otherwise pytest doesn't display the names +files = [str(f) for f in files] + + +@pytest.mark.parametrize("file", files) +def test_tracing_errors(file, capsys, snapshot): + run_error_test(file, capsys, snapshot) diff --git a/tests/error/tracing_errors/__init__.py b/tests/error/tracing_errors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/error/tracing_errors/bad_arg.err b/tests/error/tracing_errors/bad_arg.err new file mode 100644 index 00000000..7e69a224 --- /dev/null +++ b/tests/error/tracing_errors/bad_arg.err @@ -0,0 +1,6 @@ +Traceback (most recent call last): + File "$FILE", line 16, in + module.compile() + File "$FILE", line 13, in test + foo(1.0) +TypeError: Type mismatch: Expected argument of type `int`, got `float` diff --git a/tests/error/tracing_errors/bad_arg.py b/tests/error/tracing_errors/bad_arg.py new file mode 100644 index 00000000..59520e57 --- /dev/null +++ b/tests/error/tracing_errors/bad_arg.py @@ -0,0 +1,16 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + + +@guppy.declare(module) +def foo(x: int) -> None: ... + + +@guppy.comptime(module) +def test() -> None: + foo(1.0) + + +module.compile() diff --git a/tests/error/tracing_errors/bad_return1.err b/tests/error/tracing_errors/bad_return1.err new file mode 100644 index 00000000..046ecd11 --- /dev/null +++ b/tests/error/tracing_errors/bad_return1.err @@ -0,0 +1,10 @@ +Error: Type mismatch (at $FILE:8:0) + | +6 | +7 | @guppy.comptime(module) +8 | def test() -> int: + | ^^^^^^^^^^^^^^^^^^ +9 | return 1.0 + | ^^^^^^^^^^^^^^ Expected return value of type `int`, got `float` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/tracing_errors/bad_return1.py b/tests/error/tracing_errors/bad_return1.py new file mode 100644 index 00000000..8f9c9710 --- /dev/null +++ b/tests/error/tracing_errors/bad_return1.py @@ -0,0 +1,12 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + + +@guppy.comptime(module) +def test() -> int: + return 1.0 + + +module.compile() diff --git a/tests/error/tracing_errors/bad_return2.err b/tests/error/tracing_errors/bad_return2.err new file mode 100644 index 00000000..559d2d8f --- /dev/null +++ b/tests/error/tracing_errors/bad_return2.err @@ -0,0 +1,11 @@ +Error: Type mismatch (at $FILE:8:0) + | + 6 | + 7 | @guppy.comptime(module) + 8 | def test() -> int: + | ^^^^^^^^^^^^^^^^^^ + | ... +12 | return 1 + | ^^^^^^^^^^^^^^^^ Expected return value of type `int`, got `float` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/tracing_errors/bad_return2.py b/tests/error/tracing_errors/bad_return2.py new file mode 100644 index 00000000..c15e7518 --- /dev/null +++ b/tests/error/tracing_errors/bad_return2.py @@ -0,0 +1,15 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + + +@guppy.comptime(module) +def test() -> int: + if True: + return 1.0 + else: + return 1 + + +module.compile() diff --git a/tests/error/tracing_errors/no_return.err b/tests/error/tracing_errors/no_return.err new file mode 100644 index 00000000..1a58a382 --- /dev/null +++ b/tests/error/tracing_errors/no_return.err @@ -0,0 +1,10 @@ +Error: Type mismatch (at $FILE:8:0) + | +6 | +7 | @guppy.comptime(module) +8 | def test() -> int: + | ^^^^^^^^^^^^^^^^^^ +9 | pass + | ^^^^^^^^ Expected return value of type `int`, got `None` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/tracing_errors/no_return.py b/tests/error/tracing_errors/no_return.py new file mode 100644 index 00000000..d3861577 --- /dev/null +++ b/tests/error/tracing_errors/no_return.py @@ -0,0 +1,12 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + + +@guppy.comptime(module) +def test() -> int: + pass + + +module.compile() diff --git a/tests/error/tracing_errors/not_enough_args.err b/tests/error/tracing_errors/not_enough_args.err new file mode 100644 index 00000000..782eee2b --- /dev/null +++ b/tests/error/tracing_errors/not_enough_args.err @@ -0,0 +1,6 @@ +Traceback (most recent call last): + File "$FILE", line 16, in + module.compile() + File "$FILE", line 13, in test + foo(1) +TypeError: Not enough arguments: Expected 2, got 1 diff --git a/tests/error/tracing_errors/not_enough_args.py b/tests/error/tracing_errors/not_enough_args.py new file mode 100644 index 00000000..7cb9c856 --- /dev/null +++ b/tests/error/tracing_errors/not_enough_args.py @@ -0,0 +1,16 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + + +@guppy.declare(module) +def foo(x: int, y: int) -> None: ... + + +@guppy.comptime(module) +def test() -> None: + foo(1) + + +module.compile() diff --git a/tests/error/tracing_errors/too_many_args.err b/tests/error/tracing_errors/too_many_args.err new file mode 100644 index 00000000..4ba1fb7c --- /dev/null +++ b/tests/error/tracing_errors/too_many_args.err @@ -0,0 +1,6 @@ +Traceback (most recent call last): + File "$FILE", line 16, in + module.compile() + File "$FILE", line 13, in test + foo(1, 2, 3, 4) +TypeError: Too many arguments: Expected 2, got 4 diff --git a/tests/error/tracing_errors/too_many_args.py b/tests/error/tracing_errors/too_many_args.py new file mode 100644 index 00000000..c02276e9 --- /dev/null +++ b/tests/error/tracing_errors/too_many_args.py @@ -0,0 +1,16 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + + +@guppy.declare(module) +def foo(x: int, y: int) -> None: ... + + +@guppy.comptime(module) +def test() -> None: + foo(1, 2, 3, 4) + + +module.compile() From 7a7b01495ba48115cca0d7b4163ba8b7ad1d701f Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 4 Feb 2025 13:09:10 +0000 Subject: [PATCH 17/26] Remove bad test --- tests/integration/tracing/test_arithmetic.py | 22 -------------------- 1 file changed, 22 deletions(-) diff --git a/tests/integration/tracing/test_arithmetic.py b/tests/integration/tracing/test_arithmetic.py index b408bc85..456b2385 100644 --- a/tests/integration/tracing/test_arithmetic.py +++ b/tests/integration/tracing/test_arithmetic.py @@ -111,25 +111,3 @@ def main() -> None: # TODO: Requires lowering of `fpow` op: https://github.com/CQCL/hugr/issues/1905 # run_float_fn_approx(compiled, ..., "pow", [...]) - - -def test_mixed(validate, run_int_fn): - module = GuppyModule("module") - - @guppy.comptime(module) - def foo(x: int, y: float) -> int: - a = 1 + (x + 2) - b = 1 - (a - 2) - c = 2 * (b * 3) - x += 1 - x *= y - z = 1 + x - return int(z / 2) - - @guppy(module) - def main() -> int: - return foo(1, 2.0) - - compiled = module.compile() - validate(compiled) - run_int_fn(compiled, 42) From 3701b7e491decc0690e93ed42b8f2789ae588096 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 4 Feb 2025 13:09:35 +0000 Subject: [PATCH 18/26] Add __init__.py --- tests/integration/tracing/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/integration/tracing/__init__.py diff --git a/tests/integration/tracing/__init__.py b/tests/integration/tracing/__init__.py new file mode 100644 index 00000000..e69de29b From af05e9865dfa571a3f41faf3db058909d4b75708 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 4 Feb 2025 13:12:55 +0000 Subject: [PATCH 19/26] Improve dunder checks and errors --- guppylang/tracing/object.py | 148 ++++++++++++++----- tests/error/tracing_errors/bad_binary1.err | 7 + tests/error/tracing_errors/bad_binary1.py | 12 ++ tests/error/tracing_errors/bad_binary2.err | 7 + tests/error/tracing_errors/bad_binary2.py | 12 ++ tests/error/tracing_errors/bad_unary.err | 7 + tests/error/tracing_errors/bad_unary.py | 12 ++ tests/integration/tracing/test_arithmetic.py | 30 ++++ 8 files changed, 198 insertions(+), 37 deletions(-) create mode 100644 tests/error/tracing_errors/bad_binary1.err create mode 100644 tests/error/tracing_errors/bad_binary1.py create mode 100644 tests/error/tracing_errors/bad_binary2.err create mode 100644 tests/error/tracing_errors/bad_binary2.py create mode 100644 tests/error/tracing_errors/bad_unary.err create mode 100644 tests/error/tracing_errors/bad_unary.py diff --git a/guppylang/tracing/object.py b/guppylang/tracing/object.py index 416a7819..f2004b44 100644 --- a/guppylang/tracing/object.py +++ b/guppylang/tracing/object.py @@ -3,45 +3,104 @@ import itertools from abc import ABC, abstractmethod from collections.abc import Callable +from contextlib import suppress from dataclasses import dataclass from typing import Any, NamedTuple, TypeAlias from hugr import Wire, ops +from guppylang.checker.errors.type_errors import BinaryOperatorNotDefinedError, \ + UnaryOperatorNotDefinedError +import guppylang.checker.expr_checker as expr_checker from guppylang.definition.common import DefId, Definition from guppylang.definition.function import RawFunctionDef from guppylang.definition.ty import TypeDef from guppylang.definition.value import CompiledCallableDef, CompiledValueDef -from guppylang.error import GuppyError +from guppylang.error import GuppyError, GuppyTypeError from guppylang.ipython_inspect import find_ipython_def, is_running_ipython from guppylang.tracing.state import get_tracing_globals, get_tracing_state -from guppylang.tracing.util import get_calling_frame, hide_trace +from guppylang.tracing.util import get_calling_frame, hide_trace, capture_guppy_errors from guppylang.tys.ty import TupleType, Type -DunderMethod: TypeAlias = Callable[["DunderMixin", Any], Any] +# Mapping from unary dunder method to display name of the operation +unary_table = { + method: display_name for method, display_name in expr_checker.unary_table.values() +} +# Mapping from binary dunder method to reversed method and display name of the operation +binary_table = { + method: (reverse_method, display_name) + for method, reverse_method, display_name in expr_checker.binary_table.values() +} -def also_try_reversed(method_name: str) -> Callable[[DunderMethod], DunderMethod]: - """Decorator to delegate calls to dunder methods like `__add__` on `GuppyObject`s to - their reversed version `__radd__` if the original one doesn't type check. +# Mapping from reverse binary dunder method to original method and display name of the +# operation +reverse_binary_table = { + reverse_method: (method, display_name) + for method, reverse_method, display_name in expr_checker.binary_table.values() +} + +UnaryDunderMethod: TypeAlias = Callable[["DunderMixin"], Any] +BinaryDunderMethod: TypeAlias = Callable[["DunderMixin", Any], Any] + + +def unary_operation(f: UnaryDunderMethod) -> UnaryDunderMethod: + """Decorator for methods corresponding to unary operations like `__neg__` etc. + + Emits a user error if the binary operation is not defined for the given type. """ - def decorator(f: DunderMethod) -> DunderMethod: - @functools.wraps(f) - def wrapped(self: "DunderMixin", other: Any) -> Any: - try: - return f(self, other) - except GuppyError: - from guppylang.tracing.state import get_tracing_state - from guppylang.tracing.unpacking import guppy_object_from_py + @functools.wraps(f) + @capture_guppy_errors + def wrapped(self: "DunderMixin") -> Any: + with suppress(Exception): + return f(self) - state = get_tracing_state() - obj = guppy_object_from_py(other, state.dfg.builder, state.node) - return obj.__getattr__(method_name)(self) + from guppylang.tracing.state import get_tracing_state + + state = get_tracing_state() + assert isinstance(self, GuppyObject) + raise GuppyTypeError( + UnaryOperatorNotDefinedError(state.node, self._ty, unary_table[f.__name__]) + ) - return wrapped + return wrapped + + +def binary_operation(f: BinaryDunderMethod) -> BinaryDunderMethod: + """Decorator for methods corresponding to binary operations like `__add__` etc. + + Delegate calls to their reversed versions `__radd__` etc. if the original one + doesn't type check. Otherwise, emits an error informing the user that the binary + operation is not defined for those types. + """ + + @functools.wraps(f) + @capture_guppy_errors + def wrapped(self: "DunderMixin", other: Any) -> Any: + with suppress(Exception): + return f(self, other) + with suppress(Exception): + from guppylang.tracing.state import get_tracing_state + from guppylang.tracing.unpacking import guppy_object_from_py + + state = get_tracing_state() + obj = guppy_object_from_py(other, state.dfg.builder, state.node) + + if f.__name__ in binary_table: + reverse_method, display_name = binary_table[f.__name__] + left_ty, right_ty = self._ty, obj._ty + else: + reverse_method, display_name = reverse_binary_table[f.__name__] + left_ty, right_ty = obj._ty, self._ty + return obj.__getattr__(reverse_method)(self) + + assert isinstance(self, GuppyObject) + raise GuppyTypeError( + BinaryOperatorNotDefinedError(state.node, left_ty, right_ty, display_name) + ) - return decorator + return wrapped class DunderMixin(ABC): @@ -55,11 +114,11 @@ def __getattr__(self, item: Any) -> Any: ... def __abs__(self) -> Any: return self.__getattr__("__abs__")() - @also_try_reversed("__radd__") + @binary_operation def __add__(self, other: Any) -> Any: return self.__getattr__("__add__")(other) - @also_try_reversed("__rand__") + @binary_operation def __and__(self, other: Any) -> Any: return self.__getattr__("__and__")(other) @@ -72,7 +131,7 @@ def __ceil__(self: Any) -> Any: def __divmod__(self, other: Any) -> Any: return self.__getattr__("__divmod__")(other) - @also_try_reversed("__eq__") + @binary_operation def __eq__(self, other: object) -> Any: return self.__getattr__("__eq__")(other) @@ -82,114 +141,129 @@ def __float__(self) -> Any: def __floor__(self) -> Any: return self.__getattr__("__floor__")() - @also_try_reversed("__rfloordiv__") + @binary_operation def __floordiv__(self, other: Any) -> Any: return self.__getattr__("__floordiv__")(other) - @also_try_reversed("__le__") + @binary_operation def __ge__(self, other: Any) -> Any: return self.__getattr__("__ge__")(other) - @also_try_reversed("__lt__") + @binary_operation def __gt__(self, other: Any) -> Any: return self.__getattr__("__gt__")(other) def __int__(self) -> Any: return self.__getattr__("__int__")() + @unary_operation def __invert__(self) -> Any: return self.__getattr__("__invert__")() - @also_try_reversed("__ge__") + @binary_operation def __le__(self, other: Any) -> Any: return self.__getattr__("__le__")(other) - @also_try_reversed("__rlshift__") + @binary_operation def __lshift__(self, other: Any) -> Any: return self.__getattr__("__lshift__")(other) - @also_try_reversed("__gt__") + @binary_operation def __lt__(self, other: Any) -> Any: return self.__getattr__("__lt__")(other) - @also_try_reversed("__rmod__") + @binary_operation def __mod__(self, other: Any) -> Any: return self.__getattr__("__mod__")(other) - @also_try_reversed("__rmul__") + @binary_operation def __mul__(self, other: Any) -> Any: return self.__getattr__("__mul__")(other) - @also_try_reversed("__ne__") + @binary_operation def __ne__(self, other: object) -> Any: return self.__getattr__("__ne__")(other) + @unary_operation def __neg__(self) -> Any: return self.__getattr__("__neg__")() - @also_try_reversed("__ror__") + @binary_operation def __or__(self, other: Any) -> Any: return self.__getattr__("__or__")(other) + @unary_operation def __pos__(self) -> Any: return self.__getattr__("__pos__")() - @also_try_reversed("__rpow__") + @binary_operation def __pow__(self, other: Any) -> Any: return self.__getattr__("__pow__")(other) + @binary_operation def __radd__(self, other: Any) -> Any: return self.__getattr__("__radd__")(other) + @binary_operation def __rand__(self, other: Any) -> Any: return self.__getattr__("__rand__")(other) + @binary_operation def __rfloordiv__(self, other: Any) -> Any: return self.__getattr__("__rfloordiv__")(other) + @binary_operation def __rlshift__(self, other: Any) -> Any: return self.__getattr__("__rlshift__")(other) + @binary_operation def __rmod__(self, other: Any) -> Any: return self.__getattr__("__rmod__")(other) + @binary_operation def __rmul__(self, other: Any) -> Any: return self.__getattr__("__rmul__")(other) + @binary_operation def __ror__(self, other: Any) -> Any: return self.__getattr__("__ror__")(other) + @binary_operation def __rpow__(self, other: Any) -> Any: return self.__getattr__("__rpow__")(other) + @binary_operation def __rrshift__(self, other: Any) -> Any: return self.__getattr__("__pow__")(other) - @also_try_reversed("__rrshift__") + @binary_operation def __rshift__(self, other: Any) -> Any: return self.__getattr__("__rshift__")(other) + @binary_operation def __rsub__(self, other: Any) -> Any: return self.__getattr__("__rsub__")(other) + @binary_operation def __rtruediv__(self, other: Any) -> Any: return self.__getattr__("__rtruediv__")(other) + @binary_operation def __rxor__(self, other: Any) -> Any: return self.__getattr__("__rxor__")(other) - @also_try_reversed("__rsub__") + @binary_operation def __sub__(self, other: Any) -> Any: return self.__getattr__("__sub__")(other) - @also_try_reversed("__rtruediv__") + @binary_operation def __truediv__(self, other: Any) -> Any: return self.__getattr__("__truediv__")(other) def __trunc__(self) -> Any: return self.__getattr__("__trunc__")() - @also_try_reversed("__rxor__") + @binary_operation def __xor__(self, other: Any) -> Any: return self.__getattr__("__xor__")(other) diff --git a/tests/error/tracing_errors/bad_binary1.err b/tests/error/tracing_errors/bad_binary1.err new file mode 100644 index 00000000..6380cd3b --- /dev/null +++ b/tests/error/tracing_errors/bad_binary1.err @@ -0,0 +1,7 @@ +Traceback (most recent call last): + File "$FILE", line 12, in + module.compile() + File "$FILE", line 9, in test + return x + (2, 3) + ~~^~~~~~~~ +TypeError: Operator not defined: Binary operator `+` not defined for `int` and `(int, int)` diff --git a/tests/error/tracing_errors/bad_binary1.py b/tests/error/tracing_errors/bad_binary1.py new file mode 100644 index 00000000..aa202525 --- /dev/null +++ b/tests/error/tracing_errors/bad_binary1.py @@ -0,0 +1,12 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + + +@guppy.comptime(module) +def test(x: int) -> int: + return x + (2, 3) + + +module.compile() diff --git a/tests/error/tracing_errors/bad_binary2.err b/tests/error/tracing_errors/bad_binary2.err new file mode 100644 index 00000000..2bc428b9 --- /dev/null +++ b/tests/error/tracing_errors/bad_binary2.err @@ -0,0 +1,7 @@ +Traceback (most recent call last): + File "$FILE", line 12, in + module.compile() + File "$FILE", line 9, in test + return (1, 2) + x + ~~~~~~~^~~ +TypeError: Operator not defined: Binary operator `+` not defined for `(int, int)` and `int` diff --git a/tests/error/tracing_errors/bad_binary2.py b/tests/error/tracing_errors/bad_binary2.py new file mode 100644 index 00000000..1c2e2ae0 --- /dev/null +++ b/tests/error/tracing_errors/bad_binary2.py @@ -0,0 +1,12 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + + +@guppy.comptime(module) +def test(x: int) -> int: + return (1, 2) + x + + +module.compile() diff --git a/tests/error/tracing_errors/bad_unary.err b/tests/error/tracing_errors/bad_unary.err new file mode 100644 index 00000000..c5768c4e --- /dev/null +++ b/tests/error/tracing_errors/bad_unary.err @@ -0,0 +1,7 @@ +Traceback (most recent call last): + File "$FILE", line 12, in + module.compile() + File "$FILE", line 9, in test + return ~x + ^^ +TypeError: Operator not defined: Unary operator `~` not defined for `float` diff --git a/tests/error/tracing_errors/bad_unary.py b/tests/error/tracing_errors/bad_unary.py new file mode 100644 index 00000000..51e95410 --- /dev/null +++ b/tests/error/tracing_errors/bad_unary.py @@ -0,0 +1,12 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + + +@guppy.comptime(module) +def test(x: float) -> float: + return ~x + + +module.compile() diff --git a/tests/integration/tracing/test_arithmetic.py b/tests/integration/tracing/test_arithmetic.py index 456b2385..6e5c314f 100644 --- a/tests/integration/tracing/test_arithmetic.py +++ b/tests/integration/tracing/test_arithmetic.py @@ -111,3 +111,33 @@ def main() -> None: # TODO: Requires lowering of `fpow` op: https://github.com/CQCL/hugr/issues/1905 # run_float_fn_approx(compiled, ..., "pow", [...]) + + +def test_dunder_coercions(validate): + module = GuppyModule("module") + + @guppy.comptime(module) + def test1(x: int) -> float: + return 1.0 + x + + @guppy.comptime(module) + def test2(x: int) -> float: + return x + 1.0 + + @guppy.comptime(module) + def test3(x: float) -> float: + return 1 + x + + @guppy.comptime(module) + def test4(x: float) -> float: + return x + 1 + + @guppy.comptime(module) + def test4(x: int, y: float) -> float: + return x + y + + @guppy.comptime(module) + def test5(x: float, y: int) -> float: + return x + y + + validate(module.compile()) From 674cee1764593837906ea111cf23315ca147d33d Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 4 Feb 2025 13:17:51 +0000 Subject: [PATCH 20/26] Turn wrong module test into error test --- tests/error/tracing_errors/wrong_module.err | 7 +++++++ tests/error/tracing_errors/wrong_module.py | 17 +++++++++++++++ tests/integration/tracing/test_typing.py | 23 --------------------- 3 files changed, 24 insertions(+), 23 deletions(-) create mode 100644 tests/error/tracing_errors/wrong_module.err create mode 100644 tests/error/tracing_errors/wrong_module.py delete mode 100644 tests/integration/tracing/test_typing.py diff --git a/tests/error/tracing_errors/wrong_module.err b/tests/error/tracing_errors/wrong_module.err new file mode 100644 index 00000000..508ac3ef --- /dev/null +++ b/tests/error/tracing_errors/wrong_module.err @@ -0,0 +1,7 @@ +Traceback (most recent call last): + File "$FILE", line 17, in + module2.compile() + File "$FILE", line 14, in main + return foo() + ^^^^^ +TypeError: Function `foo` is not available in this module, consider importing it from `module1` diff --git a/tests/error/tracing_errors/wrong_module.py b/tests/error/tracing_errors/wrong_module.py new file mode 100644 index 00000000..62012a13 --- /dev/null +++ b/tests/error/tracing_errors/wrong_module.py @@ -0,0 +1,17 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module1 = GuppyModule("module1") +module2 = GuppyModule("module2") + + +@guppy.declare(module1) +def foo() -> int: ... + + +@guppy.comptime(module2) +def main() -> int: + return foo() + + +module2.compile() diff --git a/tests/integration/tracing/test_typing.py b/tests/integration/tracing/test_typing.py deleted file mode 100644 index 2d91bbce..00000000 --- a/tests/integration/tracing/test_typing.py +++ /dev/null @@ -1,23 +0,0 @@ -import pytest - -from guppylang.decorator import guppy -from guppylang.module import GuppyModule - - -def test_wrong_module(): - module1 = GuppyModule("module1") - module2 = GuppyModule("module2") - - @guppy.declare(module1) - def foo() -> int: ... - - @guppy.comptime(module2) - def main() -> int: - return foo() - - err = ( - "Function `foo` is not available in this module, consider importing it from " - "`module1`" - ) - with pytest.raises(TypeError, match=err): - module2.compile() From b0533224c03c5b3f8e2e982452ddf89fd80ba2dc Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 4 Feb 2025 13:20:17 +0000 Subject: [PATCH 21/26] Ruff --- guppylang/tracing/object.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/guppylang/tracing/object.py b/guppylang/tracing/object.py index f2004b44..4e922370 100644 --- a/guppylang/tracing/object.py +++ b/guppylang/tracing/object.py @@ -9,23 +9,23 @@ from hugr import Wire, ops -from guppylang.checker.errors.type_errors import BinaryOperatorNotDefinedError, \ - UnaryOperatorNotDefinedError import guppylang.checker.expr_checker as expr_checker +from guppylang.checker.errors.type_errors import ( + BinaryOperatorNotDefinedError, + UnaryOperatorNotDefinedError, +) from guppylang.definition.common import DefId, Definition from guppylang.definition.function import RawFunctionDef from guppylang.definition.ty import TypeDef from guppylang.definition.value import CompiledCallableDef, CompiledValueDef -from guppylang.error import GuppyError, GuppyTypeError +from guppylang.error import GuppyTypeError from guppylang.ipython_inspect import find_ipython_def, is_running_ipython from guppylang.tracing.state import get_tracing_globals, get_tracing_state -from guppylang.tracing.util import get_calling_frame, hide_trace, capture_guppy_errors +from guppylang.tracing.util import capture_guppy_errors, get_calling_frame, hide_trace from guppylang.tys.ty import TupleType, Type # Mapping from unary dunder method to display name of the operation -unary_table = { - method: display_name for method, display_name in expr_checker.unary_table.values() -} +unary_table = dict(expr_checker.unary_table.values()) # Mapping from binary dunder method to reversed method and display name of the operation binary_table = { From 13905621deac65ef99a493e059ea09cf3d23725e Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 4 Feb 2025 13:52:47 +0000 Subject: [PATCH 22/26] Check for callable --- guppylang/tracing/object.py | 18 ++++++++++++++++-- tests/error/tracing_errors/higher_order.err | 7 +++++++ tests/error/tracing_errors/higher_order.py | 14 ++++++++++++++ tests/error/tracing_errors/not_callable1.err | 6 ++++++ tests/error/tracing_errors/not_callable1.py | 12 ++++++++++++ tests/error/tracing_errors/not_callable2.err | 6 ++++++ tests/error/tracing_errors/not_callable2.py | 14 ++++++++++++++ 7 files changed, 75 insertions(+), 2 deletions(-) create mode 100644 tests/error/tracing_errors/higher_order.err create mode 100644 tests/error/tracing_errors/higher_order.py create mode 100644 tests/error/tracing_errors/not_callable1.err create mode 100644 tests/error/tracing_errors/not_callable1.py create mode 100644 tests/error/tracing_errors/not_callable2.err create mode 100644 tests/error/tracing_errors/not_callable2.py diff --git a/guppylang/tracing/object.py b/guppylang/tracing/object.py index 4e922370..da7924a0 100644 --- a/guppylang/tracing/object.py +++ b/guppylang/tracing/object.py @@ -10,6 +10,7 @@ from hugr import Wire, ops import guppylang.checker.expr_checker as expr_checker +from guppylang.checker.errors.generic import UnsupportedError from guppylang.checker.errors.type_errors import ( BinaryOperatorNotDefinedError, UnaryOperatorNotDefinedError, @@ -18,11 +19,11 @@ from guppylang.definition.function import RawFunctionDef from guppylang.definition.ty import TypeDef from guppylang.definition.value import CompiledCallableDef, CompiledValueDef -from guppylang.error import GuppyTypeError +from guppylang.error import GuppyError, GuppyTypeError from guppylang.ipython_inspect import find_ipython_def, is_running_ipython from guppylang.tracing.state import get_tracing_globals, get_tracing_state from guppylang.tracing.util import capture_guppy_errors, get_calling_frame, hide_trace -from guppylang.tys.ty import TupleType, Type +from guppylang.tys.ty import FunctionType, TupleType, Type # Mapping from unary dunder method to display name of the operation unary_table = dict(expr_checker.unary_table.values()) @@ -317,6 +318,19 @@ def __bool__(self) -> Any: ) raise ValueError(err) + @hide_trace + @capture_guppy_errors + def __call__(self, *args): + if not isinstance(self._ty, FunctionType): + err = f"Value of type `{self._ty}` is not callable" + raise TypeError(err) + + # TODO: Support higher-order functions + state = get_tracing_state() + raise GuppyError( + UnsupportedError(state.node, "Higher-order comptime functions") + ) + @hide_trace def __iter__(self) -> Any: state = get_tracing_state() diff --git a/tests/error/tracing_errors/higher_order.err b/tests/error/tracing_errors/higher_order.err new file mode 100644 index 00000000..afacb9c2 --- /dev/null +++ b/tests/error/tracing_errors/higher_order.err @@ -0,0 +1,7 @@ +Traceback (most recent call last): + File "$FILE", line 14, in + module.compile() + File "$FILE", line 11, in test + return f() + ^^^ +TypeError: Unsupported: Higher-order comptime functions are not supported diff --git a/tests/error/tracing_errors/higher_order.py b/tests/error/tracing_errors/higher_order.py new file mode 100644 index 00000000..00797f7b --- /dev/null +++ b/tests/error/tracing_errors/higher_order.py @@ -0,0 +1,14 @@ +from collections.abc import Callable + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + + +@guppy.comptime(module) +def test(f: Callable[[], int]) -> int: + return f() + + +module.compile() diff --git a/tests/error/tracing_errors/not_callable1.err b/tests/error/tracing_errors/not_callable1.err new file mode 100644 index 00000000..34e50c8b --- /dev/null +++ b/tests/error/tracing_errors/not_callable1.err @@ -0,0 +1,6 @@ +Traceback (most recent call last): + File "$FILE", line 12, in + module.compile() + File "$FILE", line 9, in test + x(1) +TypeError: Value of type `float` is not callable diff --git a/tests/error/tracing_errors/not_callable1.py b/tests/error/tracing_errors/not_callable1.py new file mode 100644 index 00000000..278d9dce --- /dev/null +++ b/tests/error/tracing_errors/not_callable1.py @@ -0,0 +1,12 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + + +@guppy.comptime(module) +def test(x: float) -> None: + x(1) + + +module.compile() diff --git a/tests/error/tracing_errors/not_callable2.err b/tests/error/tracing_errors/not_callable2.err new file mode 100644 index 00000000..e0789fe5 --- /dev/null +++ b/tests/error/tracing_errors/not_callable2.err @@ -0,0 +1,6 @@ +Traceback (most recent call last): + File "$FILE", line 14, in + module.compile() + File "$FILE", line 11, in test + x(1) +TypeError: Extern `x` is not callable diff --git a/tests/error/tracing_errors/not_callable2.py b/tests/error/tracing_errors/not_callable2.py new file mode 100644 index 00000000..61ee9f39 --- /dev/null +++ b/tests/error/tracing_errors/not_callable2.py @@ -0,0 +1,14 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + +x = guppy.extern("x", "float", module=module) + + +@guppy.comptime(module) +def test() -> None: + x(1) + + +module.compile() From a43a1380912c1d565cf2884ef40063218795bd66 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 4 Feb 2025 14:18:57 +0000 Subject: [PATCH 23/26] Make golden tests version sensitive --- tests/error/test_tracing_errors.py | 2 +- .../{bad_arg.err => bad_arg@python310.err} | 0 tests/error/tracing_errors/bad_arg@python313.err | 8 ++++++++ tests/error/tracing_errors/bad_binary1@python310.err | 6 ++++++ .../{bad_binary1.err => bad_binary1@python313.err} | 1 + tests/error/tracing_errors/bad_binary2@python310.err | 6 ++++++ .../{bad_binary2.err => bad_binary2@python313.err} | 1 + .../{bad_return1.err => bad_return1@python310.err} | 0 tests/error/tracing_errors/bad_return1@python313.err | 10 ++++++++++ .../{bad_return2.err => bad_return2@python310.err} | 0 tests/error/tracing_errors/bad_return2@python313.err | 11 +++++++++++ .../{bad_unary.err => bad_unary@python310.err} | 1 - tests/error/tracing_errors/bad_unary@python313.err | 8 ++++++++ .../{higher_order.err => higher_order@python310.err} | 1 - .../error/tracing_errors/higher_order@python313.err | 7 +++++++ .../{no_return.err => no_return@python310.err} | 0 tests/error/tracing_errors/no_return@python313.err | 10 ++++++++++ ...not_callable1.err => not_callable1@python310.err} | 0 .../error/tracing_errors/not_callable1@python313.err | 8 ++++++++ ...not_callable2.err => not_callable2@python310.err} | 0 .../error/tracing_errors/not_callable2@python313.err | 8 ++++++++ ...enough_args.err => not_enough_args@python310.err} | 0 .../tracing_errors/not_enough_args@python313.err | 8 ++++++++ ...too_many_args.err => too_many_args@python310.err} | 0 .../error/tracing_errors/too_many_args@python313.err | 8 ++++++++ .../{wrong_module.err => wrong_module@python310.err} | 1 - .../error/tracing_errors/wrong_module@python313.err | 7 +++++++ tests/error/util.py | 12 ++++++++++-- 28 files changed, 118 insertions(+), 6 deletions(-) rename tests/error/tracing_errors/{bad_arg.err => bad_arg@python310.err} (100%) create mode 100644 tests/error/tracing_errors/bad_arg@python313.err create mode 100644 tests/error/tracing_errors/bad_binary1@python310.err rename tests/error/tracing_errors/{bad_binary1.err => bad_binary1@python313.err} (92%) create mode 100644 tests/error/tracing_errors/bad_binary2@python310.err rename tests/error/tracing_errors/{bad_binary2.err => bad_binary2@python313.err} (92%) rename tests/error/tracing_errors/{bad_return1.err => bad_return1@python310.err} (100%) create mode 100644 tests/error/tracing_errors/bad_return1@python313.err rename tests/error/tracing_errors/{bad_return2.err => bad_return2@python310.err} (100%) create mode 100644 tests/error/tracing_errors/bad_return2@python313.err rename tests/error/tracing_errors/{bad_unary.err => bad_unary@python310.err} (93%) create mode 100644 tests/error/tracing_errors/bad_unary@python313.err rename tests/error/tracing_errors/{higher_order.err => higher_order@python310.err} (93%) create mode 100644 tests/error/tracing_errors/higher_order@python313.err rename tests/error/tracing_errors/{no_return.err => no_return@python310.err} (100%) create mode 100644 tests/error/tracing_errors/no_return@python313.err rename tests/error/tracing_errors/{not_callable1.err => not_callable1@python310.err} (100%) create mode 100644 tests/error/tracing_errors/not_callable1@python313.err rename tests/error/tracing_errors/{not_callable2.err => not_callable2@python310.err} (100%) create mode 100644 tests/error/tracing_errors/not_callable2@python313.err rename tests/error/tracing_errors/{not_enough_args.err => not_enough_args@python310.err} (100%) create mode 100644 tests/error/tracing_errors/not_enough_args@python313.err rename tests/error/tracing_errors/{too_many_args.err => too_many_args@python310.err} (100%) create mode 100644 tests/error/tracing_errors/too_many_args@python313.err rename tests/error/tracing_errors/{wrong_module.err => wrong_module@python310.err} (93%) create mode 100644 tests/error/tracing_errors/wrong_module@python313.err diff --git a/tests/error/test_tracing_errors.py b/tests/error/test_tracing_errors.py index 40bb1f51..d0119e37 100644 --- a/tests/error/test_tracing_errors.py +++ b/tests/error/test_tracing_errors.py @@ -16,4 +16,4 @@ @pytest.mark.parametrize("file", files) def test_tracing_errors(file, capsys, snapshot): - run_error_test(file, capsys, snapshot) + run_error_test(file, capsys, snapshot, version_sensitive=True) diff --git a/tests/error/tracing_errors/bad_arg.err b/tests/error/tracing_errors/bad_arg@python310.err similarity index 100% rename from tests/error/tracing_errors/bad_arg.err rename to tests/error/tracing_errors/bad_arg@python310.err diff --git a/tests/error/tracing_errors/bad_arg@python313.err b/tests/error/tracing_errors/bad_arg@python313.err new file mode 100644 index 00000000..119decf3 --- /dev/null +++ b/tests/error/tracing_errors/bad_arg@python313.err @@ -0,0 +1,8 @@ +Traceback (most recent call last): + File "$FILE", line 16, in + module.compile() + ~~~~~~~~~~~~~~^^ + File "$FILE", line 13, in test + foo(1.0) + ~~~^^^^^ +TypeError: Type mismatch: Expected argument of type `int`, got `float` diff --git a/tests/error/tracing_errors/bad_binary1@python310.err b/tests/error/tracing_errors/bad_binary1@python310.err new file mode 100644 index 00000000..38b50421 --- /dev/null +++ b/tests/error/tracing_errors/bad_binary1@python310.err @@ -0,0 +1,6 @@ +Traceback (most recent call last): + File "$FILE", line 12, in + module.compile() + File "$FILE", line 9, in test + return x + (2, 3) +TypeError: Operator not defined: Binary operator `+` not defined for `int` and `(int, int)` diff --git a/tests/error/tracing_errors/bad_binary1.err b/tests/error/tracing_errors/bad_binary1@python313.err similarity index 92% rename from tests/error/tracing_errors/bad_binary1.err rename to tests/error/tracing_errors/bad_binary1@python313.err index 6380cd3b..336f016e 100644 --- a/tests/error/tracing_errors/bad_binary1.err +++ b/tests/error/tracing_errors/bad_binary1@python313.err @@ -1,6 +1,7 @@ Traceback (most recent call last): File "$FILE", line 12, in module.compile() + ~~~~~~~~~~~~~~^^ File "$FILE", line 9, in test return x + (2, 3) ~~^~~~~~~~ diff --git a/tests/error/tracing_errors/bad_binary2@python310.err b/tests/error/tracing_errors/bad_binary2@python310.err new file mode 100644 index 00000000..0f552bef --- /dev/null +++ b/tests/error/tracing_errors/bad_binary2@python310.err @@ -0,0 +1,6 @@ +Traceback (most recent call last): + File "$FILE", line 12, in + module.compile() + File "$FILE", line 9, in test + return (1, 2) + x +TypeError: Operator not defined: Binary operator `+` not defined for `(int, int)` and `int` diff --git a/tests/error/tracing_errors/bad_binary2.err b/tests/error/tracing_errors/bad_binary2@python313.err similarity index 92% rename from tests/error/tracing_errors/bad_binary2.err rename to tests/error/tracing_errors/bad_binary2@python313.err index 2bc428b9..6435b30d 100644 --- a/tests/error/tracing_errors/bad_binary2.err +++ b/tests/error/tracing_errors/bad_binary2@python313.err @@ -1,6 +1,7 @@ Traceback (most recent call last): File "$FILE", line 12, in module.compile() + ~~~~~~~~~~~~~~^^ File "$FILE", line 9, in test return (1, 2) + x ~~~~~~~^~~ diff --git a/tests/error/tracing_errors/bad_return1.err b/tests/error/tracing_errors/bad_return1@python310.err similarity index 100% rename from tests/error/tracing_errors/bad_return1.err rename to tests/error/tracing_errors/bad_return1@python310.err diff --git a/tests/error/tracing_errors/bad_return1@python313.err b/tests/error/tracing_errors/bad_return1@python313.err new file mode 100644 index 00000000..046ecd11 --- /dev/null +++ b/tests/error/tracing_errors/bad_return1@python313.err @@ -0,0 +1,10 @@ +Error: Type mismatch (at $FILE:8:0) + | +6 | +7 | @guppy.comptime(module) +8 | def test() -> int: + | ^^^^^^^^^^^^^^^^^^ +9 | return 1.0 + | ^^^^^^^^^^^^^^ Expected return value of type `int`, got `float` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/tracing_errors/bad_return2.err b/tests/error/tracing_errors/bad_return2@python310.err similarity index 100% rename from tests/error/tracing_errors/bad_return2.err rename to tests/error/tracing_errors/bad_return2@python310.err diff --git a/tests/error/tracing_errors/bad_return2@python313.err b/tests/error/tracing_errors/bad_return2@python313.err new file mode 100644 index 00000000..559d2d8f --- /dev/null +++ b/tests/error/tracing_errors/bad_return2@python313.err @@ -0,0 +1,11 @@ +Error: Type mismatch (at $FILE:8:0) + | + 6 | + 7 | @guppy.comptime(module) + 8 | def test() -> int: + | ^^^^^^^^^^^^^^^^^^ + | ... +12 | return 1 + | ^^^^^^^^^^^^^^^^ Expected return value of type `int`, got `float` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/tracing_errors/bad_unary.err b/tests/error/tracing_errors/bad_unary@python310.err similarity index 93% rename from tests/error/tracing_errors/bad_unary.err rename to tests/error/tracing_errors/bad_unary@python310.err index c5768c4e..d55f1bcc 100644 --- a/tests/error/tracing_errors/bad_unary.err +++ b/tests/error/tracing_errors/bad_unary@python310.err @@ -3,5 +3,4 @@ Traceback (most recent call last): module.compile() File "$FILE", line 9, in test return ~x - ^^ TypeError: Operator not defined: Unary operator `~` not defined for `float` diff --git a/tests/error/tracing_errors/bad_unary@python313.err b/tests/error/tracing_errors/bad_unary@python313.err new file mode 100644 index 00000000..35d69fc4 --- /dev/null +++ b/tests/error/tracing_errors/bad_unary@python313.err @@ -0,0 +1,8 @@ +Traceback (most recent call last): + File "$FILE", line 12, in + module.compile() + ~~~~~~~~~~~~~~^^ + File "$FILE", line 9, in test + return ~x + ^^ +TypeError: Operator not defined: Unary operator `~` not defined for `float` diff --git a/tests/error/tracing_errors/higher_order.err b/tests/error/tracing_errors/higher_order@python310.err similarity index 93% rename from tests/error/tracing_errors/higher_order.err rename to tests/error/tracing_errors/higher_order@python310.err index afacb9c2..2a754daf 100644 --- a/tests/error/tracing_errors/higher_order.err +++ b/tests/error/tracing_errors/higher_order@python310.err @@ -3,5 +3,4 @@ Traceback (most recent call last): module.compile() File "$FILE", line 11, in test return f() - ^^^ TypeError: Unsupported: Higher-order comptime functions are not supported diff --git a/tests/error/tracing_errors/higher_order@python313.err b/tests/error/tracing_errors/higher_order@python313.err new file mode 100644 index 00000000..e953dd1b --- /dev/null +++ b/tests/error/tracing_errors/higher_order@python313.err @@ -0,0 +1,7 @@ +Traceback (most recent call last): + File "$FILE", line 14, in + module.compile() + ~~~~~~~~~~~~~~^^ + File "$FILE", line 11, in test + return f() +TypeError: Unsupported: Higher-order comptime functions are not supported diff --git a/tests/error/tracing_errors/no_return.err b/tests/error/tracing_errors/no_return@python310.err similarity index 100% rename from tests/error/tracing_errors/no_return.err rename to tests/error/tracing_errors/no_return@python310.err diff --git a/tests/error/tracing_errors/no_return@python313.err b/tests/error/tracing_errors/no_return@python313.err new file mode 100644 index 00000000..1a58a382 --- /dev/null +++ b/tests/error/tracing_errors/no_return@python313.err @@ -0,0 +1,10 @@ +Error: Type mismatch (at $FILE:8:0) + | +6 | +7 | @guppy.comptime(module) +8 | def test() -> int: + | ^^^^^^^^^^^^^^^^^^ +9 | pass + | ^^^^^^^^ Expected return value of type `int`, got `None` + +Guppy compilation failed due to 1 previous error diff --git a/tests/error/tracing_errors/not_callable1.err b/tests/error/tracing_errors/not_callable1@python310.err similarity index 100% rename from tests/error/tracing_errors/not_callable1.err rename to tests/error/tracing_errors/not_callable1@python310.err diff --git a/tests/error/tracing_errors/not_callable1@python313.err b/tests/error/tracing_errors/not_callable1@python313.err new file mode 100644 index 00000000..df9b184d --- /dev/null +++ b/tests/error/tracing_errors/not_callable1@python313.err @@ -0,0 +1,8 @@ +Traceback (most recent call last): + File "$FILE", line 12, in + module.compile() + ~~~~~~~~~~~~~~^^ + File "$FILE", line 9, in test + x(1) + ~^^^ +TypeError: Value of type `float` is not callable diff --git a/tests/error/tracing_errors/not_callable2.err b/tests/error/tracing_errors/not_callable2@python310.err similarity index 100% rename from tests/error/tracing_errors/not_callable2.err rename to tests/error/tracing_errors/not_callable2@python310.err diff --git a/tests/error/tracing_errors/not_callable2@python313.err b/tests/error/tracing_errors/not_callable2@python313.err new file mode 100644 index 00000000..00de69ac --- /dev/null +++ b/tests/error/tracing_errors/not_callable2@python313.err @@ -0,0 +1,8 @@ +Traceback (most recent call last): + File "$FILE", line 14, in + module.compile() + ~~~~~~~~~~~~~~^^ + File "$FILE", line 11, in test + x(1) + ~^^^ +TypeError: Extern `x` is not callable diff --git a/tests/error/tracing_errors/not_enough_args.err b/tests/error/tracing_errors/not_enough_args@python310.err similarity index 100% rename from tests/error/tracing_errors/not_enough_args.err rename to tests/error/tracing_errors/not_enough_args@python310.err diff --git a/tests/error/tracing_errors/not_enough_args@python313.err b/tests/error/tracing_errors/not_enough_args@python313.err new file mode 100644 index 00000000..63b377a2 --- /dev/null +++ b/tests/error/tracing_errors/not_enough_args@python313.err @@ -0,0 +1,8 @@ +Traceback (most recent call last): + File "$FILE", line 16, in + module.compile() + ~~~~~~~~~~~~~~^^ + File "$FILE", line 13, in test + foo(1) + ~~~^^^ +TypeError: Not enough arguments: Expected 2, got 1 diff --git a/tests/error/tracing_errors/too_many_args.err b/tests/error/tracing_errors/too_many_args@python310.err similarity index 100% rename from tests/error/tracing_errors/too_many_args.err rename to tests/error/tracing_errors/too_many_args@python310.err diff --git a/tests/error/tracing_errors/too_many_args@python313.err b/tests/error/tracing_errors/too_many_args@python313.err new file mode 100644 index 00000000..51872bd4 --- /dev/null +++ b/tests/error/tracing_errors/too_many_args@python313.err @@ -0,0 +1,8 @@ +Traceback (most recent call last): + File "$FILE", line 16, in + module.compile() + ~~~~~~~~~~~~~~^^ + File "$FILE", line 13, in test + foo(1, 2, 3, 4) + ~~~^^^^^^^^^^^^ +TypeError: Too many arguments: Expected 2, got 4 diff --git a/tests/error/tracing_errors/wrong_module.err b/tests/error/tracing_errors/wrong_module@python310.err similarity index 93% rename from tests/error/tracing_errors/wrong_module.err rename to tests/error/tracing_errors/wrong_module@python310.err index 508ac3ef..7d036284 100644 --- a/tests/error/tracing_errors/wrong_module.err +++ b/tests/error/tracing_errors/wrong_module@python310.err @@ -3,5 +3,4 @@ Traceback (most recent call last): module2.compile() File "$FILE", line 14, in main return foo() - ^^^^^ TypeError: Function `foo` is not available in this module, consider importing it from `module1` diff --git a/tests/error/tracing_errors/wrong_module@python313.err b/tests/error/tracing_errors/wrong_module@python313.err new file mode 100644 index 00000000..055d360e --- /dev/null +++ b/tests/error/tracing_errors/wrong_module@python313.err @@ -0,0 +1,7 @@ +Traceback (most recent call last): + File "$FILE", line 17, in + module2.compile() + ~~~~~~~~~~~~~~~^^ + File "$FILE", line 14, in main + return foo() +TypeError: Function `foo` is not available in this module, consider importing it from `module1` diff --git a/tests/error/util.py b/tests/error/util.py index e23f121e..24490e3c 100644 --- a/tests/error/util.py +++ b/tests/error/util.py @@ -12,7 +12,7 @@ import guppylang.decorator as decorator -def run_error_test(file, capsys, snapshot): +def run_error_test(file, capsys, snapshot, version_sensitive=False): file = pathlib.Path(file) with pytest.raises(Exception) as exc_info: @@ -30,8 +30,16 @@ def run_error_test(file, capsys, snapshot): err = capsys.readouterr().err err = err.replace(str(file), "$FILE") + if version_sensitive: + major, minor, *_ = sys.version_info + golden_file = file.with_name(file.stem + f"@python{major}{minor}.err") + if not golden_file.exists() and not snapshot._snapshot_update: + pytest.skip(f"No golden test available for Python {major}.{minor}") + else: + golden_file = file.with_suffix(".err") + snapshot.snapshot_dir = str(file.parent) - snapshot.assert_match(err, file.with_suffix(".err").name) + snapshot.assert_match(err, golden_file.name) util = GuppyModule("test") From c3b9d2c63dacbf001cedb007e635788c82cfd1b0 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 4 Feb 2025 14:22:52 +0000 Subject: [PATCH 24/26] Fix mypy --- guppylang/tracing/object.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guppylang/tracing/object.py b/guppylang/tracing/object.py index da7924a0..dc6b76d0 100644 --- a/guppylang/tracing/object.py +++ b/guppylang/tracing/object.py @@ -320,7 +320,7 @@ def __bool__(self) -> Any: @hide_trace @capture_guppy_errors - def __call__(self, *args): + def __call__(self, *args: Any) -> Any: if not isinstance(self._ty, FunctionType): err = f"Value of type `{self._ty}` is not callable" raise TypeError(err) From 99b05b95854541f38abf01e854040e583a97a6a0 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Tue, 4 Feb 2025 15:03:52 +0000 Subject: [PATCH 25/26] Allow runtime array construction and mutation --- guppylang/std/builtins.py | 4 +- guppylang/tracing/function.py | 3 +- guppylang/tracing/unpacking.py | 60 ++++++------------- tests/integration/tracing/test_array.py | 76 +++++++++++++++++++++++++ 4 files changed, 95 insertions(+), 48 deletions(-) create mode 100644 tests/integration/tracing/test_array.py diff --git a/guppylang/std/builtins.py b/guppylang/std/builtins.py index d379a3a0..adbf69ad 100644 --- a/guppylang/std/builtins.py +++ b/guppylang/std/builtins.py @@ -83,11 +83,11 @@ class nat: _n = TypeVar("_n") -class array(Generic[_T, _n]): +class array(list[_T], Generic[_T, _n]): """Class to import in order to use arrays.""" def __init__(self, *args: Any): - pass + list.__init__(self, args) @guppy.extend_type(bool_type_def) diff --git a/guppylang/tracing/function.py b/guppylang/tracing/function.py index c56d9637..ae2335af 100644 --- a/guppylang/tracing/function.py +++ b/guppylang/tracing/function.py @@ -26,7 +26,6 @@ from guppylang.tracing.unpacking import ( P, guppy_object_from_py, - repack_guppy_object, unpack_guppy_object, update_packed_value, ) @@ -93,7 +92,7 @@ def trace_function( # Compute the inout extra outputs try: inout_returns = [ - repack_guppy_object(inout_obj, builder)._use_wire(None) + guppy_object_from_py(inout_obj, builder, state.node)._use_wire(None) for inout_obj, inp in zip(inputs, ty.inputs, strict=False) if InputFlags.Inout in inp.flags ] diff --git a/guppylang/tracing/unpacking.py b/guppylang/tracing/unpacking.py index 0bd6d219..df944940 100644 --- a/guppylang/tracing/unpacking.py +++ b/guppylang/tracing/unpacking.py @@ -63,21 +63,26 @@ def unpack_guppy_object(obj: GuppyObject, builder: DfBase[P]) -> Any: return obj -def repack_guppy_object(v: Any, builder: DfBase[P]) -> GuppyObject: - """Undoes the `unpack_guppy_object` operation.""" +def guppy_object_from_py(v: Any, builder: DfBase[P], node: AstNode) -> GuppyObject: + """Constructs a Guppy object from a Python value. + + Essentially undoes the `unpack_guppy_object` operation. + """ match v: case GuppyObject() as obj: return obj + case GuppyDefinition() as defn: + return defn.to_guppy_object() case None: return GuppyObject(NoneType(), builder.add_op(ops.MakeTuple())) case tuple(vs): - objs = [repack_guppy_object(v, builder) for v in vs] + objs = [guppy_object_from_py(v, builder, node) for v in vs] return GuppyObject( TupleType([obj._ty for obj in objs]), builder.add_op(ops.MakeTuple(), *(obj._use_wire(None) for obj in objs)), ) case list(vs) if len(vs) > 0: - objs = [repack_guppy_object(v, builder) for v in vs] + objs = [guppy_object_from_py(v, builder, node) for v in vs] elem_ty = objs[0]._ty hugr_elem_ty = ht.Option(elem_ty.to_hugr()) wires = [ @@ -88,11 +93,13 @@ def repack_guppy_object(v: Any, builder: DfBase[P]) -> GuppyObject: array_type(elem_ty, len(vs)), builder.add_op(array_new(hugr_elem_ty, len(vs)), *wires), ) - case _: - raise InternalGuppyError( - "Can only repack values that were constructed via " - "`unpack_guppy_object`" - ) + case v: + ty = python_value_to_guppy_type(v, node, get_tracing_globals()) + if ty is None: + raise GuppyError(IllegalPyExpressionError(node, type(v))) + hugr_val = python_value_to_hugr(v, ty) + assert hugr_val is not None + return GuppyObject(ty, builder.load(hugr_val)) def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) -> None: @@ -127,38 +134,3 @@ def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) -> None: update_packed_value(v, GuppyObject(elem_ty, wire), builder) case _: pass - - -def guppy_object_from_py(v: Any, builder: DfBase[P], node: AstNode) -> GuppyObject: - match v: - case GuppyObject() as obj: - return obj - case GuppyDefinition() as defn: - return defn.to_guppy_object() - case None: - return GuppyObject(NoneType(), builder.add_op(ops.MakeTuple())) - case tuple(vs): - objs = [guppy_object_from_py(v, builder, node) for v in vs] - return GuppyObject( - TupleType([obj._ty for obj in objs]), - builder.add_op(ops.MakeTuple(), *(obj._use_wire(None) for obj in objs)), - ) - case list(vs) if len(vs) > 0: - objs = [guppy_object_from_py(v, builder, node) for v in vs] - elem_ty = objs[0]._ty - hugr_elem_ty = ht.Option(elem_ty.to_hugr()) - wires = [ - builder.add_op(ops.Tag(1, hugr_elem_ty), obj._use_wire(None)) - for obj in objs - ] - return GuppyObject( - array_type(elem_ty, len(vs)), - builder.add_op(array_new(hugr_elem_ty, len(vs)), *wires), - ) - case v: - ty = python_value_to_guppy_type(v, node, get_tracing_globals()) - if ty is None: - raise GuppyError(IllegalPyExpressionError(node, type(v))) - hugr_val = python_value_to_hugr(v, ty) - assert hugr_val is not None - return GuppyObject(ty, builder.load(hugr_val)) diff --git a/tests/integration/tracing/test_array.py b/tests/integration/tracing/test_array.py new file mode 100644 index 00000000..d9aa9428 --- /dev/null +++ b/tests/integration/tracing/test_array.py @@ -0,0 +1,76 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from guppylang.std.builtins import array, owned + + +def test_turns_into_list(validate, run_int_fn): + module = GuppyModule("test") + + @guppy.comptime(module) + def test(xs: array[int, 10]) -> int: + assert isinstance(xs, list) + assert len(xs) == 10 + + return sum(xs) + + @guppy(module) + def main() -> int: + return test(array(i for i in range(10))) + + compiled = module.compile() + validate(compiled) + run_int_fn(compiled, sum(range(10))) + + +def test_accepts_list(validate, run_int_fn): + module = GuppyModule("test") + + @guppy(module) + def foo(xs: array[int, 10] @owned) -> int: + s = 0 + for x in xs: + s += x + return s + + @guppy.comptime(module) + def main() -> int: + return foo(list(range(10))) + + compiled = module.compile() + validate(compiled) + run_int_fn(compiled, sum(range(10))) + + +def test_create(validate, run_int_fn): + module = GuppyModule("test") + + @guppy.comptime(module) + def main() -> int: + xs = array(*range(10)) + assert isinstance(xs, list) + assert xs == list(range(10)) + return xs[-1] + + compiled = module.compile() + validate(compiled) + run_int_fn(compiled, 9) + + +def test_mutate(validate, run_int_fn): + module = GuppyModule("test") + + @guppy.comptime(module) + def test(xs: array[int, 10]) -> None: + ys = xs + ys[0] = 100 + + @guppy(module) + def main() -> int: + xs = array(i for i in range(10)) + test(xs) + return xs[0] + + compiled = module.compile() + validate(compiled) + run_int_fn(compiled, 100) + From b2bd5a27361d5892c1b4abbb6dce11f15826c764 Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Fri, 7 Feb 2025 09:32:46 +0000 Subject: [PATCH 26/26] Support structs --- guppylang/tracing/function.py | 3 +- guppylang/tracing/object.py | 39 +++++- guppylang/tracing/unpacking.py | 28 +++- .../tracing_errors/struct_get_bad_field.py | 17 +++ .../struct_get_bad_field@python310.err | 6 + .../struct_get_bad_field@python313.err | 8 ++ .../tracing_errors/struct_set_bad_field.py | 18 +++ .../struct_set_bad_field@python310.err | 6 + .../struct_set_bad_field@python313.err | 8 ++ .../error/tracing_errors/struct_set_method.py | 22 +++ .../struct_set_method@python310.err | 6 + .../struct_set_method@python313.err | 8 ++ tests/integration/tracing/test_struct.py | 127 ++++++++++++++++++ 13 files changed, 290 insertions(+), 6 deletions(-) create mode 100644 tests/error/tracing_errors/struct_get_bad_field.py create mode 100644 tests/error/tracing_errors/struct_get_bad_field@python310.err create mode 100644 tests/error/tracing_errors/struct_get_bad_field@python313.err create mode 100644 tests/error/tracing_errors/struct_set_bad_field.py create mode 100644 tests/error/tracing_errors/struct_set_bad_field@python310.err create mode 100644 tests/error/tracing_errors/struct_set_bad_field@python313.err create mode 100644 tests/error/tracing_errors/struct_set_method.py create mode 100644 tests/error/tracing_errors/struct_set_method@python310.err create mode 100644 tests/error/tracing_errors/struct_set_method@python313.err create mode 100644 tests/integration/tracing/test_struct.py diff --git a/guppylang/tracing/function.py b/guppylang/tracing/function.py index ae2335af..60a3adf6 100644 --- a/guppylang/tracing/function.py +++ b/guppylang/tracing/function.py @@ -144,4 +144,5 @@ def trace_call(func: CompiledCallableDef, *args: Any) -> Any: inout_wire = state.dfg[var] update_packed_value(arg, GuppyObject(inp.ty, inout_wire), state.dfg.builder) - return GuppyObject(ret_ty, ret_wire) + ret_obj = GuppyObject(ret_ty, ret_wire) + return unpack_guppy_object(ret_obj, state.dfg.builder) diff --git a/guppylang/tracing/object.py b/guppylang/tracing/object.py index dc6b76d0..1c0d5380 100644 --- a/guppylang/tracing/object.py +++ b/guppylang/tracing/object.py @@ -2,7 +2,7 @@ import inspect import itertools from abc import ABC, abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Sequence from contextlib import suppress from dataclasses import dataclass from typing import Any, NamedTuple, TypeAlias @@ -23,7 +23,7 @@ from guppylang.ipython_inspect import find_ipython_def, is_running_ipython from guppylang.tracing.state import get_tracing_globals, get_tracing_state from guppylang.tracing.util import capture_guppy_errors, get_calling_frame, hide_trace -from guppylang.tys.ty import FunctionType, TupleType, Type +from guppylang.tys.ty import FunctionType, StructType, TupleType, Type # Mapping from unary dunder method to display name of the operation unary_table = dict(expr_checker.unary_table.values()) @@ -374,6 +374,41 @@ def _use_wire(self, called_func: CompiledCallableDef | None) -> Wire: return self._wire +class GuppyStructObject: + """The runtime representation of Guppy struct objects during tracing.""" + + _ty: StructType + _field_values: dict[str, Any] + + def __init__(self, ty: StructType, field_values: Sequence[Any]) -> None: + field_values_dict = { + f.name: v for f, v in zip(ty.fields, field_values, strict=True) + } + object.__setattr__(self, "_ty", ty) + object.__setattr__(self, "_field_values", field_values_dict) + + @hide_trace + def __getattr__(self, key: str) -> Any: # type: ignore[misc] + # It could be an attribute + if key in self._field_values: + return self._field_values[key] + # Or a method + globals = get_tracing_globals() + func = globals.get_instance_func(self._ty, key) + if func is None: + err = f"Expression of type `{self._ty}` has no attribute `{key}`" + raise AttributeError(err) + return lambda *xs: GuppyDefinition(func)(self, *xs) + + @hide_trace + def __setattr__(self, key: str, value: Any) -> None: + if key in self._field_values: + self._field_values[key] = value + else: + err = f"Expression of type `{self._ty}` has no attribute `{key}`" + raise AttributeError(err) + + @dataclass(frozen=True) class GuppyDefinition: """A top-level Guppy definition. diff --git a/guppylang/tracing/unpacking.py b/guppylang/tracing/unpacking.py index df944940..14f79501 100644 --- a/guppylang/tracing/unpacking.py +++ b/guppylang/tracing/unpacking.py @@ -8,10 +8,10 @@ from guppylang.checker.errors.py_errors import IllegalPyExpressionError from guppylang.checker.expr_checker import python_value_to_guppy_type from guppylang.compiler.expr_compiler import python_value_to_hugr -from guppylang.error import GuppyError, InternalGuppyError +from guppylang.error import GuppyError from guppylang.std._internal.compiler.array import array_new, unpack_array from guppylang.std._internal.compiler.prelude import build_unwrap -from guppylang.tracing.object import GuppyDefinition, GuppyObject +from guppylang.tracing.object import GuppyDefinition, GuppyObject, GuppyStructObject from guppylang.tracing.state import get_tracing_globals, get_tracing_state from guppylang.tys.builtin import ( array_type, @@ -20,7 +20,7 @@ is_array_type, ) from guppylang.tys.const import ConstValue -from guppylang.tys.ty import NoneType, TupleType +from guppylang.tys.ty import NoneType, StructType, TupleType P = TypeVar("P", bound=ops.DfParentOp) @@ -40,6 +40,13 @@ def unpack_guppy_object(obj: GuppyObject, builder: DfBase[P]) -> Any: unpack_guppy_object(GuppyObject(ty, wire), builder) for ty, wire in zip(tys, unpack.outputs(), strict=False) ) + case StructType() as ty: + unpack = builder.add_op(ops.UnpackTuple(), obj._use_wire(None)) + field_values = [ + unpack_guppy_object(GuppyObject(field.ty, wire), builder) + for field, wire in zip(ty.fields, unpack.outputs(), strict=True) + ] + return GuppyStructObject(ty, field_values) case ty if is_array_type(ty): length = get_array_length(ty) if isinstance(length, ConstValue): @@ -81,6 +88,12 @@ def guppy_object_from_py(v: Any, builder: DfBase[P], node: AstNode) -> GuppyObje TupleType([obj._ty for obj in objs]), builder.add_op(ops.MakeTuple(), *(obj._use_wire(None) for obj in objs)), ) + case GuppyStructObject(_ty=struct_ty, _field_values=values): + wires = [ + guppy_object_from_py(values[f.name], builder, node)._use_wire(None) + for f in struct_ty.fields + ] + return GuppyObject(struct_ty, builder.add_op(ops.MakeTuple(), *wires)) case list(vs) if len(vs) > 0: objs = [guppy_object_from_py(v, builder, node) for v in vs] elem_ty = objs[0]._ty @@ -124,6 +137,15 @@ def update_packed_value(v: Any, obj: "GuppyObject", builder: DfBase[P]) -> None: wires = builder.add_op(ops.UnpackTuple(), obj._use_wire(None)).outputs() for v, ty, wire in zip(vs, obj._ty.element_types, wires, strict=True): update_packed_value(v, GuppyObject(ty, wire), builder) + case GuppyStructObject(_ty=ty, _field_values=values): + assert obj._ty == ty + wires = builder.add_op(ops.UnpackTuple(), obj._use_wire(None)).outputs() + for ( + field, + wire, + ) in zip(ty.fields, wires, strict=True): + v = values[field.name] + update_packed_value(v, GuppyObject(field.ty, wire), builder) case list(vs) if len(vs) > 0: assert is_array_type(obj._ty) elem_ty = get_element_type(obj._ty) diff --git a/tests/error/tracing_errors/struct_get_bad_field.py b/tests/error/tracing_errors/struct_get_bad_field.py new file mode 100644 index 00000000..d8123ed4 --- /dev/null +++ b/tests/error/tracing_errors/struct_get_bad_field.py @@ -0,0 +1,17 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + + +@guppy.struct(module) +class S: + x: int + + +@guppy.comptime(module) +def test(s: S) -> int: + return s.y + + +module.compile() diff --git a/tests/error/tracing_errors/struct_get_bad_field@python310.err b/tests/error/tracing_errors/struct_get_bad_field@python310.err new file mode 100644 index 00000000..576a42df --- /dev/null +++ b/tests/error/tracing_errors/struct_get_bad_field@python310.err @@ -0,0 +1,6 @@ +Traceback (most recent call last): + File "$FILE", line 17, in + module.compile() + File "$FILE", line 14, in test + return s.y +AttributeError: Expression of type `S` has no attribute `y` diff --git a/tests/error/tracing_errors/struct_get_bad_field@python313.err b/tests/error/tracing_errors/struct_get_bad_field@python313.err new file mode 100644 index 00000000..81124a34 --- /dev/null +++ b/tests/error/tracing_errors/struct_get_bad_field@python313.err @@ -0,0 +1,8 @@ +Traceback (most recent call last): + File "$FILE", line 17, in + module.compile() + ~~~~~~~~~~~~~~^^ + File "$FILE", line 14, in test + return s.y + ^^^ +AttributeError: Expression of type `S` has no attribute `y` diff --git a/tests/error/tracing_errors/struct_set_bad_field.py b/tests/error/tracing_errors/struct_set_bad_field.py new file mode 100644 index 00000000..8fbbeb67 --- /dev/null +++ b/tests/error/tracing_errors/struct_set_bad_field.py @@ -0,0 +1,18 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + + +@guppy.struct(module) +class S: + x: int + + +@guppy.comptime(module) +def test(x: int) -> None: + s = S(x) + s.y = x + + +module.compile() diff --git a/tests/error/tracing_errors/struct_set_bad_field@python310.err b/tests/error/tracing_errors/struct_set_bad_field@python310.err new file mode 100644 index 00000000..1a60d945 --- /dev/null +++ b/tests/error/tracing_errors/struct_set_bad_field@python310.err @@ -0,0 +1,6 @@ +Traceback (most recent call last): + File "$FILE", line 18, in + module.compile() + File "$FILE", line 15, in test + s.y = x +AttributeError: Expression of type `S` has no attribute `y` diff --git a/tests/error/tracing_errors/struct_set_bad_field@python313.err b/tests/error/tracing_errors/struct_set_bad_field@python313.err new file mode 100644 index 00000000..f078f14e --- /dev/null +++ b/tests/error/tracing_errors/struct_set_bad_field@python313.err @@ -0,0 +1,8 @@ +Traceback (most recent call last): + File "$FILE", line 18, in + module.compile() + ~~~~~~~~~~~~~~^^ + File "$FILE", line 15, in test + s.y = x + ^^^ +AttributeError: Expression of type `S` has no attribute `y` diff --git a/tests/error/tracing_errors/struct_set_method.py b/tests/error/tracing_errors/struct_set_method.py new file mode 100644 index 00000000..f77168ad --- /dev/null +++ b/tests/error/tracing_errors/struct_set_method.py @@ -0,0 +1,22 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + +module = GuppyModule("test") + + +@guppy.struct(module) +class S: + x: int + + @guppy(module) + def foo(self: "S") -> None: + pass + + +@guppy.comptime(module) +def test(x: int) -> None: + s = S(x) + s.foo = 0 + + +module.compile() diff --git a/tests/error/tracing_errors/struct_set_method@python310.err b/tests/error/tracing_errors/struct_set_method@python310.err new file mode 100644 index 00000000..6bb624b8 --- /dev/null +++ b/tests/error/tracing_errors/struct_set_method@python310.err @@ -0,0 +1,6 @@ +Traceback (most recent call last): + File "$FILE", line 22, in + module.compile() + File "$FILE", line 19, in test + s.foo = 0 +AttributeError: Expression of type `S` has no attribute `foo` diff --git a/tests/error/tracing_errors/struct_set_method@python313.err b/tests/error/tracing_errors/struct_set_method@python313.err new file mode 100644 index 00000000..cafe19ad --- /dev/null +++ b/tests/error/tracing_errors/struct_set_method@python313.err @@ -0,0 +1,8 @@ +Traceback (most recent call last): + File "$FILE", line 22, in + module.compile() + ~~~~~~~~~~~~~~^^ + File "$FILE", line 19, in test + s.foo = 0 + ^^^^^ +AttributeError: Expression of type `S` has no attribute `foo` diff --git a/tests/integration/tracing/test_struct.py b/tests/integration/tracing/test_struct.py new file mode 100644 index 00000000..91f07eea --- /dev/null +++ b/tests/integration/tracing/test_struct.py @@ -0,0 +1,127 @@ +from typing import Generic + +from guppylang.decorator import guppy +from guppylang.module import GuppyModule + + +def test_create(validate, run_int_fn): + module = GuppyModule("module") + + @guppy.struct(module) + class S: + x: int + y: int + + @guppy.comptime(module) + def main(x: int) -> int: + s = S(x, 2) + return s.x + s.y + + compiled = module.compile() + validate(compiled) + run_int_fn(compiled, 42, args=[40]) + + +def test_argument(validate, run_int_fn): + module = GuppyModule("module") + + @guppy.struct(module) + class S: + x: int + y: int + + @guppy.comptime(module) + def foo(s: S) -> int: + return s.x + s.y + + @guppy(module) + def main() -> int: + return foo(S(40, 2)) + + compiled = module.compile() + validate(compiled) + run_int_fn(compiled, 42) + + +def test_write(validate, run_int_fn): + module = GuppyModule("module") + + @guppy.struct(module) + class S: + x: int + y: int + + @guppy.comptime(module) + def main(y: int) -> int: + s = S(40, y) + t = s + t.y += 1 + return s.x + s.y + + compiled = module.compile() + validate(compiled) + run_int_fn(compiled, 42, args=[1]) + + +def test_method(validate, run_int_fn): + module = GuppyModule("module") + + @guppy.struct(module) + class S: + x: int + y: int + + @guppy.comptime(module) + def get_x(self: "S") -> int: + return self.x + + @guppy(module) + def get_y(self: "S") -> int: + return self.y + + @guppy.comptime(module) + def main(x: int, y: int) -> int: + s = S(x, y) + return s.get_x() + s.get_y() + + compiled = module.compile() + validate(compiled) + run_int_fn(compiled, 42, args=[40, 2]) + + +def test_generic_nested(validate, run_float_fn_approx): + module = GuppyModule("module") + S = guppy.type_var("S", module=module) + T = guppy.type_var("T", module=module) + + @guppy.struct(module) + class StructA(Generic[T]): + x: tuple[int, T] + + @guppy.struct(module) + class StructB(Generic[S, T]): + x: S + y: StructA[T] + + @guppy.comptime(module) + def foo(a: StructA[StructA[float]], b: StructB[bool, int]) -> float: + flat_a = a.x[0] + a.x[1].x[0] + a.x[1].x[1] + flat_b = b.y.x[0] + b.y.x[1] + return flat_a + flat_b + + @guppy.comptime(module) + def bar( + x1: int, x2: int, x3: float,x4: int, x5: int + ) -> tuple[StructA[StructA[float]], StructB[bool, int]]: + a = StructA((x1, StructA((x2, x3)))) + b = StructB(True, StructA((x4, x5))) + return a, b + + @guppy(module) + def main() -> float: + a, b = bar(1, 10, 100, 1000, 10000) + return foo(a, b) + + compiled = module.compile() + validate(compiled) + run_float_fn_approx(compiled, 11111)