Skip to content

Commit 33d4271

Browse files
committed
[PyRTG] Support targets and test arguments
1 parent 707947f commit 33d4271

File tree

5 files changed

+165
-18
lines changed

5 files changed

+165
-18
lines changed

frontends/PyRTG/src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ declare_mlir_python_sources(PyRTGSources
2323
pyrtg/sequences.py
2424
pyrtg/sets.py
2525
pyrtg/support.py
26+
pyrtg/target.py
2627
pyrtg/tests.py
2728
rtgtool/rtgtool.py
2829
)

frontends/PyRTG/src/pyrtg/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@
1212
from .integers import Integer
1313
from .bags import Bag
1414
from .sequences import sequence
15+
from .target import target, entry

frontends/PyRTG/src/pyrtg/target.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from __future__ import annotations
6+
7+
from .core import CodeGenRoot, Value
8+
from .circt import ir
9+
from .rtg import rtg
10+
11+
12+
class Entry:
13+
"""
14+
Represents an RTG Target Entry. Stores the entry function and location.
15+
"""
16+
17+
def __init__(self, entry_func) -> Entry:
18+
self.entry_func = entry_func
19+
20+
@property
21+
def name(self) -> str:
22+
return self.entry_func.__name__
23+
24+
25+
def entry(func):
26+
"""
27+
Decorator for target entry functions. It computes one value returned from the
28+
target. The name of the function is used as the key in the target dictionary
29+
and the values returned from the target will be sorted by name.
30+
"""
31+
32+
return Entry(func)
33+
34+
35+
def target(cls):
36+
"""
37+
Represents an RTG Target. Constructs an instance of the decorated class which
38+
registers it as an RTG target.
39+
"""
40+
41+
def new_init(self):
42+
self._name = self.__class__.__name__
43+
self._dict = cls.__dict__
44+
45+
cls = type(cls.__name__, (Target,) + cls.__bases__, dict(cls.__dict__))
46+
cls.__init__ = new_init
47+
instance = cls()
48+
return instance
49+
50+
51+
class Target(CodeGenRoot):
52+
"""
53+
An RTG Target is a collection of entry functions that define the capabilities
54+
and characteristics of a specific test target. Each entry function computes
55+
and returns a value that represents a particular feature or property of the
56+
target.
57+
"""
58+
59+
def _codegen(self) -> None:
60+
entries = []
61+
names = []
62+
63+
# Collect entries from the class dictionary.
64+
for attr_name, attr in self.__class__.__dict__.items():
65+
if isinstance(attr, Entry):
66+
entries.append(attr)
67+
names.append(attr_name)
68+
69+
# Construct the target operation.
70+
target_op = rtg.TargetOp(self._name, ir.TypeAttr.get(rtg.DictType.get()))
71+
entry_block = ir.Block.create_at_start(target_op.bodyRegion, [])
72+
with ir.InsertionPoint(entry_block):
73+
results: list[Value] = []
74+
75+
for entry in entries:
76+
results.append(entry.entry_func())
77+
78+
rtg.YieldOp(results)
79+
80+
dict_entries = [(ir.StringAttr.get(name), val.get_type())
81+
for (name, val) in zip(names, results)]
82+
target_op.target = ir.TypeAttr.get(rtg.DictType.get(dict_entries))

frontends/PyRTG/src/pyrtg/tests.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,42 +2,45 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5-
import inspect
6-
75
from .circt import ir
86
from .core import CodeGenRoot
97
from .rtg import rtg
8+
from .support import _FromCirctValue
109

1110

1211
class Test(CodeGenRoot):
1312
"""
1413
Represents an RTG Test. Stores the test function and location.
1514
"""
1615

17-
type: ir.Type
18-
19-
def __init__(self, test_func):
16+
def __init__(self, test_func, args: list[tuple[str, ir.Type]]):
2017
self.test_func = test_func
21-
22-
sig = inspect.signature(test_func)
23-
assert len(sig.parameters) == 0, "test arguments not supported yet"
24-
25-
self.type = rtg.DictType.get([])
18+
self.arg_names = [name for name, _ in args]
19+
self.arg_types = [ty for _, ty in args]
2620

2721
@property
2822
def name(self) -> str:
2923
return self.test_func.__name__
3024

3125
def _codegen(self):
32-
test = rtg.TestOp(self.name, ir.TypeAttr.get(self.type))
33-
block = ir.Block.create_at_start(test.bodyRegion, [])
26+
test = rtg.TestOp(
27+
self.name,
28+
ir.TypeAttr.get(
29+
rtg.DictType.get([
30+
(ir.StringAttr.get(name), ty)
31+
for (name, ty) in zip(self.arg_names, self.arg_types)
32+
])))
33+
block = ir.Block.create_at_start(test.bodyRegion, self.arg_types)
3434
with ir.InsertionPoint(block):
35-
self.test_func(*block.arguments)
35+
self.test_func(*[_FromCirctValue(arg) for arg in block.arguments])
3636

3737

38-
def test(func):
38+
def test(*args, **kwargs):
3939
"""
4040
Decorator for RTG test functions.
4141
"""
4242

43-
return Test(func)
43+
def wrapper(func):
44+
return Test(func, list(args))
45+
46+
return wrapper

frontends/PyRTG/test/basic.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,42 @@
22
# RUN: %rtgtool% %s --seed=0 --output-format=elaborated | FileCheck %s --check-prefix=ELABORATED
33
# RUN: %rtgtool% %s --seed=0 -o %t --output-format=asm && FileCheck %s --input-file=%t --check-prefix=ASM
44

5-
from pyrtg import test, sequence, rtg, Label, Set, Integer, Bag
5+
from pyrtg import test, sequence, target, entry, rtg, Label, Set, Integer, Bag
6+
7+
# MLIR-LABEL: rtg.target @Tgt0 : !rtg.dict<entry0: !rtg.set<index>>
8+
# MLIR-NEXT: [[C0:%.+]] = index.constant 0
9+
# MLIR-NEXT: [[C1:%.+]] = index.constant 1
10+
# MLIR-NEXT: [[SET:%.+]] = rtg.set_create [[C0:%.+]], [[C1:%.+]] : index
11+
# MLIR-NEXT: rtg.yield [[SET]] : !rtg.set<index>
12+
# MLIR-NEXT: }
13+
14+
15+
@target
16+
class Tgt0:
17+
18+
@entry
19+
def entry0():
20+
return Set.create(Integer(0), Integer(1))
21+
22+
23+
# MLIR-LABEL: rtg.target @Tgt1 : !rtg.dict<entry0: index, entry1: !rtg.label>
24+
# MLIR-NEXT: [[C0:%.+]] = index.constant 0
25+
# MLIR-NEXT: [[LBL:%.+]] = rtg.label_decl "l0"
26+
# MLIR-NEXT: rtg.yield [[C0]], [[LBL]] : index, !rtg.label
27+
# MLIR-NEXT: }
28+
29+
30+
@target
31+
class Tgt1:
32+
33+
@entry
34+
def entry0():
35+
return Integer(0)
36+
37+
@entry
38+
def entry1():
39+
return Label.declare("l0")
40+
641

742
# MLIR-LABEL: rtg.sequence @seq0
843
# MLIR-SAME: ([[SET:%.+]]: !rtg.set<!rtg.label>)
@@ -37,11 +72,36 @@ def seq1():
3772
# ASM: End of test0
3873

3974

40-
@test
75+
@test()
4176
def test0():
4277
pass
4378

4479

80+
# MLIR-LABEL: rtg.test @test_args
81+
# MLIR-SAME: ([[SET:%.+]]: !rtg.set<index>)
82+
# MLIR-NEXT: [[RAND:%.+]] = rtg.set_select_random [[SET]] : !rtg.set<index>
83+
# MLIR-NEXT: rtg.label_decl "L_{{[{][{]0[}][}]}}", [[RAND]]
84+
# MLIR-NEXT: rtg.label local
85+
# MLIR-NEXT: }
86+
87+
# ELABORATED-LABEL: rtg.test @test_args_Tgt0
88+
# CHECK: rtg.label_decl "L_0"
89+
# CHECK-NEXT: rtg.label local
90+
# CHECK-NEXT: }
91+
92+
# ASM-LABEL: Begin of test_args
93+
# ASM-EMPTY:
94+
# ASM-NEXT: L_0:
95+
# ASM-EMPTY:
96+
# ASM: End of test_args
97+
98+
99+
@test(("entry0", Set.type(Integer.type())))
100+
def test_args(set: Set):
101+
i = set.get_random()
102+
Label.declare(r"L_{{0}}", i).place()
103+
104+
45105
# MLIR-LABEL: rtg.test @test_labels
46106
# MLIR-NEXT: index.constant 5
47107
# MLIR-NEXT: index.constant 3
@@ -154,7 +214,7 @@ def test0():
154214
# ASM: End of test_labels
155215

156216

157-
@test
217+
@test()
158218
def test_labels():
159219
l0 = Label.declare("l0")
160220
l1 = Label.declare_unique("l1")

0 commit comments

Comments
 (0)