Skip to content

Commit

Permalink
Add conversion for aten.t.default (#23)
Browse files Browse the repository at this point in the history
* Add conversion for aten.t.default op along with tests

* Black code format

* Put aten.t tests under a specific dir
  • Loading branch information
mcw-zwakeelTT authored Jul 9, 2024
1 parent 3a5bfb0 commit 221c7b7
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 0 deletions.
122 changes: 122 additions & 0 deletions tests/lowering/tensor_manipulation/test_aten_t.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import torch
import torch_ttnn
import unittest
import ttnn
import tt_lib

from torch_ttnn.utils import check_with_pcc


class AtenTModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.t(x)

def input_shapes(self):
return [(1, 32)]


class AtenT0DModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.t(x)

def input_shapes(self):
return 5


class AtenT1DModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.t(x)

def input_shapes(self):
return [(5)]


class TestModules(unittest.TestCase):
def setUp(self):
# Open device 0
self.device: ttnn.Device = ttnn.open_device(device_id=0)
# For AutoFormat
tt_lib.device.SetDefaultDevice(self.device)

def tearDown(self):
# Close the device
ttnn.close_device(self.device)

def test_aten_t(self):
m = AtenTModule()
input_shapes = m.input_shapes()
input = torch.rand(input_shapes[0], dtype=torch.bfloat16)
result_before = m.forward(input)
option = torch_ttnn.TorchTtnnOption(device=self.device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(input)
option._out_fx_graphs[0].print_tabular()
# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)

self.assertTrue(nodes[4].target == ttnn.permute)
self.assertTrue(nodes[4].args[0].target == ttnn.to_device)
self.assertTrue(nodes[4].args[0].args[0].target == ttnn.to_layout)
self.assertTrue(nodes[4].args[0].args[0].args[0].target == ttnn.from_torch)
self.assertTrue(nodes[5].target == ttnn.from_device)
self.assertTrue(nodes[6].target == ttnn.to_layout)
self.assertTrue(nodes[7].target == ttnn.to_torch)
# Check inference result
self.assertTrue(check_with_pcc(result_before, result_after))

def test_aten_t_0d(self):
m = AtenT0DModule()
input_shapes = m.input_shapes()
input = torch.rand(input_shapes, dtype=torch.bfloat16)
result_before = m.forward(input)
option = torch_ttnn.TorchTtnnOption(device=self.device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(input)
option._out_fx_graphs[0].print_tabular()
# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)

self.assertTrue(nodes[1].target == torch.ops.aten.t.default)
self.assertTrue(nodes[1].args[0].target == "arg0_1")
self.assertTrue(nodes[1].args[0].op == "placeholder")
self.assertTrue(nodes[2].target == "output")
self.assertTrue(nodes[2].op == "output")
self.assertTrue(check_with_pcc(result_before, result_after))

def test_aten_t_1d(self):
m = AtenT1DModule()
input_shapes = m.input_shapes()
input = torch.rand(input_shapes[0], dtype=torch.bfloat16)
result_before = m.forward(input)
option = torch_ttnn.TorchTtnnOption(device=self.device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(input)
option._out_fx_graphs[0].print_tabular()
# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)

self.assertTrue(nodes[1].target == torch.ops.aten.t.default)
self.assertTrue(nodes[1].args[0].target == "arg0_1")
self.assertTrue(nodes[1].args[0].op == "placeholder")
self.assertTrue(nodes[2].target == "output")
self.assertTrue(nodes[2].op == "output")
self.assertTrue(check_with_pcc(result_before, result_after))


if __name__ == "__main__":
unittest.main()
13 changes: 13 additions & 0 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,19 @@ def ReplaceMoreTtManually(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
new_node,
delete_user_cb=lambda node: node != new_node,
)
if node.target == torch.ops.aten.t.default:
permutation = list()
rank = len(node.meta["val"].size())
assert rank >= 0 and rank <= 2, "Input tensor can only be 0D, 1D or 2D"
if rank == 2:
permutation = [1, 0]
new_node = g.call_function(
ttnn.permute, args=(args[0], permutation)
)
node.replace_all_uses_with(
new_node,
delete_user_cb=lambda node: node != new_node,
)

gm = GraphCleanup(gm)
return gm
Expand Down

0 comments on commit 221c7b7

Please sign in to comment.