Skip to content

Commit 3b778c3

Browse files
authored
feat: Support nat/intbool cast operations (#459)
* Closes #281 * Closes #409
1 parent b02e0d0 commit 3b778c3

File tree

2 files changed

+67
-29
lines changed

2 files changed

+67
-29
lines changed

guppylang/prelude/_internal/compiler/arithmetic.py

Lines changed: 61 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections.abc import Sequence
44

5-
import hugr
5+
import hugr.std.int
66
from hugr import Wire, ops
77
from hugr import tys as ht
88
from hugr.std.float import FLOAT_T
@@ -45,45 +45,53 @@ def ine(width: int) -> ops.ExtOp:
4545
return _instantiate_int_op("ine", width, [int_t(width), int_t(width)], [ht.Bool])
4646

4747

48+
def iwiden_u(from_width: int, to_width: int) -> ops.ExtOp:
49+
"""Returns an unsigned `std.arithmetic.int.widen_u` operation."""
50+
return _instantiate_int_op(
51+
"iwiden_u", [from_width, to_width], [int_t(from_width)], [int_t(to_width)]
52+
)
53+
54+
55+
def iwiden_s(from_width: int, to_width: int) -> ops.ExtOp:
56+
"""Returns a signed `std.arithmetic.int.widen_s` operation."""
57+
return _instantiate_int_op(
58+
"iwiden_s", [from_width, to_width], [int_t(from_width)], [int_t(to_width)]
59+
)
60+
61+
4862
# ------------------------------------------------------
4963
# --------- std.arithmetic.conversions ops -------------
5064
# ------------------------------------------------------
5165

5266

67+
def _instantiate_convert_op(
68+
name: str,
69+
inp: list[ht.Type],
70+
out: list[ht.Type],
71+
args: list[ht.TypeArg] | None = None,
72+
) -> ops.ExtOp:
73+
op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op(name)
74+
return ops.ExtOp(op_def, ht.FunctionType(inp, out), args or [])
75+
76+
5377
def convert_ifromusize() -> ops.ExtOp:
5478
"""Returns a `std.arithmetic.conversions.ifromusize` operation."""
55-
op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op("ifromusize")
56-
return ops.ExtOp(
57-
op_def,
58-
ht.FunctionType([ht.USize()], [INT_T]),
59-
)
79+
return _instantiate_convert_op("ifromusize", [ht.USize()], [INT_T])
6080

6181

6282
def convert_itousize() -> ops.ExtOp:
6383
"""Returns a `std.arithmetic.conversions.itousize` operation."""
64-
op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op("itousize")
65-
return ops.ExtOp(
66-
op_def,
67-
ht.FunctionType([INT_T], [ht.USize()]),
68-
)
84+
return _instantiate_convert_op("itousize", [INT_T], [ht.USize()])
6985

7086

7187
def convert_ifrombool() -> ops.ExtOp:
7288
"""Returns a `std.arithmetic.conversions.ifrombool` operation."""
73-
op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op("ifrombool")
74-
return ops.ExtOp(
75-
op_def,
76-
ht.FunctionType([ht.Bool], [int_t(1)]),
77-
)
89+
return _instantiate_convert_op("ifrombool", [ht.Bool], [int_t(0)])
7890

7991

8092
def convert_itobool() -> ops.ExtOp:
8193
"""Returns a `std.arithmetic.conversions.itobool` operation."""
82-
op_def = hugr.std.int.CONVERSIONS_EXTENSION.get_op("itobool")
83-
return ops.ExtOp(
84-
op_def,
85-
ht.FunctionType([int_t(1)], [ht.Bool]),
86-
)
94+
return _instantiate_convert_op("itobool", [int_t(0)], [ht.Bool])
8795

8896

8997
# ------------------------------------------------------
@@ -264,3 +272,35 @@ def compile(self, args: list[Wire]) -> list[Wire]:
264272
ht.FunctionType([FLOAT_T] * len(args), [FLOAT_T]),
265273
)
266274
return list(self.builder.add(ops.MakeTuple()(div, mod)))
275+
276+
277+
class IToBoolCompiler(CustomCallCompiler):
278+
"""Compiler for the `Int` and `Nat` `.__bool__` methods.
279+
280+
Note that the native `std.arithmetic.conversions.itobool` hugr op
281+
only supports 1 bit integers as input.
282+
"""
283+
284+
def compile(self, args: list[Wire]) -> list[Wire]:
285+
# Emit a comparison against zero
286+
[num] = args
287+
zero = self.builder.load(hugr.std.int.IntVal(0, width=6))
288+
out = self.builder.add_op(ine(NumericType.INT_WIDTH), num, zero)
289+
return [out]
290+
291+
292+
class IFromBoolCompiler(CustomCallCompiler):
293+
"""Compiler for the `Bool` `.__int__` and `.__nat__` methods.
294+
295+
Note that the native `std.arithmetic.conversions.ifrombool` hugr op
296+
only produces 1 bit integers as output, so we have to widen the result.
297+
"""
298+
299+
def compile(self, args: list[Wire]) -> list[Wire]:
300+
# Emit an `ifrombool` followed by a widening cast
301+
# We use `widen_u` independently of the target type, since we want the bit `1`
302+
# to be expanded to `0x00000001` even for `nat` types
303+
[boolean] = args
304+
bit = self.builder.add_op(convert_ifrombool(), boolean)
305+
num = self.builder.add_op(iwiden_u(0, NumericType.INT_WIDTH), bit)
306+
return [num]

guppylang/prelude/builtins.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
FloatDivmodCompiler,
2727
FloatFloordivCompiler,
2828
FloatModCompiler,
29+
IFromBoolCompiler,
2930
IntTruedivCompiler,
31+
IToBoolCompiler,
3032
NatTruedivCompiler,
3133
)
3234
from guppylang.prelude._internal.compiler.array import (
@@ -101,10 +103,10 @@ def __bool__(self: bool) -> bool: ...
101103
@guppy.hugr_op(builtins, logic_op("Eq"))
102104
def __eq__(self: bool, other: bool) -> bool: ...
103105

104-
@guppy.hugr_op(builtins, unsupported_op("ifrombool")) # TODO: Widen to INT_WIDTH
106+
@guppy.custom(builtins, IFromBoolCompiler())
105107
def __int__(self: bool) -> int: ...
106108

107-
@guppy.hugr_op(builtins, unsupported_op("ifrombool")) # TODO: Widen to INT_WIDTH
109+
@guppy.custom(builtins, IFromBoolCompiler())
108110
def __nat__(self: bool) -> nat: ...
109111

110112
@guppy.custom(builtins, checker=DunderChecker("__bool__"), higher_order_value=False)
@@ -128,9 +130,7 @@ def __add__(self: nat, other: nat) -> nat: ...
128130
@guppy.hugr_op(builtins, int_op("iand"))
129131
def __and__(self: nat, other: nat) -> nat: ...
130132

131-
@guppy.hugr_op(
132-
builtins, unsupported_op("itobool")
133-
) # TODO: itobool only supports single bit ints
133+
@guppy.custom(builtins, IToBoolCompiler())
134134
def __bool__(self: nat) -> bool: ...
135135

136136
@guppy.custom(builtins, NoopCompiler())
@@ -273,9 +273,7 @@ def __add__(self: int, other: int) -> int: ...
273273
@guppy.hugr_op(builtins, int_op("iand"))
274274
def __and__(self: int, other: int) -> int: ...
275275

276-
@guppy.hugr_op(
277-
builtins, unsupported_op("itobool")
278-
) # TODO: itobool only supports single bit ints
276+
@guppy.custom(builtins, IToBoolCompiler())
279277
def __bool__(self: int) -> bool: ...
280278

281279
@guppy.custom(builtins, NoopCompiler())

0 commit comments

Comments
 (0)