Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement conversion for aten.expand by using ttnn.repeat #21

Merged
merged 2 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions tests/test_expand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import torch
import torch_ttnn
import unittest
from torch_ttnn import ttnn
import tt_lib
from torch_ttnn.utils import (
DummyTtnnRowMajorLayout,
DummyTtnnTileLayout,
)


class ExpandModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, new_shape):
return x.expand(new_shape)

def input_shapes(self):
return [(1, 4), (4, 4)]


class ExpandAfterOpModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, new_shape):
a = torch.clone(x)
return a.expand(new_shape)

def input_shapes(self):
return [(1, 4), (4, 4)]


class ExpandBeforeOpModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, new_shape):
ex = x.expand(new_shape)
return torch.add(ex, ex)

def input_shapes(self):
return [(1, 4), (4, 4)]


class ExpandBetweenOpsModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, new_shape):
a = torch.clone(x)
ex = a.expand(new_shape)
return torch.add(ex, ex)

def input_shapes(self):
return [(1, 4), (4, 4)]


class TestModules(unittest.TestCase):
def setUp(self):
# Open device 0
self.device: ttnn.Device = ttnn.open_device(device_id=0)
# For AutoFormat
tt_lib.device.SetDefaultDevice(self.device)

def tearDown(self):
# Close the device
ttnn.close_device(self.device)

def test_expand(self):
m = ExpandModule()
input_shapes = m.input_shapes()
tensor = torch.rand(input_shapes[0], dtype=torch.bfloat16)
new_shape = input_shapes[1]
inputs = [tensor, new_shape]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=self.device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
self.assertTrue(nodes[4].target == ttnn.repeat)
self.assertTrue(nodes[4].args[1].target == ttnn.Shape)
self.assertTrue(nodes[5].target == ttnn.from_device)
self.assertTrue(nodes[6].target == ttnn.to_layout)
self.assertTrue(nodes[7].target == ttnn.to_torch)
# Check inference result
self.assertTrue(torch.allclose(result_before, result_after, rtol=0.2))

def test_expand_after_op(self):
m = ExpandAfterOpModule()
input_shapes = m.input_shapes()
tensor = torch.rand(input_shapes[0], dtype=torch.bfloat16)
new_shape = input_shapes[1]
inputs = [tensor, new_shape]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=self.device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
self.assertTrue(nodes[8].target == ttnn.repeat)
self.assertTrue(nodes[8].args[0].target == ttnn.to_layout)
self.assertTrue(nodes[8].args[0].args[0].target == ttnn.clone)
self.assertTrue(
type(nodes[8].args[0].args[1]) is type(DummyTtnnRowMajorLayout())
)
self.assertTrue(nodes[8].args[1].target == ttnn.Shape)
self.assertTrue(nodes[9].target == ttnn.from_device)
self.assertTrue(nodes[10].target == ttnn.to_layout)
self.assertTrue(nodes[11].target == ttnn.to_torch)
# Check inference result
self.assertTrue(torch.allclose(result_before, result_after, rtol=0.2))

def test_expand_before_op(self):
m = ExpandBeforeOpModule()
input_shapes = m.input_shapes()
tensor = torch.rand(input_shapes[0], dtype=torch.bfloat16)
new_shape = input_shapes[1]
inputs = [tensor, new_shape]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=self.device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
self.assertTrue(nodes[4].target == ttnn.repeat)
self.assertTrue(nodes[4].args[1].target == ttnn.Shape)
self.assertTrue(nodes[5].target == ttnn.to_layout)
self.assertTrue(nodes[5].args[0].target == ttnn.repeat)
self.assertTrue(type(nodes[5].args[1]) is type(DummyTtnnTileLayout()))
self.assertTrue(nodes[6].target == ttnn.add)
self.assertTrue(nodes[7].target == ttnn.from_device)
self.assertTrue(nodes[8].target == ttnn.to_layout)
self.assertTrue(nodes[9].target == ttnn.to_torch)
# Check inference result
self.assertTrue(torch.allclose(result_before, result_after, rtol=0.2))

def test_expand_between_ops(self):
m = ExpandBetweenOpsModule()
input_shapes = m.input_shapes()
tensor = torch.rand(input_shapes[0], dtype=torch.bfloat16)
new_shape = input_shapes[1]
inputs = [tensor, new_shape]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=self.device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
self.assertTrue(nodes[8].target == ttnn.repeat)
self.assertTrue(nodes[8].args[0].target == ttnn.to_layout)
self.assertTrue(nodes[8].args[0].args[0].target == ttnn.clone)
self.assertTrue(
type(nodes[8].args[0].args[1]) is type(DummyTtnnRowMajorLayout())
)
self.assertTrue(nodes[8].args[1].target == ttnn.Shape)
self.assertTrue(nodes[9].target == ttnn.to_layout)
self.assertTrue(nodes[9].args[0].target == ttnn.repeat)
self.assertTrue(type(nodes[9].args[1]) is type(DummyTtnnTileLayout()))
self.assertTrue(nodes[10].target == ttnn.add)
self.assertTrue(nodes[11].target == ttnn.from_device)
self.assertTrue(nodes[12].target == ttnn.to_layout)
self.assertTrue(nodes[13].target == ttnn.to_torch)
# Check inference result
self.assertTrue(torch.allclose(result_before, result_after, rtol=0.2))
51 changes: 51 additions & 0 deletions torch_ttnn/passes/add_data_move_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def is_tt_compute(node) -> bool:
ttnn.cos,
ttnn.sigmoid,
ttnn.as_tensor,
ttnn.repeat,
]
)

Expand All @@ -81,6 +82,7 @@ def is_tt_data_move(node) -> bool:
ttnn.to_torch,
ttnn.to_layout,
ttnn.MemoryConfig,
ttnn.Shape,
]


Expand Down Expand Up @@ -173,6 +175,7 @@ def try_add_data_move_in(src_node, dst_idx, dst_node, device) -> torch.fx.node.N
dst_node.target != ttnn.reshape
and dst_node.target != ttnn.embedding
and dst_node.target != ttnn.zeros_like
and dst_node.target != ttnn.repeat
):
new_nodes.append(
g.call_function(ttnn.to_layout, (new_nodes[-1], DummyTtnnTileLayout()))
Expand All @@ -187,6 +190,48 @@ def try_add_data_move_in(src_node, dst_idx, dst_node, device) -> torch.fx.node.N
return new_nodes[-1]


def try_add_layout_change_before_repeat(
src_node, dst_idx, dst_node
) -> torch.fx.node.Node:
# Consider dst_node is ttnn.repeat, and src_node are any tt nodes that ttnn.repeat uses
if isinstance(src_node, (int, float, list, tuple)) or not isinstance(
src_node, torch.fx.node.Node
):
return None
if not is_function_call(dst_node):
return None
if dst_node.target != ttnn.repeat or dst_idx != 0 or not is_tt(src_node):
return None

g = dst_node.graph
with g.inserting_before(dst_node):
to_layout = g.call_function(
ttnn.to_layout, (src_node, DummyTtnnRowMajorLayout())
)

insert_node_between(src_node, dst_idx, dst_node, [to_layout])

return to_layout


def try_add_layout_change_after_repeat(
src_node, dst_idx, dst_node
) -> torch.fx.node.Node:
# Consider src_node is ttnn.repeat, and dst_node should be any tt_compute node that uses ttnn.repeat
if not is_function_call(src_node):
return None
if src_node.target != ttnn.repeat or not is_tt_compute(dst_node):
return None

g = dst_node.graph
with g.inserting_before(dst_node):
to_layout = g.call_function(ttnn.to_layout, (dst_node, DummyTtnnTileLayout()))

insert_node_between(src_node, dst_idx, dst_node, [to_layout])

return to_layout


def try_add_data_move_out(src_node, dst_idx, dst_node) -> torch.fx.node.Node:
if not should_add_data_move_out(src_node, dst_node):
return None
Expand Down Expand Up @@ -278,6 +323,12 @@ def call(self, gm: torch.fx.GraphModule):
elif to_device := try_add_data_move_in(arg, idx, node, device):
data_move_in_hash[arg] = to_device
i += 1
elif to_layout := try_add_layout_change_before_repeat(arg, idx, node):
data_move_in_hash[arg] = to_layout
i += 1
elif to_layout := try_add_layout_change_after_repeat(arg, idx, node):
data_move_in_hash[arg] = to_layout
i += 1

if arg in data_move_out_hash and node.op == "output":
old_arg = node.args[0]
Expand Down
Loading