Skip to content

Commit

Permalink
feat: Load pytket circuit as a function definition (#672)
Browse files Browse the repository at this point in the history
Closes #657, however the following limitations still exist which are
covered by further issues:
- #669: ~~Classical bits and measurements aren't currently supported as
this requires some checks related to ownership and return values.~~ They
are supported but without taking ownership.
- #670: It would be nice to not have to provide an explicit function
stub for the circuit and also have `load_pytket` method which infers it.
- #671: It would also be nice to be able to provide arrays of qubits in
the function stub.
  • Loading branch information
tatiana-s authored Dec 6, 2024
1 parent 50f71b9 commit b21b7e1
Show file tree
Hide file tree
Showing 14 changed files with 587 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pull-request.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
run: uv sync --extra pytket

- name: Rerun `py(...)` expression tests and pytket lowering with tket2 installed
run: uv run pytest tests/integration/test_py.py tests/error/test_py_errors.py tests/integration/test_tket.py
run: uv run pytest tests/integration/test_py.py tests/error/test_py_errors.py tests/integration/test_tket.py tests/integration/test_pytket_circuits.py

test-coverage:
name: Check Python (3.13) with coverage
Expand Down
1 change: 1 addition & 0 deletions .typos.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ ine = "ine"
inot = "inot"
inout = "inout"
inouts = "inouts"
anc = "anc"
9 changes: 9 additions & 0 deletions guppylang/checker/errors/py_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,12 @@ class Tket2NotInstalled(Error):
@dataclass(frozen=True)
class InstallInstruction(Help):
message: ClassVar[str] = "Install tket2: `pip install tket2`"


@dataclass(frozen=True)
class PytketSignatureMismatch(Error):
title: ClassVar[str] = "Signature mismatch"
span_label: ClassVar[str] = (
"Signature `{name}` doesn't match provided pytket circuit"
)
name: str
24 changes: 24 additions & 0 deletions guppylang/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
RawFunctionDef,
)
from guppylang.definition.parameter import ConstVarDef, TypeVarDef
from guppylang.definition.pytket_circuits import RawPytketDef
from guppylang.definition.struct import RawStructDef
from guppylang.definition.ty import OpaqueTypeDef, TypeDef
from guppylang.error import MissingModuleError, pretty_errors
Expand Down Expand Up @@ -57,6 +58,7 @@
FuncDefDecorator = Decorator[PyFunc, RawFunctionDef]
FuncDeclDecorator = Decorator[PyFunc, RawFunctionDecl]
CustomFuncDecorator = Decorator[PyFunc, RawCustomFunctionDef]
PytketDecorator = Decorator[PyFunc, RawPytketDef]
ClassDecorator = Decorator[PyClass, PyClass]
OpaqueTypeDecorator = Decorator[PyClass, OpaqueTypeDef]
StructDecorator = Decorator[PyClass, RawStructDef]
Expand Down Expand Up @@ -468,6 +470,28 @@ def registered_modules(self) -> KeysView[ModuleIdentifier]:
"""Returns a list of all currently registered modules for local contexts."""
return self._modules.keys()

@pretty_errors
def pytket(
self, input_circuit: Any, module: GuppyModule | None = None
) -> PytketDecorator:
"""Adds a pytket circuit function definition with explicit signature."""
err_msg = "Only pytket circuits can be passed to guppy.pytket"
try:
import pytket

if not isinstance(input_circuit, pytket.circuit.Circuit):
raise TypeError(err_msg) from None

except ImportError:
raise TypeError(err_msg) from None

mod = module or self.get_module()

def func(f: PyFunc) -> RawPytketDef:
return mod.register_pytket_func(f, input_circuit)

return func


class _GuppyDummy:
"""A dummy class with the same interface as `@guppy` that is used during sphinx
Expand Down
27 changes: 11 additions & 16 deletions guppylang/definition/declaration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import ClassVar

from hugr import Node, Wire
from hugr import tys as ht
from hugr.build import function as hf
from hugr.build.dfg import DefinitionBuilder, OpVar

Expand All @@ -13,14 +12,19 @@
from guppylang.checker.func_checker import check_signature
from guppylang.compiler.core import CompiledGlobals, DFContainer
from guppylang.definition.common import CompilableDef, ParsableDef
from guppylang.definition.function import PyFunc, parse_py_func
from guppylang.definition.function import (
PyFunc,
compile_call,
load_with_args,
parse_py_func,
)
from guppylang.definition.value import CallableDef, CallReturnWires, CompiledCallableDef
from guppylang.diagnostic import Error
from guppylang.error import GuppyError
from guppylang.nodes import GlobalCall
from guppylang.span import SourceMap
from guppylang.tys.subst import Inst, Subst
from guppylang.tys.ty import Type, type_to_row
from guppylang.tys.ty import Type


@dataclass(frozen=True)
Expand Down Expand Up @@ -121,9 +125,8 @@ def load_with_args(
node: AstNode,
) -> Wire:
"""Loads the function as a value into a local Hugr dataflow graph."""
func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr()
type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args]
return dfg.builder.load_function(self.declaration, func_ty, type_args)
# Use implementation from function definition.
return load_with_args(type_args, dfg, self.ty, self.declaration)

def compile_call(
self,
Expand All @@ -134,13 +137,5 @@ def compile_call(
node: AstNode,
) -> CallReturnWires:
"""Compiles a call to the function."""
func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr()
type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args]
num_returns = len(type_to_row(self.ty.output))
call = dfg.builder.call(
self.declaration, *args, instantiation=func_ty, type_args=type_args
)
return CallReturnWires(
regular_returns=list(call[:num_returns]),
inout_returns=list(call[num_returns:]),
)
# Use implementation from function definition.
return compile_call(args, type_args, dfg, self.ty, self.declaration)
46 changes: 33 additions & 13 deletions guppylang/definition/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import hugr.tys as ht
from hugr import Wire
from hugr.build.dfg import DefinitionBuilder, OpVar
from hugr.hugr.node_port import ToNode
from hugr.package import FuncDefnPointer

from guppylang.ast_util import AstNode, annotate_location, with_loc, with_type
Expand Down Expand Up @@ -199,9 +200,7 @@ def load_with_args(
node: AstNode,
) -> Wire:
"""Loads the function as a value into a local Hugr dataflow graph."""
func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr()
type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args]
return dfg.builder.load_function(self.func_def, func_ty, type_args)
return load_with_args(type_args, dfg, self.ty, self.func_def)

def compile_call(
self,
Expand All @@ -212,22 +211,43 @@ def compile_call(
node: AstNode,
) -> CallReturnWires:
"""Compiles a call to the function."""
func_ty: ht.FunctionType = self.ty.instantiate(type_args).to_hugr()
type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args]
num_returns = len(type_to_row(self.ty.output))
call = dfg.builder.call(
self.func_def, *args, instantiation=func_ty, type_args=type_args
)
return CallReturnWires(
regular_returns=list(call[:num_returns]),
inout_returns=list(call[num_returns:]),
)
return compile_call(args, type_args, dfg, self.ty, self.func_def)

def compile_inner(self, globals: CompiledGlobals) -> None:
"""Compiles the body of the function."""
compile_global_func_def(self, self.func_def, globals)


def load_with_args(
type_args: Inst,
dfg: DFContainer,
ty: FunctionType,
func: ToNode,
) -> Wire:
"""Loads the function as a value into a local Hugr dataflow graph."""
func_ty: ht.FunctionType = ty.instantiate(type_args).to_hugr()
type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args]
return dfg.builder.load_function(func, func_ty, type_args)


def compile_call(
args: list[Wire],
type_args: Inst,
dfg: DFContainer,
ty: FunctionType,
func: ToNode,
) -> CallReturnWires:
"""Compiles a call to the function."""
func_ty: ht.FunctionType = ty.instantiate(type_args).to_hugr()
type_args: list[ht.TypeArg] = [arg.to_hugr() for arg in type_args]
num_returns = len(type_to_row(ty.output))
call = dfg.builder.call(func, *args, instantiation=func_ty, type_args=type_args)
return CallReturnWires(
regular_returns=list(call[:num_returns]),
inout_returns=list(call[num_returns:]),
)


def parse_py_func(f: PyFunc, sources: SourceMap) -> tuple[ast.FunctionDef, str | None]:
source_lines, line_offset = inspect.getsourcelines(f)
source = "".join(source_lines) # Lines already have trailing \n's
Expand Down
Loading

0 comments on commit b21b7e1

Please sign in to comment.