Skip to content

Commit

Permalink
feat: Support nat/intbool cast operations (#459)
Browse files Browse the repository at this point in the history
* Closes #281
* Closes #409
  • Loading branch information
aborgna-q authored Sep 9, 2024
1 parent b02e0d0 commit 3b778c3
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 29 deletions.
82 changes: 61 additions & 21 deletions guppylang/prelude/_internal/compiler/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections.abc import Sequence

import hugr
import hugr.std.int
from hugr import Wire, ops
from hugr import tys as ht
from hugr.std.float import FLOAT_T
Expand Down Expand Up @@ -45,45 +45,53 @@ def ine(width: int) -> ops.ExtOp:
return _instantiate_int_op("ine", width, [int_t(width), int_t(width)], [ht.Bool])


def iwiden_u(from_width: int, to_width: int) -> ops.ExtOp:
"""Returns an unsigned `std.arithmetic.int.widen_u` operation."""
return _instantiate_int_op(
"iwiden_u", [from_width, to_width], [int_t(from_width)], [int_t(to_width)]
)


def iwiden_s(from_width: int, to_width: int) -> ops.ExtOp:
"""Returns a signed `std.arithmetic.int.widen_s` operation."""
return _instantiate_int_op(
"iwiden_s", [from_width, to_width], [int_t(from_width)], [int_t(to_width)]
)


# ------------------------------------------------------
# --------- std.arithmetic.conversions ops -------------
# ------------------------------------------------------


def _instantiate_convert_op(
name: str,
inp: list[ht.Type],
out: list[ht.Type],
args: list[ht.TypeArg] | None = None,
) -> ops.ExtOp:
op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op(name)
return ops.ExtOp(op_def, ht.FunctionType(inp, out), args or [])


def convert_ifromusize() -> ops.ExtOp:
"""Returns a `std.arithmetic.conversions.ifromusize` operation."""
op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op("ifromusize")
return ops.ExtOp(
op_def,
ht.FunctionType([ht.USize()], [INT_T]),
)
return _instantiate_convert_op("ifromusize", [ht.USize()], [INT_T])


def convert_itousize() -> ops.ExtOp:
"""Returns a `std.arithmetic.conversions.itousize` operation."""
op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op("itousize")
return ops.ExtOp(
op_def,
ht.FunctionType([INT_T], [ht.USize()]),
)
return _instantiate_convert_op("itousize", [INT_T], [ht.USize()])


def convert_ifrombool() -> ops.ExtOp:
"""Returns a `std.arithmetic.conversions.ifrombool` operation."""
op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op("ifrombool")
return ops.ExtOp(
op_def,
ht.FunctionType([ht.Bool], [int_t(1)]),
)
return _instantiate_convert_op("ifrombool", [ht.Bool], [int_t(0)])


def convert_itobool() -> ops.ExtOp:
"""Returns a `std.arithmetic.conversions.itobool` operation."""
op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op("itobool")
return ops.ExtOp(
op_def,
ht.FunctionType([int_t(1)], [ht.Bool]),
)
return _instantiate_convert_op("itobool", [int_t(0)], [ht.Bool])


# ------------------------------------------------------
Expand Down Expand Up @@ -264,3 +272,35 @@ def compile(self, args: list[Wire]) -> list[Wire]:
ht.FunctionType([FLOAT_T] * len(args), [FLOAT_T]),
)
return list(self.builder.add(ops.MakeTuple()(div, mod)))


class IToBoolCompiler(CustomCallCompiler):
"""Compiler for the `Int` and `Nat` `.__bool__` methods.
Note that the native `std.arithmetic.conversions.itobool` hugr op
only supports 1 bit integers as input.
"""

def compile(self, args: list[Wire]) -> list[Wire]:
# Emit a comparison against zero
[num] = args
zero = self.builder.load(hugr.std.int.IntVal(0, width=6))
out = self.builder.add_op(ine(NumericType.INT_WIDTH), num, zero)
return [out]


class IFromBoolCompiler(CustomCallCompiler):
"""Compiler for the `Bool` `.__int__` and `.__nat__` methods.
Note that the native `std.arithmetic.conversions.ifrombool` hugr op
only produces 1 bit integers as output, so we have to widen the result.
"""

def compile(self, args: list[Wire]) -> list[Wire]:
# Emit an `ifrombool` followed by a widening cast
# We use `widen_u` independently of the target type, since we want the bit `1`
# to be expanded to `0x00000001` even for `nat` types
[boolean] = args
bit = self.builder.add_op(convert_ifrombool(), boolean)
num = self.builder.add_op(iwiden_u(0, NumericType.INT_WIDTH), bit)
return [num]
14 changes: 6 additions & 8 deletions guppylang/prelude/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
FloatDivmodCompiler,
FloatFloordivCompiler,
FloatModCompiler,
IFromBoolCompiler,
IntTruedivCompiler,
IToBoolCompiler,
NatTruedivCompiler,
)
from guppylang.prelude._internal.compiler.array import (
Expand Down Expand Up @@ -101,10 +103,10 @@ def __bool__(self: bool) -> bool: ...
@guppy.hugr_op(builtins, logic_op("Eq"))
def __eq__(self: bool, other: bool) -> bool: ...

@guppy.hugr_op(builtins, unsupported_op("ifrombool")) # TODO: Widen to INT_WIDTH
@guppy.custom(builtins, IFromBoolCompiler())
def __int__(self: bool) -> int: ...

@guppy.hugr_op(builtins, unsupported_op("ifrombool")) # TODO: Widen to INT_WIDTH
@guppy.custom(builtins, IFromBoolCompiler())
def __nat__(self: bool) -> nat: ...

@guppy.custom(builtins, checker=DunderChecker("__bool__"), higher_order_value=False)
Expand All @@ -128,9 +130,7 @@ def __add__(self: nat, other: nat) -> nat: ...
@guppy.hugr_op(builtins, int_op("iand"))
def __and__(self: nat, other: nat) -> nat: ...

@guppy.hugr_op(
builtins, unsupported_op("itobool")
) # TODO: itobool only supports single bit ints
@guppy.custom(builtins, IToBoolCompiler())
def __bool__(self: nat) -> bool: ...

@guppy.custom(builtins, NoopCompiler())
Expand Down Expand Up @@ -273,9 +273,7 @@ def __add__(self: int, other: int) -> int: ...
@guppy.hugr_op(builtins, int_op("iand"))
def __and__(self: int, other: int) -> int: ...

@guppy.hugr_op(
builtins, unsupported_op("itobool")
) # TODO: itobool only supports single bit ints
@guppy.custom(builtins, IToBoolCompiler())
def __bool__(self: int) -> bool: ...

@guppy.custom(builtins, NoopCompiler())
Expand Down

0 comments on commit 3b778c3

Please sign in to comment.