diff --git a/pybindings/quizx/__init__.py b/pybindings/quizx/__init__.py index 05786f5..eb38560 100644 --- a/pybindings/quizx/__init__.py +++ b/pybindings/quizx/__init__.py @@ -1,13 +1,7 @@ -from . import _quizx, simplify +from . import simplify from .graph import VecGraph -from .circuit import Circuit +from .circuit import Circuit, extract_circuit from .decompose import Decomposer from ._quizx import Scalar -__all__ = ["VecGraph", "Circuit", "simplify", "Decomposer", "Scalar"] - - -def extract_circuit(g): - c = Circuit() - c._c = _quizx.extract_circuit(g._g) - return c +__all__ = ["VecGraph", "Circuit", "simplify", "Decomposer", "Scalar", "extract_circuit"] diff --git a/pybindings/quizx/_quizx.pyi b/pybindings/quizx/_quizx.pyi index d2f0330..ef954ef 100644 --- a/pybindings/quizx/_quizx.pyi +++ b/pybindings/quizx/_quizx.pyi @@ -65,6 +65,9 @@ class VecGraph: def outputs(self) -> list[int]: ... def num_outputs(self) -> int: ... def set_outputs(self, outputs: list[int]) -> None: ... + def adjoint(self) -> None: ... + def plug(self, other: "VecGraph") -> None: ... + def clone(self) -> "VecGraph": ... @final class Circuit: @@ -95,10 +98,13 @@ class Decomposer: def empty() -> Decomposer: ... def __init__(self, g: VecGraph) -> None: ... def graphs(self) -> list[VecGraph]: ... + def done(self) -> list[VecGraph]: ... + def save(self, b: bool) -> None: ... def apply_optimizations(self, b: bool) -> None: ... def max_terms(self) -> int: ... def decomp_top(self) -> None: ... def decomp_all(self) -> None: ... + def decomp_parallel(self, depth: int) -> None: ... def decomp_until_depth(self, depth: int) -> None: ... def use_cats(self, b: bool) -> None: ... def get_nterms(self) -> int: ... diff --git a/pybindings/quizx/circuit.py b/pybindings/quizx/circuit.py index d532740..9f21c19 100644 --- a/pybindings/quizx/circuit.py +++ b/pybindings/quizx/circuit.py @@ -2,6 +2,12 @@ from .graph import VecGraph +def extract_circuit(g: VecGraph) -> "Circuit": + c = Circuit() + c._c = _quizx.extract_circuit(g.get_raw_graph()) + return c + + class Circuit: def __init__(self): self._c = None diff --git a/pybindings/quizx/decompose.py b/pybindings/quizx/decompose.py index d4642ec..7c96ea7 100644 --- a/pybindings/quizx/decompose.py +++ b/pybindings/quizx/decompose.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from . import _quizx from .graph import VecGraph @@ -15,8 +15,14 @@ def __init__(self, graph: Optional[VecGraph] = None): else: self._d = _quizx.Decomposer(graph.get_raw_graph()) - def graphs(self) -> List[VecGraph]: - return [VecGraph(g) for g in self._d.graphs()] + def graphs(self) -> list[VecGraph]: + return [VecGraph.from_raw_graph(g) for g in self._d.graphs()] + + def done(self) -> list[VecGraph]: + return [VecGraph.from_raw_graph(g) for g in self._d.done()] + + def save(self, b: bool): + self._d.save(b) def apply_optimizations(self, b: bool): self._d.apply_optimizations(b) @@ -30,6 +36,9 @@ def decomp_top(self): def decomp_all(self): self._d.decomp_all() + def decomp_parallel(self, depth: int = 4): + self._d.decomp_parallel(depth) + def decomp_until_depth(self, depth: int): self._d.decomp_until_depth(depth) diff --git a/pybindings/quizx/graph.py b/pybindings/quizx/graph.py index 253d1f5..1e4bd2d 100644 --- a/pybindings/quizx/graph.py +++ b/pybindings/quizx/graph.py @@ -16,7 +16,7 @@ from .scalar import from_pyzx_scalar, to_pyzx_scalar from fractions import Fraction -from typing import Tuple, Dict, Any, Optional +from typing import Tuple, Dict, Any from pyzx.graph.base import BaseGraph # type: ignore from pyzx.utils import VertexType, EdgeType # type: ignore from pyzx.graph.scalar import Scalar @@ -32,12 +32,9 @@ class VecGraph(BaseGraph[int, Tuple[int, int]]): # The documentation of what these methods do # can be found in base.BaseGraph - def __init__(self, rust_graph: Optional[_quizx.VecGraph] = None): - if rust_graph: - self._g = rust_graph - else: - self._g = _quizx.VecGraph() - BaseGraph.__init__(self) + def __init__(self) -> None: + self._g = _quizx.VecGraph() + super().__init__() self._vdata: Dict[int, Any] = dict() def get_raw_graph(self) -> _quizx.VecGraph: @@ -172,7 +169,7 @@ def vertices_in_range(self, start, end): for v in self.vertices(): if not start < v < end: continue - if all(start < v2 < end for v2 in self.graph[v]): + if all(start < v2 < end for v2 in self.neighbors(v)): yield v def edges(self): @@ -328,3 +325,15 @@ def scalar(self, s: Scalar): def is_ground(self, vertex): return False + + def adjoint(self): + self._g.adjoint() + + def plug(self, other: "VecGraph"): + if other._g is self._g: + self._g.plug(other._g.clone()) + else: + self._g.plug(other._g) + + def clone(self) -> "VecGraph": + return VecGraph.from_raw_graph(self._g.clone()) diff --git a/pybindings/src/lib.rs b/pybindings/src/lib.rs index 3380157..835084f 100644 --- a/pybindings/src/lib.rs +++ b/pybindings/src/lib.rs @@ -313,6 +313,18 @@ impl VecGraph { fn set_scalar(&mut self, scalar: Scalar) { *self.g.scalar_mut() = scalar.into(); } + + fn adjoint(&mut self) { + self.g.adjoint() + } + + fn plug(&mut self, other: &VecGraph) { + self.g.plug(&other.g); + } + + fn clone(&self) -> VecGraph { + VecGraph { g: self.g.clone() } + } } #[pyclass] @@ -354,6 +366,18 @@ impl Decomposer { Ok(gs) } + fn done(&self) -> PyResult> { + let mut gs = vec![]; + for g in &self.d.done { + gs.push(VecGraph { g: g.clone() }); + } + Ok(gs) + } + + fn save(&mut self, b: bool) { + self.d.save(b); + } + fn apply_optimizations(&mut self, b: bool) { if b { self.d.with_simp(quizx::decompose::SimpFunc::FullSimp); @@ -374,6 +398,9 @@ impl Decomposer { fn decomp_until_depth(&mut self, depth: usize) { self.d.decomp_until_depth(depth); } + fn decomp_parallel(&mut self, depth: usize) { + self.d = self.d.clone().decomp_parallel(depth); + } fn use_cats(&mut self, b: bool) { self.d.use_cats(b); } diff --git a/quizx/src/json/scalar.rs b/quizx/src/json/scalar.rs index 8a31ac8..762aa25 100644 --- a/quizx/src/json/scalar.rs +++ b/quizx/src/json/scalar.rs @@ -47,7 +47,7 @@ impl JsonScalar { // In the Clifford+T case where we have Scalar4, we can extract factors of sqrt(2) directly from the // coefficients. Since the coefficients are reduced, sqrt(2) is represented as - // [1, 0, +-1, 0], [0, 1, +-1, 0], where the +- lead to phase contributions already extracted in `phase` + // [1, 0, +-1, 0], [0, 1, 0, +-1], where the +- lead to phase contributions already extracted in `phase` let (power_sqrt2, floatfactor) = match coeffs.iter_coeffs().collect::>().as_slice() { [a, 0, b, 0] | [0, a, 0, b]