Skip to content

Commit 707947f

Browse files
committed
[PyRTG] Support sequences
1 parent 9b105ae commit 707947f

File tree

12 files changed

+340
-5
lines changed

12 files changed

+340
-5
lines changed

frontends/PyRTG/src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ declare_mlir_python_sources(PyRTGSources
2020
pyrtg/integers.py
2121
pyrtg/labels.py
2222
pyrtg/rtg.py
23+
pyrtg/sequences.py
2324
pyrtg/sets.py
2425
pyrtg/support.py
2526
pyrtg/tests.py

frontends/PyRTG/src/pyrtg/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@
1111
from .sets import Set
1212
from .integers import Integer
1313
from .bags import Bag
14+
from .sequences import sequence

frontends/PyRTG/src/pyrtg/bags.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,11 @@ def _get_ssa_value(self) -> ir.Value:
134134

135135
def get_type(self) -> ir.Type:
136136
return self._value.type
137+
138+
def type(*args: ir.Type) -> ir.Type:
139+
"""
140+
Returns the bag type for the given element type.
141+
"""
142+
143+
assert len(args) == 1, "Bag type requires exactly one element type"
144+
return rtg.BagType.get(args[0])

frontends/PyRTG/src/pyrtg/core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class CodeGenRoot:
1111
during codegen.
1212
"""
1313

14-
def codegen(self):
14+
def _codegen(self):
1515
assert False, "must be implemented by the subclass"
1616

1717

@@ -28,5 +28,8 @@ class Value:
2828
def get_type(self) -> ir.Type:
2929
assert False, "must be implemented by subclass"
3030

31+
def type(*args: ir.Type) -> ir.Type:
32+
assert False, "must be implemented by subclass"
33+
3134
def _get_ssa_value(self) -> ir.Value:
3235
assert False, "must be implemented by subclass"

frontends/PyRTG/src/pyrtg/integers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,11 @@ def _get_ssa_value(self) -> ir.Value:
4949
self = index.ConstantOp(self._value)
5050

5151
return self._value
52+
53+
def type(*args: ir.Type) -> ir.Type:
54+
"""
55+
Returns the index type.
56+
"""
57+
58+
assert len(args) == 0, "Integer type does not take type arguments"
59+
return ir.IndexType.get()

frontends/PyRTG/src/pyrtg/labels.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,11 @@ def get_type(self) -> ir.Type:
6161

6262
def _get_ssa_value(self) -> ir.Value:
6363
return self._value
64+
65+
def type(*args: ir.Type) -> ir.Type:
66+
"""
67+
Returns the label type.
68+
"""
69+
70+
assert len(args) == 0, "Label type does not take type arguments"
71+
return rtg.LabelType.get()
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from __future__ import annotations
6+
7+
from .core import CodeGenRoot, Value
8+
from .support import _FromCirctValue
9+
from .circt import ir, support
10+
from .rtg import rtg
11+
12+
13+
class SequenceDeclaration(CodeGenRoot):
14+
"""
15+
This class is responsible for managing and generating RTG sequences. It
16+
encapsulates the sequence function, its argument types, and the source
17+
location where it was defined.
18+
"""
19+
20+
def __init__(self, sequence_func, arg_types: list[ir.Type]):
21+
self.sequence_func = sequence_func
22+
self.arg_types = arg_types
23+
24+
@property
25+
def name(self) -> str:
26+
return self.sequence_func.__name__
27+
28+
def get(self) -> Sequence:
29+
"""
30+
Returns a sequence value referring to this sequence declaration. It can be
31+
used for substitution, randomization, or passed as a value to other
32+
functions.
33+
"""
34+
35+
return Sequence(self._get_ssa_value())
36+
37+
def substitute(self, *args: Value) -> Sequence:
38+
"""
39+
Creates a new sequence with the given arguments substituted.
40+
41+
Args:
42+
*args: Values to substitute for the sequence's parameters.
43+
"""
44+
45+
return self.get().substitute(*args)
46+
47+
def randomize(self, *args: Value) -> RandomizedSequence:
48+
"""
49+
Randomizes this sequence, i.e., replaces all randomization constructs with
50+
concrete values.
51+
52+
Args:
53+
*args: Values to substitute for the sequence's parameters.
54+
"""
55+
56+
return self.get().randomize(*args)
57+
58+
def __call__(self, *args: Value) -> None:
59+
"""
60+
Convenience method to substitute, randomize, and embed this sequence in one
61+
go.
62+
63+
Args:
64+
*args: Values to substitute for the sequence's parameters.
65+
"""
66+
67+
self.get()(*args)
68+
69+
def _codegen(self) -> None:
70+
seq = rtg.SequenceOp(self.name,
71+
ir.TypeAttr.get(rtg.SequenceType.get(self.arg_types)))
72+
block = ir.Block.create_at_start(seq.bodyRegion, self.arg_types)
73+
with ir.InsertionPoint(block):
74+
self.sequence_func(*[_FromCirctValue(arg) for arg in block.arguments])
75+
76+
def _get_ssa_value(self) -> ir.Value:
77+
return rtg.GetSequenceOp(rtg.SequenceType.get(self.arg_types),
78+
self.name)._get_ssa_value()
79+
80+
def get_type(self):
81+
return rtg.SequenceType.get(self.arg_types)
82+
83+
84+
def sequence(*args: ir.Type, **kwargs):
85+
"""
86+
Decorator for defining RTG sequence functions.
87+
88+
Args:
89+
*args: The types of the sequence's parameters.
90+
"""
91+
92+
def wrapper(func):
93+
return SequenceDeclaration(func, list(args))
94+
95+
return wrapper
96+
97+
98+
class Sequence(Value):
99+
"""
100+
Represents a sequence value that can be substituted and randomized (i.e., all
101+
randomization constructs are replaced with concrete values). Once it is
102+
randomized it can be embedded into a test or another sequence.
103+
"""
104+
105+
def __init__(self, value: ir.Value) -> Sequence:
106+
"""
107+
Intended for library internal usage only.
108+
"""
109+
110+
self._value = value
111+
112+
def substitute(self, *args: Value) -> Sequence:
113+
"""
114+
Creates a new sequence with the given arguments substituted.
115+
116+
Args:
117+
*args: Values to substitute for the sequence's parameters.
118+
"""
119+
120+
element_types = self.element_types
121+
assert len(args) > 0, "At least one argument must be provided"
122+
assert len(args) <= len(
123+
element_types
124+
), f"Expected at most {len(element_types)} arguments, got {len(args)}"
125+
for arg, expected_type in zip(args, element_types):
126+
assert arg.get_type(
127+
) == expected_type, f"Expected argument of type {expected_type}, got {arg.get_type()}"
128+
129+
return rtg.SubstituteSequenceOp(self, args)
130+
131+
def randomize(self, *args: Value) -> RandomizedSequence:
132+
"""
133+
Creates a randomized version (i.e., all randomization constructs are
134+
replaced with concrete values) of this sequence.
135+
136+
Args:
137+
*args: Values to substitute for the sequence's parameters.
138+
"""
139+
140+
value = self
141+
element_types = self.element_types
142+
if len(element_types) > 0:
143+
assert len(args) == len(
144+
element_types
145+
), f"Expected {len(element_types)} arguments, got {len(args)}"
146+
for arg, expected_type in zip(args, element_types):
147+
assert arg.get_type(
148+
) == expected_type, f"Expected argument of type {expected_type}, got {arg.get_type()}"
149+
150+
value = self.substitute(*args)
151+
152+
return rtg.RandomizeSequenceOp(value)
153+
154+
def __call__(self, *args: Value) -> None:
155+
"""
156+
Convenience method to substitute, randomize, and embed this sequence in one
157+
go.
158+
159+
Args:
160+
*args: Values to substitute for the sequence's parameters.
161+
"""
162+
163+
self.randomize(*args).embed()
164+
165+
def _get_ssa_value(self) -> ir.Value:
166+
return self._value
167+
168+
@property
169+
def element_types(self) -> list[ir.Type]:
170+
"""
171+
Returns the list of elements types for this sequence.
172+
"""
173+
174+
type = support.type_to_pytype(self.get_type())
175+
return [type.get_element(i) for i in range(type.num_elements)]
176+
177+
def get_type(self) -> ir.Type:
178+
return self._value.type
179+
180+
def type(*args: ir.Type) -> ir.Type:
181+
"""
182+
Returns the sequence type with the given argument types.
183+
"""
184+
185+
return rtg.SequenceType.get(list(args))
186+
187+
188+
class RandomizedSequence(Value):
189+
"""
190+
Represents a randomized sequence value where all randomization constructs have
191+
been replaced with concrete values. It can be embedded into a test or another
192+
sequence.
193+
"""
194+
195+
def __init__(self, value: ir.Value) -> RandomizedSequence:
196+
"""
197+
Intended for library internal usage only.
198+
"""
199+
200+
self._value = value
201+
202+
def embed(self) -> None:
203+
"""
204+
Embeds this randomized sequence at the current position in the test or
205+
sequence.
206+
"""
207+
208+
rtg.EmbedSequenceOp(self)
209+
210+
def __call__(self) -> None:
211+
"""
212+
Convenience method to embed this sequence. Takes no arguments since the
213+
sequence is already fully sustituted.
214+
215+
Args:
216+
*args: Must be empty, since randomized sequences cannot take arguments.
217+
"""
218+
219+
self.embed()
220+
221+
def _get_ssa_value(self) -> ir.Value:
222+
return self._value
223+
224+
def get_type(self) -> ir.Type:
225+
return self._value.type
226+
227+
def type(*args: ir.Type) -> ir.Type:
228+
"""
229+
Returns the randomized sequence type.
230+
"""
231+
232+
assert len(
233+
args) == 0, "RandomizedSequence type does not take type arguments"
234+
return rtg.RandomizedSequenceType.get()

frontends/PyRTG/src/pyrtg/sets.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,3 +103,11 @@ def _get_ssa_value(self) -> ir.Value:
103103

104104
def get_type(self) -> ir.Type:
105105
return self._value.type
106+
107+
def type(*args: ir.Type) -> ir.Type:
108+
"""
109+
Returns the set type for the given element type.
110+
"""
111+
112+
assert len(args) == 1, "Set type requires exactly one element type"
113+
return rtg.SetType.get(args[0])

frontends/PyRTG/src/pyrtg/support.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ def _FromCirctValue(value: ir.Value) -> Value:
1818
if isinstance(type, rtg.BagType):
1919
from .bags import Bag
2020
return Bag(value)
21+
if isinstance(type, rtg.SequenceType):
22+
from .sequences import Sequence
23+
return Sequence(value)
24+
if isinstance(type, rtg.RandomizedSequenceType):
25+
from .sequences import RandomizedSequence
26+
return RandomizedSequence(value)
2127
if isinstance(type, ir.IndexType):
2228
from .integers import Integer
2329
return Integer(value)
@@ -44,7 +50,8 @@ def specialize_create(cls):
4450
def create(*args, **kwargs):
4551
# If any of the arguments are 'pyrtg.Value', we need to convert them.
4652
def to_circt(arg):
47-
if isinstance(arg, Value):
53+
from .sequences import SequenceDeclaration
54+
if isinstance(arg, (Value, SequenceDeclaration)):
4855
return arg._get_ssa_value()
4956
if isinstance(arg, (list, tuple)):
5057
return [to_circt(a) for a in arg]

frontends/PyRTG/src/pyrtg/tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, test_func):
2828
def name(self) -> str:
2929
return self.test_func.__name__
3030

31-
def codegen(self):
31+
def _codegen(self):
3232
test = rtg.TestOp(self.name, ir.TypeAttr.get(self.type))
3333
block = ir.Block.create_at_start(test.bodyRegion, [])
3434
with ir.InsertionPoint(block):

0 commit comments

Comments
 (0)