Skip to content

Commit

Permalink
[PyRTG] Support targets and test arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
maerhart committed Feb 13, 2025
1 parent 707947f commit 33d4271
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 18 deletions.
1 change: 1 addition & 0 deletions frontends/PyRTG/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ declare_mlir_python_sources(PyRTGSources
pyrtg/sequences.py
pyrtg/sets.py
pyrtg/support.py
pyrtg/target.py
pyrtg/tests.py
rtgtool/rtgtool.py
)
Expand Down
1 change: 1 addition & 0 deletions frontends/PyRTG/src/pyrtg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
from .integers import Integer
from .bags import Bag
from .sequences import sequence
from .target import target, entry
82 changes: 82 additions & 0 deletions frontends/PyRTG/src/pyrtg/target.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from __future__ import annotations

from .core import CodeGenRoot, Value
from .circt import ir
from .rtg import rtg


class Entry:
"""
Represents an RTG Target Entry. Stores the entry function and location.
"""

def __init__(self, entry_func) -> Entry:
self.entry_func = entry_func

@property
def name(self) -> str:
return self.entry_func.__name__


def entry(func):
"""
Decorator for target entry functions. It computes one value returned from the
target. The name of the function is used as the key in the target dictionary
and the values returned from the target will be sorted by name.
"""

return Entry(func)


def target(cls):
"""
Represents an RTG Target. Constructs an instance of the decorated class which
registers it as an RTG target.
"""

def new_init(self):
self._name = self.__class__.__name__
self._dict = cls.__dict__

cls = type(cls.__name__, (Target,) + cls.__bases__, dict(cls.__dict__))
cls.__init__ = new_init
instance = cls()
return instance


class Target(CodeGenRoot):
"""
An RTG Target is a collection of entry functions that define the capabilities
and characteristics of a specific test target. Each entry function computes
and returns a value that represents a particular feature or property of the
target.
"""

def _codegen(self) -> None:
entries = []
names = []

# Collect entries from the class dictionary.
for attr_name, attr in self.__class__.__dict__.items():
if isinstance(attr, Entry):
entries.append(attr)
names.append(attr_name)

# Construct the target operation.
target_op = rtg.TargetOp(self._name, ir.TypeAttr.get(rtg.DictType.get()))
entry_block = ir.Block.create_at_start(target_op.bodyRegion, [])
with ir.InsertionPoint(entry_block):
results: list[Value] = []

for entry in entries:
results.append(entry.entry_func())

rtg.YieldOp(results)

dict_entries = [(ir.StringAttr.get(name), val.get_type())
for (name, val) in zip(names, results)]
target_op.target = ir.TypeAttr.get(rtg.DictType.get(dict_entries))
33 changes: 18 additions & 15 deletions frontends/PyRTG/src/pyrtg/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,45 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import inspect

from .circt import ir
from .core import CodeGenRoot
from .rtg import rtg
from .support import _FromCirctValue


class Test(CodeGenRoot):
"""
Represents an RTG Test. Stores the test function and location.
"""

type: ir.Type

def __init__(self, test_func):
def __init__(self, test_func, args: list[tuple[str, ir.Type]]):
self.test_func = test_func

sig = inspect.signature(test_func)
assert len(sig.parameters) == 0, "test arguments not supported yet"

self.type = rtg.DictType.get([])
self.arg_names = [name for name, _ in args]
self.arg_types = [ty for _, ty in args]

@property
def name(self) -> str:
return self.test_func.__name__

def _codegen(self):
test = rtg.TestOp(self.name, ir.TypeAttr.get(self.type))
block = ir.Block.create_at_start(test.bodyRegion, [])
test = rtg.TestOp(
self.name,
ir.TypeAttr.get(
rtg.DictType.get([
(ir.StringAttr.get(name), ty)
for (name, ty) in zip(self.arg_names, self.arg_types)
])))
block = ir.Block.create_at_start(test.bodyRegion, self.arg_types)
with ir.InsertionPoint(block):
self.test_func(*block.arguments)
self.test_func(*[_FromCirctValue(arg) for arg in block.arguments])


def test(func):
def test(*args, **kwargs):
"""
Decorator for RTG test functions.
"""

return Test(func)
def wrapper(func):
return Test(func, list(args))

return wrapper
66 changes: 63 additions & 3 deletions frontends/PyRTG/test/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,42 @@
# RUN: %rtgtool% %s --seed=0 --output-format=elaborated | FileCheck %s --check-prefix=ELABORATED
# RUN: %rtgtool% %s --seed=0 -o %t --output-format=asm && FileCheck %s --input-file=%t --check-prefix=ASM

from pyrtg import test, sequence, rtg, Label, Set, Integer, Bag
from pyrtg import test, sequence, target, entry, rtg, Label, Set, Integer, Bag

# MLIR-LABEL: rtg.target @Tgt0 : !rtg.dict<entry0: !rtg.set<index>>
# MLIR-NEXT: [[C0:%.+]] = index.constant 0
# MLIR-NEXT: [[C1:%.+]] = index.constant 1
# MLIR-NEXT: [[SET:%.+]] = rtg.set_create [[C0:%.+]], [[C1:%.+]] : index
# MLIR-NEXT: rtg.yield [[SET]] : !rtg.set<index>
# MLIR-NEXT: }


@target
class Tgt0:

@entry
def entry0():
return Set.create(Integer(0), Integer(1))


# MLIR-LABEL: rtg.target @Tgt1 : !rtg.dict<entry0: index, entry1: !rtg.label>
# MLIR-NEXT: [[C0:%.+]] = index.constant 0
# MLIR-NEXT: [[LBL:%.+]] = rtg.label_decl "l0"
# MLIR-NEXT: rtg.yield [[C0]], [[LBL]] : index, !rtg.label
# MLIR-NEXT: }


@target
class Tgt1:

@entry
def entry0():
return Integer(0)

@entry
def entry1():
return Label.declare("l0")


# MLIR-LABEL: rtg.sequence @seq0
# MLIR-SAME: ([[SET:%.+]]: !rtg.set<!rtg.label>)
Expand Down Expand Up @@ -37,11 +72,36 @@ def seq1():
# ASM: End of test0


@test
@test()
def test0():
pass


# MLIR-LABEL: rtg.test @test_args
# MLIR-SAME: ([[SET:%.+]]: !rtg.set<index>)
# MLIR-NEXT: [[RAND:%.+]] = rtg.set_select_random [[SET]] : !rtg.set<index>
# MLIR-NEXT: rtg.label_decl "L_{{[{][{]0[}][}]}}", [[RAND]]
# MLIR-NEXT: rtg.label local
# MLIR-NEXT: }

# ELABORATED-LABEL: rtg.test @test_args_Tgt0
# CHECK: rtg.label_decl "L_0"
# CHECK-NEXT: rtg.label local
# CHECK-NEXT: }

# ASM-LABEL: Begin of test_args
# ASM-EMPTY:
# ASM-NEXT: L_0:
# ASM-EMPTY:
# ASM: End of test_args


@test(("entry0", Set.type(Integer.type())))
def test_args(set: Set):
i = set.get_random()
Label.declare(r"L_{{0}}", i).place()


# MLIR-LABEL: rtg.test @test_labels
# MLIR-NEXT: index.constant 5
# MLIR-NEXT: index.constant 3
Expand Down Expand Up @@ -154,7 +214,7 @@ def test0():
# ASM: End of test_labels


@test
@test()
def test_labels():
l0 = Label.declare("l0")
l1 = Label.declare_unique("l1")
Expand Down

0 comments on commit 33d4271

Please sign in to comment.