Skip to content

Commit

Permalink
Convert aten.cumsum to ttnn.moreh_cumsum (#370)
Browse files Browse the repository at this point in the history
* Add torch.ops.aten.cumsum.default lowering to ttnn.moreh_cumsum

* Workaround for ttnn.moreh_cumsum

---------

Co-authored-by: Po-Sheng Chang <[email protected]>
Co-authored-by: Artem Yerofieiev <[email protected]>
  • Loading branch information
3 people authored Nov 11, 2024
1 parent 4f5158b commit 55a6521
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 6 deletions.
57 changes: 57 additions & 0 deletions tests/lowering/misc/test_cumsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import torch
import torch_ttnn
import pytest

from tests.utils import assert_with_pcc


class CumsumModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input, dim):
return torch.ops.aten.cumsum.default(input, dim=dim)


@pytest.mark.parametrize(
"input_shapes, dim",
[
((1, 32), -1),
((1, 45), -1),
((1, 59), 1),
((1, 5), -1),
((1, 60), 1),
((1, 10), 1),
((4, 32, 32), 0),
((1, 4, 32, 32), 1),
((4, 4, 32, 32), 0),
((1, 23, 40), 1),
((4, 32), 0),
pytest.param(
(1, 1, 32, 32),
3,
marks=pytest.mark.xfail(reson="inner-most 2 dims are not supported (#367)"),
),
pytest.param(
(1, 23, 40),
2,
marks=pytest.mark.xfail(reson="inner-most 2 dims are not supported (#367)"),
),
],
)
def test_cumsum(device, input_shapes, dim):
m = CumsumModule()
inputs = torch.rand(input_shapes, dtype=torch.bfloat16)
result_before = m.forward(inputs, dim)

option = torch_ttnn.TorchTtnnOption(device=device, gen_graphviz=False)
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)

result_after = m.forward(inputs, dim)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = [node.target for node in option._out_fx_graphs[0].nodes]
assert nodes.count(torch.ops.aten.cumsum.default) == 0
assert_with_pcc(result_before, result_after, pcc=0.99)
9 changes: 3 additions & 6 deletions torch_ttnn/passes/lowering/add_data_move_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,6 @@ def is_function_call(node) -> bool:
)


def can_be_tilized(node):
size = node.meta["val"].size()
return len(size) >= 2 and size[-1] % 32 == 0 and size[-2] % 32 == 0


# For operations limitations
# See https://github.com/tenstorrent-metal/tt-metal/blob/main/ttnn/README.md?plain=1#L19
def is_tt_compute(node) -> bool:
Expand Down Expand Up @@ -178,6 +173,7 @@ def is_tt_compute(node) -> bool:
ttnn.squeeze,
ttnn.full,
ttnn.as_tensor,
ttnn.moreh_cumsum,
]
)

Expand Down Expand Up @@ -335,7 +331,8 @@ def try_add_layout_change_before_node(src_node, dst_idx, dst_node, device) -> to
need_from_device = False
need_to_layout = False
need_to_device = False
if dst_node.target in TTNN_LAYOUT_CHANGE_OPS and dst_idx == 0 and is_tt(src_node) and not can_be_tilized(dst_node):
# TODO(#372): #322 will enable tile layout for more layout change ops
if dst_node.target in TTNN_LAYOUT_CHANGE_OPS and dst_idx == 0 and is_tt(src_node):
need_from_device = True
need_to_layout = True

Expand Down
23 changes: 23 additions & 0 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,29 @@ def rewrite_node(node):
else:
return None

if node.target == torch.ops.aten.cumsum.default:
tensor, dim = args
input_shape = tensor.meta["val"].size()
rank = len(input_shape)
if rank > 4:
return None
dim = (dim + rank) % rank
# Unsqueeze input tensor to 4D for cumsum
# TODO(#367): Special case if dim is inner-most 2 dim. Unsqueeze (x, y) to (x, y, 1, 1) as cumsum currently only support N and C
if (dim - rank) >= -2:
if rank <= 2:
input_4d_shape = (1,) * (2 - rank) + (*input_shape, 1, 1)
elif rank == 3 and dim == 1:
input_4d_shape = (*input_shape, 1)
else:
return None
else:
input_4d_shape = (1,) * (4 - rank) + input_shape
dim += 4 - rank
input_4d = g.call_function(ttnn.reshape, (tensor, input_4d_shape))
output_4d = g.call_function(ttnn.moreh_cumsum, (input_4d, dim), kwargs)
return g.call_function(ttnn.reshape, (output_4d, input_shape))

with g.inserting_before(node):
new_node = rewrite_node(node)
if new_node is not None:
Expand Down

0 comments on commit 55a6521

Please sign in to comment.