Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added sum, min, max support and tests for TTIR Builder #2404

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 65 additions & 4 deletions python/test_infra/ttir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,17 +416,27 @@ def organize_golden_args(inputs: List[Operand], output: OpView, output_shape: Op

with self._ctx, self._loc:
# Compute the golden
golden = Golden(
op_golden_function(*(organize_golden_args(inputs)), **golden_kwargs)
)

if op_ttir_function in [ttir.MaxOp, ttir.MinOp]:
golden = Golden(
op_golden_function(
*(organize_golden_args(inputs)),
golden_kwargs["dim"][0],
golden_kwargs["keepdim"],
)[0]
)
else:
golden = Golden(
op_golden_function(*(organize_golden_args(inputs)), **golden_kwargs)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, why isn't the kwarg expansion sufficient here? I assume it's some weird inconsistency in how torch.{min,max} are called, but I can't see it. No need to change this if it's the only way to make it work.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I’m not thrilled with this solution, if you have a more elegant/scalable way lmk. But torch.max and min expect an int or name as the second argument whereas other torch functions expect a tuple of ints or names. They both also return torch.return_types.max(tensor, tensor) instead of a tensor.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, see most recent commit

# Use the golden output to determine proper output shape unless otherwise specified
output_shape = golden.tensor.shape if not output_shape else output_shape
# print(output_shape, type(output_shape))
output = self.empty(output_shape)

id = self.get_next_global_id()
loc = get_loc_of_extra_file_callee(id=id)

# print(organize_ttir_args(inputs, output, output_shape),loc,ttir_kwargs)
op = op_ttir_function(
*organize_ttir_args(inputs, output, output_shape),
loc=loc,
Expand Down Expand Up @@ -578,6 +588,22 @@ def maximum(self, in0: Operand, in1: Operand) -> OpView:
def minimum(self, in0: Operand, in1: Operand) -> OpView:
return self.eltwise_proxy(torch.minimum, ttir.MinimumOp, [in0, in1])

def sum(
self, in0: Operand, dim_arg: List[int] = [0], keep_dim: bool = True
) -> OpView:

golden_kwargs = {"dim": dim_arg, "keepdim": keep_dim}
ttir_kwargs = {"dim_arg": dim_arg, "keep_dim": keep_dim}

return self.op_proxy(
torch.sum,
ttir.SumOp,
[in0],
golden_kwargs=golden_kwargs,
ttir_kwargs=ttir_kwargs,
organize_ttir_args=lambda i, o, _: (self._get_type(o), i[0], o),
)

def mean(
self, in0: Operand, dim_arg: List[int] = [0], keep_dim: bool = True
) -> OpView:
Expand All @@ -594,6 +620,41 @@ def mean(
organize_ttir_args=lambda i, o, _: (self._get_type(o), i[0], o),
)

def max(
self,
in0: Operand,
dim_arg: List[int] = [0],
keep_dim: bool = True,
) -> OpView:

golden_kwargs = {"dim": dim_arg, "keepdim": keep_dim}
ttir_kwargs = {"dim_arg": dim_arg, "keep_dim": keep_dim}

return self.op_proxy(
torch.max,
ttir.MaxOp,
[in0],
golden_kwargs=golden_kwargs,
ttir_kwargs=ttir_kwargs,
organize_ttir_args=lambda i, o, _: (self._get_type(o), i[0], o),
)

def min(
self, in0: Operand, dim_arg: List[int] = [0], keep_dim: bool = True
) -> OpView:

golden_kwargs = {"dim": dim_arg, "keepdim": keep_dim}
ttir_kwargs = {"dim_arg": dim_arg, "keep_dim": keep_dim}

return self.op_proxy(
torch.min,
ttir.MinOp,
[in0],
golden_kwargs=golden_kwargs,
ttir_kwargs=ttir_kwargs,
organize_ttir_args=lambda i, o, _: (self._get_type(o), i[0], o),
)

def leaky_relu(self, in0: Operand, parameter: float = 0.01) -> OpView:
# TODO: reconcile this naming mismatch
ttir_kwargs = {"parameter": parameter}
Expand Down
30 changes: 30 additions & 0 deletions test/python/golden/test_ttir_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,16 @@ def test_matmul(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.matmul(in0, in1)


@compile_to_flatbuffer(
[
(64, 64),
],
targets=["ttnn"],
)
def test_sum(in0: Operand, builder: TTIRBuilder):
return builder.sum(in0)


@compile_to_flatbuffer(
[
(128, 128),
Expand All @@ -412,6 +422,26 @@ def test_mean(in0: Operand, builder: TTIRBuilder):
return builder.mean(in0)


@compile_to_flatbuffer(
[
(64, 64),
],
targets=["ttnn"],
)
def test_max(in0: Operand, builder: TTIRBuilder):
return builder.max(in0)


@compile_to_flatbuffer(
[
(64, 64),
],
targets=["ttnn"],
)
def test_min(in0: Operand, builder: TTIRBuilder):
return builder.min(in0)


@compile_to_flatbuffer(
[
(32, 64),
Expand Down
Loading