Skip to content

Commit b88b88d

Browse files
authored
Separate tests into individual files and switch to Pytest (#49)
1 parent ffd89d1 commit b88b88d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+2542
-2808
lines changed

tests/conftest.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import pytest
2+
import ttnn
3+
import torch
4+
5+
6+
@pytest.fixture(scope="session")
7+
def device():
8+
device = ttnn.open_device(device_id=0)
9+
yield device
10+
ttnn.close_device(device)
11+
12+
13+
@pytest.fixture(autouse=True)
14+
def reset_torch_dynamo():
15+
# PyTorch caches models. Start a fresh compile for each parameter of the test case.
16+
torch._dynamo.reset()
17+
yield
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import torch
2+
import torch_ttnn
3+
import pytest
4+
import ttnn
5+
6+
7+
class ArangeModule(torch.nn.Module):
8+
def __init__(self):
9+
super().__init__()
10+
11+
def forward(self, end):
12+
# start = 0, step = 1
13+
return torch.arange(end)
14+
15+
16+
class ArangeStartModule(torch.nn.Module):
17+
def __init__(self):
18+
super().__init__()
19+
20+
def forward(self, start, end):
21+
# step = 1
22+
return torch.arange(start, end)
23+
24+
25+
class ArangeStartStepModule(torch.nn.Module):
26+
def __init__(self):
27+
super().__init__()
28+
29+
def forward(self, start, end, step):
30+
return torch.arange(start, end, step)
31+
32+
33+
# NOTE(kevinwuTT) This test fails because ttnn.arange does not support start value of 0.
34+
@pytest.mark.xfail
35+
@pytest.mark.parametrize(
36+
"input_shapes",
37+
[[100]],
38+
)
39+
def test_arange(device, input_shapes):
40+
m = ArangeModule()
41+
result_before = m.forward(*input_shapes).to(torch.bfloat16)
42+
option = torch_ttnn.TorchTtnnOption(device=device)
43+
option.gen_graphviz = True
44+
# The compilation is lazy, so we need to run forward once to trigger the compilation
45+
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
46+
result_after = m.forward(*input_shapes)
47+
option._out_fx_graphs[0].print_tabular()
48+
49+
# Check the graph has be rewritten and contain ttnn ops
50+
nodes = list(option._out_fx_graphs[0].nodes)
51+
assert [node.target for node in nodes].count(ttnn.arange) == 1
52+
# Check inference result
53+
assert torch.allclose(result_before, result_after)
54+
55+
56+
@pytest.mark.parametrize(
57+
"input_shapes",
58+
[[2, 100]],
59+
)
60+
def test_arange_start(device, input_shapes):
61+
m = ArangeStartModule()
62+
result_before = m.forward(*input_shapes).to(torch.bfloat16)
63+
option = torch_ttnn.TorchTtnnOption(device=device)
64+
option.gen_graphviz = True
65+
# The compilation is lazy, so we need to run forward once to trigger the compilation
66+
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
67+
result_after = m.forward(*input_shapes)
68+
option._out_fx_graphs[0].print_tabular()
69+
70+
# Check the graph has be rewritten and contain ttnn ops
71+
nodes = list(option._out_fx_graphs[0].nodes)
72+
assert [node.target for node in nodes].count(ttnn.arange) == 1
73+
# Check inference result
74+
assert torch.allclose(result_before, result_after)
75+
76+
77+
@pytest.mark.parametrize(
78+
"input_shapes",
79+
[[4, 100, 3]],
80+
)
81+
def test_arange_start_step(device, input_shapes):
82+
m = ArangeStartStepModule()
83+
result_before = m.forward(*input_shapes).to(torch.bfloat16)
84+
option = torch_ttnn.TorchTtnnOption(device=device)
85+
option.gen_graphviz = True
86+
# The compilation is lazy, so we need to run forward once to trigger the compilation
87+
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
88+
result_after = m.forward(*input_shapes)
89+
option._out_fx_graphs[0].print_tabular()
90+
91+
# Check the graph has be rewritten and contain ttnn ops
92+
nodes = list(option._out_fx_graphs[0].nodes)
93+
assert [node.target for node in nodes].count(ttnn.arange) == 1
94+
# Check inference result
95+
assert torch.allclose(result_before, result_after)

tests/lowering/creation/test_clone.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import torch
2+
import torch_ttnn
3+
import pytest
4+
import ttnn
5+
6+
from tests.utils import check_with_pcc
7+
8+
9+
class CloneFromNodeModule(torch.nn.Module):
10+
def __init__(self):
11+
super().__init__()
12+
13+
def forward(self, input):
14+
a = input + input
15+
return torch.clone(a)
16+
17+
18+
class CloneFromArgModule(torch.nn.Module):
19+
def __init__(self):
20+
super().__init__()
21+
22+
def forward(self, input):
23+
return torch.clone(input)
24+
25+
26+
@pytest.mark.parametrize(
27+
"input_shapes",
28+
[[(4, 4)]],
29+
)
30+
def test_clone_from_arg(device, input_shapes):
31+
m = CloneFromArgModule()
32+
inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes]
33+
result_before = m.forward(*inputs)
34+
option = torch_ttnn.TorchTtnnOption(device=device)
35+
option.gen_graphviz = False
36+
# The compilation is lazy, so we need to run forward once to trigger the compilation
37+
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
38+
result_after = m.forward(*inputs)
39+
option._out_fx_graphs[0].print_tabular()
40+
41+
# Check the graph has be rewritten and contain ttnn ops
42+
nodes = list(option._out_fx_graphs[0].nodes)
43+
assert [node.target for node in nodes].count(ttnn.clone) == 1
44+
# Check inference result
45+
assert torch.allclose(result_before, result_after)
46+
47+
48+
@pytest.mark.parametrize(
49+
"input_shapes",
50+
[[(4, 4)]],
51+
)
52+
def test_clone_from_node(device, input_shapes):
53+
m = CloneFromNodeModule()
54+
inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes]
55+
result_before = m.forward(*inputs)
56+
option = torch_ttnn.TorchTtnnOption(device=device)
57+
option.gen_graphviz = False
58+
# The compilation is lazy, so we need to run forward once to trigger the compilation
59+
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
60+
result_after = m.forward(*inputs)
61+
option._out_fx_graphs[0].print_tabular()
62+
63+
# Check the graph has be rewritten and contain ttnn ops
64+
nodes = list(option._out_fx_graphs[0].nodes)
65+
target = [node.target for node in nodes]
66+
assert target.count(ttnn.clone) == 1
67+
clone_arg_0 = nodes[target.index(ttnn.clone)].args[0].target
68+
assert isinstance(clone_arg_0, ttnn.decorators.FastOperation) or isinstance(
69+
clone_arg_0, ttnn.decorators.Operation
70+
)
71+
# Check inference result
72+
assert torch.allclose(result_before, result_after)

tests/lowering/creation/test_full.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch
2+
import torch_ttnn
3+
import pytest
4+
import ttnn
5+
6+
7+
class FullModule(torch.nn.Module):
8+
def __init__(self):
9+
super().__init__()
10+
11+
def forward(self, size, fill_value):
12+
return torch.full(size, fill_value)
13+
14+
15+
@pytest.mark.parametrize(
16+
"input_shapes",
17+
[[(64, 128)]],
18+
)
19+
def test_full(device, input_shapes):
20+
m = FullModule()
21+
fill_value = 1.23
22+
result_before = m.forward(input_shapes[0], fill_value).to(torch.bfloat16)
23+
option = torch_ttnn.TorchTtnnOption(device=device)
24+
option.gen_graphviz = True
25+
# The compilation is lazy, so we need to run forward once to trigger the compilation
26+
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
27+
result_after = m.forward(input_shapes[0], fill_value)
28+
option._out_fx_graphs[0].print_tabular()
29+
30+
# Check the graph has be rewritten and contain ttnn ops
31+
nodes = list(option._out_fx_graphs[0].nodes)
32+
assert [node.target for node in nodes].count(ttnn.full) == 1
33+
# Check inference result
34+
assert torch.allclose(result_before, result_after)

tests/lowering/creation/test_ones.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch
2+
import torch_ttnn
3+
import pytest
4+
import ttnn
5+
6+
7+
class OnesModule(torch.nn.Module):
8+
def __init__(self):
9+
super().__init__()
10+
11+
def forward(self, shape):
12+
return torch.ones(shape)
13+
14+
15+
@pytest.mark.parametrize(
16+
"input_shapes",
17+
[(32, 32)],
18+
)
19+
def test_ones(device, input_shapes):
20+
m = OnesModule()
21+
result_before = m.forward(input_shapes)
22+
result_before = result_before.to(torch.bfloat16)
23+
option = torch_ttnn.TorchTtnnOption(device=device)
24+
option.gen_graphviz = True
25+
# The compilation is lazy, so we need to run forward once to trigger the compilation
26+
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
27+
result_after = m.forward(input_shapes)
28+
option._out_fx_graphs[0].print_tabular()
29+
30+
# Check the graph has be rewritten and contain ttnn ops
31+
nodes = list(option._out_fx_graphs[0].nodes)
32+
assert [node.target for node in nodes].count(ttnn.ones) == 1
33+
# Check inference result
34+
assert torch.allclose(result_before, result_after)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torch
2+
import torch_ttnn
3+
import pytest
4+
import ttnn
5+
6+
7+
class ToCopyModule(torch.nn.Module):
8+
def __init__(self):
9+
super().__init__()
10+
11+
def forward(self, x):
12+
return x.to(torch.bfloat16)
13+
14+
15+
class ToCopyWithOpAfterModule(torch.nn.Module):
16+
def __init__(self):
17+
super().__init__()
18+
19+
def forward(self, x):
20+
to = x.to(torch.bfloat16)
21+
return torch.add(to, to)
22+
23+
24+
@pytest.mark.parametrize(
25+
"input_shapes",
26+
[[(4, 4)]],
27+
)
28+
def test_to_copy(device, input_shapes):
29+
m = ToCopyModule()
30+
inputs = [torch.rand(shape) for shape in input_shapes]
31+
result_before = m.forward(*inputs)
32+
option = torch_ttnn.TorchTtnnOption(device=device)
33+
option.gen_graphviz = True
34+
# The compilation is lazy, so we need to run forward once to trigger the compilation
35+
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
36+
result_after = m.forward(*inputs)
37+
option._out_fx_graphs[0].print_tabular()
38+
39+
# Check the graph has be rewritten and contain ttnn ops
40+
nodes = list(option._out_fx_graphs[0].nodes)
41+
assert [node.target for node in nodes].count(ttnn.as_tensor) == 1
42+
# Check inference result
43+
assert torch.allclose(result_before, result_after, rtol=0.2)
44+
45+
46+
@pytest.mark.parametrize(
47+
"input_shapes",
48+
[[(4, 4)]],
49+
)
50+
def test_to_copy_with_op_after(device, input_shapes):
51+
m = ToCopyWithOpAfterModule()
52+
inputs = [torch.rand(shape) for shape in input_shapes]
53+
result_before = m.forward(*inputs)
54+
option = torch_ttnn.TorchTtnnOption(device=device)
55+
option.gen_graphviz = True
56+
# The compilation is lazy, so we need to run forward once to trigger the compilation
57+
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
58+
result_after = m.forward(*inputs)
59+
option._out_fx_graphs[0].print_tabular()
60+
61+
# Check the graph has be rewritten and contain ttnn ops
62+
nodes = list(option._out_fx_graphs[0].nodes)
63+
target = [node.target for node in nodes]
64+
assert target.count(ttnn.as_tensor) == 1
65+
assert target.count(ttnn.add) == 1
66+
add_node = nodes[target.index(ttnn.add)]
67+
assert add_node.args[0].target == ttnn.as_tensor
68+
assert add_node.args[1].target == ttnn.as_tensor
69+
# Check inference result
70+
assert torch.allclose(result_before, result_after, rtol=0.2)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import torch
2+
import torch_ttnn
3+
import pytest
4+
import ttnn
5+
6+
7+
class ZerosLikeModule(torch.nn.Module):
8+
def __init__(self):
9+
super().__init__()
10+
11+
def forward(self, input):
12+
return torch.zeros_like(input)
13+
14+
15+
@pytest.mark.parametrize(
16+
"input_shapes",
17+
[[(4, 4)]],
18+
)
19+
def test_zeros_like(device, input_shapes):
20+
m = ZerosLikeModule()
21+
inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes]
22+
result_before = m.forward(*inputs)
23+
option = torch_ttnn.TorchTtnnOption(device=device)
24+
option.gen_graphviz = True
25+
# The compilation is lazy, so we need to run forward once to trigger the compilation
26+
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
27+
result_after = m.forward(*inputs)
28+
option._out_fx_graphs[0].print_tabular()
29+
30+
# Check the graph has be rewritten and contain ttnn ops
31+
nodes = list(option._out_fx_graphs[0].nodes)
32+
assert [node.target for node in nodes].count(ttnn.zeros_like) == 1
33+
# Check inference result
34+
assert torch.allclose(result_before, result_after)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
import torch_ttnn
3+
import pytest
4+
import ttnn
5+
6+
7+
class AddModule(torch.nn.Module):
8+
def __init__(self):
9+
super().__init__()
10+
11+
def forward(self, x):
12+
return x + x
13+
14+
15+
@pytest.mark.parametrize(
16+
"input_shapes",
17+
[
18+
[(4, 4)],
19+
],
20+
)
21+
def test_add(device, input_shapes):
22+
m = AddModule()
23+
inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes]
24+
result_before = m.forward(*inputs)
25+
option = torch_ttnn.TorchTtnnOption(device=device)
26+
# The compilation is lazy, so we need to run forward once to trigger the compilation
27+
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
28+
result_after = m.forward(*inputs)
29+
assert 1 == len(option._out_fx_graphs)
30+
option._out_fx_graphs[0].print_tabular()
31+
# Check the graph has be rewritten and contain ttnn ops
32+
nodes = list(option._out_fx_graphs[0].nodes)
33+
assert [node.target for node in nodes].count(ttnn.add) == 1
34+
# Check inference result
35+
assert torch.allclose(result_before, result_after)

0 commit comments

Comments
 (0)