diff --git a/python/test_infra/ttir_builder.py b/python/test_infra/ttir_builder.py index f58402d525..118fc746ed 100644 --- a/python/test_infra/ttir_builder.py +++ b/python/test_infra/ttir_builder.py @@ -11,6 +11,7 @@ from ttmlir.dialects import ttir, tt, tensor from ttmlir.passes import GoldenTensor, DataType import torch +import array # Alias for operands of ops which can be either BlockArguments, Values, or other # ops wrapped in OpView or Operation. @@ -416,8 +417,13 @@ def organize_golden_args(inputs: List[Operand], output: OpView, output_shape: Op with self._ctx, self._loc: # Compute the golden - golden = Golden( - op_golden_function(*(organize_golden_args(inputs)), **golden_kwargs) + golden_output = op_golden_function( + *(organize_golden_args(inputs)), **golden_kwargs + ) + golden = ( + Golden(golden_output[0]) + if not isinstance(golden_output, torch.Tensor) + else Golden(golden_output) ) # Use the golden output to determine proper output shape unless otherwise specified @@ -578,6 +584,25 @@ def maximum(self, in0: Operand, in1: Operand) -> OpView: def minimum(self, in0: Operand, in1: Operand) -> OpView: return self.eltwise_proxy(torch.minimum, ttir.MinimumOp, [in0, in1]) + def power(self, in0: Operand, in1: Operand) -> OpView: + return self.eltwise_proxy(torch.pow, ttir.PowerOp, [in0, in1]) + + def sum( + self, in0: Operand, dim_arg: List[int] = [0], keep_dim: bool = True + ) -> OpView: + + golden_kwargs = {"dim": dim_arg, "keepdim": keep_dim} + ttir_kwargs = {"dim_arg": dim_arg, "keep_dim": keep_dim} + + return self.op_proxy( + torch.sum, + ttir.SumOp, + [in0], + golden_kwargs=golden_kwargs, + ttir_kwargs=ttir_kwargs, + organize_ttir_args=lambda i, o, _: (self._get_type(o), i[0], o), + ) + def mean( self, in0: Operand, dim_arg: List[int] = [0], keep_dim: bool = True ) -> OpView: @@ -594,6 +619,34 @@ def mean( organize_ttir_args=lambda i, o, _: (self._get_type(o), i[0], o), ) + def max(self, in0: Operand, dim_arg: int = 0, keep_dim: bool = True) -> OpView: + + golden_kwargs = {"dim": dim_arg, "keepdim": keep_dim} + ttir_kwargs = {"dim_arg": [dim_arg], "keep_dim": keep_dim} + + return self.op_proxy( + torch.max, + ttir.MaxOp, + [in0], + golden_kwargs=golden_kwargs, + ttir_kwargs=ttir_kwargs, + organize_ttir_args=lambda i, o, _: (self._get_type(o), i[0], o), + ) + + def min(self, in0: Operand, dim_arg: int = 0, keep_dim: bool = True) -> OpView: + + golden_kwargs = {"dim": dim_arg, "keepdim": keep_dim} + ttir_kwargs = {"dim_arg": [dim_arg], "keep_dim": keep_dim} + + return self.op_proxy( + torch.min, + ttir.MinOp, + [in0], + golden_kwargs=golden_kwargs, + ttir_kwargs=ttir_kwargs, + organize_ttir_args=lambda i, o, _: (self._get_type(o), i[0], o), + ) + def leaky_relu(self, in0: Operand, parameter: float = 0.01) -> OpView: # TODO: reconcile this naming mismatch ttir_kwargs = {"parameter": parameter} diff --git a/test/python/golden/test_ttir_ops.py b/test/python/golden/test_ttir_ops.py index 8e7bffb5b1..1d3dd96cfe 100644 --- a/test/python/golden/test_ttir_ops.py +++ b/test/python/golden/test_ttir_ops.py @@ -391,6 +391,17 @@ def test_minimum(in0: Operand, in1: Operand, builder: TTIRBuilder): return builder.minimum(in0, in1) +@compile_to_flatbuffer( + [ + (64, 64), + (64, 64), + ], + targets=["ttnn"], +) +def test_power(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.power(in0, in1) + + @compile_to_flatbuffer( [ (32, 64), @@ -402,6 +413,16 @@ def test_matmul(in0: Operand, in1: Operand, builder: TTIRBuilder): return builder.matmul(in0, in1) +@compile_to_flatbuffer( + [ + (64, 64), + ], + targets=["ttnn"], +) +def test_sum(in0: Operand, builder: TTIRBuilder): + return builder.sum(in0) + + @compile_to_flatbuffer( [ (128, 128), @@ -412,6 +433,26 @@ def test_mean(in0: Operand, builder: TTIRBuilder): return builder.mean(in0) +@compile_to_flatbuffer( + [ + (64, 64), + ], + targets=["ttnn"], +) +def test_max(in0: Operand, builder: TTIRBuilder): + return builder.max(in0) + + +@compile_to_flatbuffer( + [ + (64, 64), + ], + targets=["ttnn"], +) +def test_min(in0: Operand, builder: TTIRBuilder): + return builder.min(in0) + + @compile_to_flatbuffer( [ (32, 64),