From 6b02ab54c5e49c0805cdda17b69d9c176de21e8c Mon Sep 17 00:00:00 2001 From: Ege Beysel Date: Sat, 21 Dec 2024 19:43:58 +0100 Subject: [PATCH] [TKW] Add support for tkw.round_even Signed-off-by: Ege Beysel --- iree/turbine/kernel/ops/wave_ops.py | 5 +++++ iree/turbine/kernel/wave/codegen.py | 11 +++++++++++ lit_tests/kernel/wave/codegen.py | 12 ++++++++---- 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/iree/turbine/kernel/ops/wave_ops.py b/iree/turbine/kernel/ops/wave_ops.py index 2f5680a00..5196fdeba 100644 --- a/iree/turbine/kernel/ops/wave_ops.py +++ b/iree/turbine/kernel/ops/wave_ops.py @@ -156,6 +156,10 @@ def reshape(inputs: Sequence["Register"]) -> "Register": ... +def round_even(src: "Register") -> "Register": + ... + + def define_op(op_name: str) -> Callable[[T], T]: def decorator(cls: T) -> T: cls.tkw_op_name = op_name @@ -704,6 +708,7 @@ def infer_type(self): @define_interface_op("exp2") @define_interface_op("reciprocal") @define_interface_op("abs") +@define_interface_op("round_even") @define_py_op(operator.neg) @dataclass class UnaryPyOp(CustomOp, ABC): diff --git a/iree/turbine/kernel/wave/codegen.py b/iree/turbine/kernel/wave/codegen.py index 57084ca62..3aa2cd360 100644 --- a/iree/turbine/kernel/wave/codegen.py +++ b/iree/turbine/kernel/wave/codegen.py @@ -73,6 +73,7 @@ cast, permute, reshape, + round_even, ) from ..lang.wave_types import IndexMapping, IndexSymbol from ..compiler.base import CodegenError, ValidationError, NDEBUG @@ -1197,6 +1198,16 @@ def handle_abs(source: Value) -> OpResult: return abs +@handle_unary_op(round_even) +def handle_round_even(source: Value) -> OpResult: + element_type = get_type_or_element_type(source.type) + if _is_float_type(element_type): + round_even = math_d.roundeven(source) + else: + raise ValidationError(f"Found unhandled operand type for abs: {element_type}") + return round_even + + ############################################################################### # Control Flow ops ############################################################################### diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index b994a4224..b7115bab2 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -725,6 +725,7 @@ def test( res = tkw.reciprocal(res) res = tkw.abs(res) res_b = tkw.abs(b_reg) + res = tkw.round_even(res) tkw.write(res, a, elements_per_thread=4) tkw.write(res_b, b, elements_per_thread=4) @@ -740,12 +741,15 @@ def test( # CHECK: %[[EXP2:.+]] = math.exp2 %[[NEG]] # Testing reciprocal - # %[[ONES:.+]] = arith.constant dense<1.000000e+00> : vector<4xf16> - # %[[RECIPROCAL:.+]] = arith.divf %[[ONES]], %[[EXP2]] : vector<4xf16> + # CHECK: %[[ONES:.+]] = arith.constant dense<1.000000e+00> : vector<4xf16> + # CHECK: %[[RECIPROCAL:.+]] = arith.divf %[[ONES]], %[[EXP2]] : vector<4xf16> # Testing abs - # %[[ABSF:.+]] = math.absf %[[RECIPROCAL]] - # %[[ABSI:.+]] = math.absi + # CHECK: %[[ABSF:.+]] = math.absf %[[RECIPROCAL]] + # CHECK: %[[ABSI:.+]] = math.absi + + # Testing round_even + # CHECK: %[[ROUNDEVEN:.+]] = math.roundeven %[[ABSF]] @run_test