-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Separate tests into individual files and switch to Pytest
- Loading branch information
Showing
53 changed files
with
2,544 additions
and
2,808 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import pytest | ||
import ttnn | ||
import torch | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def device(): | ||
device = ttnn.open_device(device_id=0) | ||
yield device | ||
ttnn.close_device(device) | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def reset_torch_dynamo(): | ||
# PyTorch caches models. Start a fresh compile for each parameter of the test case. | ||
torch._dynamo.reset() | ||
yield |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import torch | ||
import torch_ttnn | ||
import pytest | ||
import ttnn | ||
|
||
|
||
class ArangeModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, end): | ||
# start = 0, step = 1 | ||
return torch.arange(end) | ||
|
||
|
||
class ArangeStartModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, start, end): | ||
# step = 1 | ||
return torch.arange(start, end) | ||
|
||
|
||
class ArangeStartStepModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, start, end, step): | ||
return torch.arange(start, end, step) | ||
|
||
|
||
# NOTE(kevinwuTT) This test fails because ttnn.arange does not support start value of 0. | ||
@pytest.mark.xfail | ||
@pytest.mark.parametrize( | ||
"input_shapes", | ||
[[100]], | ||
) | ||
def test_arange(device, input_shapes): | ||
m = ArangeModule() | ||
result_before = m.forward(*input_shapes).to(torch.bfloat16) | ||
option = torch_ttnn.TorchTtnnOption(device=device) | ||
option.gen_graphviz = True | ||
# The compilation is lazy, so we need to run forward once to trigger the compilation | ||
m = torch.compile(m, backend=torch_ttnn.backend, options=option) | ||
result_after = m.forward(*input_shapes) | ||
option._out_fx_graphs[0].print_tabular() | ||
|
||
# Check the graph has be rewritten and contain ttnn ops | ||
nodes = list(option._out_fx_graphs[0].nodes) | ||
assert [node.target for node in nodes].count(ttnn.arange) == 1 | ||
# Check inference result | ||
assert torch.allclose(result_before, result_after) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shapes", | ||
[[2, 100]], | ||
) | ||
def test_arange_start(device, input_shapes): | ||
m = ArangeStartModule() | ||
result_before = m.forward(*input_shapes).to(torch.bfloat16) | ||
option = torch_ttnn.TorchTtnnOption(device=device) | ||
option.gen_graphviz = True | ||
# The compilation is lazy, so we need to run forward once to trigger the compilation | ||
m = torch.compile(m, backend=torch_ttnn.backend, options=option) | ||
result_after = m.forward(*input_shapes) | ||
option._out_fx_graphs[0].print_tabular() | ||
|
||
# Check the graph has be rewritten and contain ttnn ops | ||
nodes = list(option._out_fx_graphs[0].nodes) | ||
assert [node.target for node in nodes].count(ttnn.arange) == 1 | ||
# Check inference result | ||
assert torch.allclose(result_before, result_after) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shapes", | ||
[[4, 100, 3]], | ||
) | ||
def test_arange_start_step(device, input_shapes): | ||
m = ArangeStartStepModule() | ||
result_before = m.forward(*input_shapes).to(torch.bfloat16) | ||
option = torch_ttnn.TorchTtnnOption(device=device) | ||
option.gen_graphviz = True | ||
# The compilation is lazy, so we need to run forward once to trigger the compilation | ||
m = torch.compile(m, backend=torch_ttnn.backend, options=option) | ||
result_after = m.forward(*input_shapes) | ||
option._out_fx_graphs[0].print_tabular() | ||
|
||
# Check the graph has be rewritten and contain ttnn ops | ||
nodes = list(option._out_fx_graphs[0].nodes) | ||
assert [node.target for node in nodes].count(ttnn.arange) == 1 | ||
# Check inference result | ||
assert torch.allclose(result_before, result_after) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import torch | ||
import torch_ttnn | ||
import pytest | ||
import ttnn | ||
|
||
from tests.utils import check_with_pcc | ||
|
||
|
||
class CloneFromNodeModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input): | ||
a = input + input | ||
return torch.clone(a) | ||
|
||
|
||
class CloneFromArgModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input): | ||
return torch.clone(input) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shapes", | ||
[[(4, 4)]], | ||
) | ||
def test_clone_from_arg(device, input_shapes): | ||
m = CloneFromArgModule() | ||
inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] | ||
result_before = m.forward(*inputs) | ||
option = torch_ttnn.TorchTtnnOption(device=device) | ||
option.gen_graphviz = False | ||
# The compilation is lazy, so we need to run forward once to trigger the compilation | ||
m = torch.compile(m, backend=torch_ttnn.backend, options=option) | ||
result_after = m.forward(*inputs) | ||
option._out_fx_graphs[0].print_tabular() | ||
|
||
# Check the graph has be rewritten and contain ttnn ops | ||
nodes = list(option._out_fx_graphs[0].nodes) | ||
assert [node.target for node in nodes].count(ttnn.clone) == 1 | ||
# Check inference result | ||
assert torch.allclose(result_before, result_after) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shapes", | ||
[[(4, 4)]], | ||
) | ||
def test_clone_from_node(device, input_shapes): | ||
m = CloneFromNodeModule() | ||
inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] | ||
result_before = m.forward(*inputs) | ||
option = torch_ttnn.TorchTtnnOption(device=device) | ||
option.gen_graphviz = False | ||
# The compilation is lazy, so we need to run forward once to trigger the compilation | ||
m = torch.compile(m, backend=torch_ttnn.backend, options=option) | ||
result_after = m.forward(*inputs) | ||
option._out_fx_graphs[0].print_tabular() | ||
|
||
# Check the graph has be rewritten and contain ttnn ops | ||
nodes = list(option._out_fx_graphs[0].nodes) | ||
target = [node.target for node in nodes] | ||
assert target.count(ttnn.clone) == 1 | ||
clone_arg_0 = nodes[target.index(ttnn.clone)].args[0].target | ||
assert isinstance(clone_arg_0, ttnn.decorators.FastOperation) or isinstance( | ||
clone_arg_0, ttnn.decorators.Operation | ||
) | ||
# Check inference result | ||
assert torch.allclose(result_before, result_after) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import torch | ||
import torch_ttnn | ||
import pytest | ||
import ttnn | ||
|
||
|
||
class FullModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, size, fill_value): | ||
return torch.full(size, fill_value) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shapes", | ||
[[(64, 128)]], | ||
) | ||
def test_full(device, input_shapes): | ||
m = FullModule() | ||
fill_value = 1.23 | ||
result_before = m.forward(input_shapes[0], fill_value).to(torch.bfloat16) | ||
option = torch_ttnn.TorchTtnnOption(device=device) | ||
option.gen_graphviz = True | ||
# The compilation is lazy, so we need to run forward once to trigger the compilation | ||
m = torch.compile(m, backend=torch_ttnn.backend, options=option) | ||
result_after = m.forward(input_shapes[0], fill_value) | ||
option._out_fx_graphs[0].print_tabular() | ||
|
||
# Check the graph has be rewritten and contain ttnn ops | ||
nodes = list(option._out_fx_graphs[0].nodes) | ||
assert [node.target for node in nodes].count(ttnn.full) == 1 | ||
# Check inference result | ||
assert torch.allclose(result_before, result_after) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import torch | ||
import torch_ttnn | ||
import pytest | ||
import ttnn | ||
|
||
|
||
class OnesModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, shape): | ||
return torch.ones(shape) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shapes", | ||
[(32, 32)], | ||
) | ||
def test_ones(device, input_shapes): | ||
m = OnesModule() | ||
result_before = m.forward(input_shapes) | ||
result_before = result_before.to(torch.bfloat16) | ||
option = torch_ttnn.TorchTtnnOption(device=device) | ||
option.gen_graphviz = True | ||
# The compilation is lazy, so we need to run forward once to trigger the compilation | ||
m = torch.compile(m, backend=torch_ttnn.backend, options=option) | ||
result_after = m.forward(input_shapes) | ||
option._out_fx_graphs[0].print_tabular() | ||
|
||
# Check the graph has be rewritten and contain ttnn ops | ||
nodes = list(option._out_fx_graphs[0].nodes) | ||
assert [node.target for node in nodes].count(ttnn.ones) == 1 | ||
# Check inference result | ||
assert torch.allclose(result_before, result_after) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import torch | ||
import torch_ttnn | ||
import pytest | ||
import ttnn | ||
|
||
|
||
class ToCopyModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, x): | ||
return x.to(torch.bfloat16) | ||
|
||
|
||
class ToCopyWithOpAfterModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, x): | ||
to = x.to(torch.bfloat16) | ||
return torch.add(to, to) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shapes", | ||
[[(4, 4)]], | ||
) | ||
def test_to_copy(device, input_shapes): | ||
m = ToCopyModule() | ||
inputs = [torch.rand(shape) for shape in input_shapes] | ||
result_before = m.forward(*inputs) | ||
option = torch_ttnn.TorchTtnnOption(device=device) | ||
option.gen_graphviz = True | ||
# The compilation is lazy, so we need to run forward once to trigger the compilation | ||
m = torch.compile(m, backend=torch_ttnn.backend, options=option) | ||
result_after = m.forward(*inputs) | ||
option._out_fx_graphs[0].print_tabular() | ||
|
||
# Check the graph has be rewritten and contain ttnn ops | ||
nodes = list(option._out_fx_graphs[0].nodes) | ||
assert [node.target for node in nodes].count(ttnn.as_tensor) == 1 | ||
# Check inference result | ||
assert torch.allclose(result_before, result_after, rtol=0.2) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shapes", | ||
[[(4, 4)]], | ||
) | ||
def test_to_copy_with_op_after(device, input_shapes): | ||
m = ToCopyWithOpAfterModule() | ||
inputs = [torch.rand(shape) for shape in input_shapes] | ||
result_before = m.forward(*inputs) | ||
option = torch_ttnn.TorchTtnnOption(device=device) | ||
option.gen_graphviz = True | ||
# The compilation is lazy, so we need to run forward once to trigger the compilation | ||
m = torch.compile(m, backend=torch_ttnn.backend, options=option) | ||
result_after = m.forward(*inputs) | ||
option._out_fx_graphs[0].print_tabular() | ||
|
||
# Check the graph has be rewritten and contain ttnn ops | ||
nodes = list(option._out_fx_graphs[0].nodes) | ||
target = [node.target for node in nodes] | ||
assert target.count(ttnn.as_tensor) == 1 | ||
assert target.count(ttnn.add) == 1 | ||
add_node = nodes[target.index(ttnn.add)] | ||
assert add_node.args[0].target == ttnn.as_tensor | ||
assert add_node.args[1].target == ttnn.as_tensor | ||
# Check inference result | ||
assert torch.allclose(result_before, result_after, rtol=0.2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import torch | ||
import torch_ttnn | ||
import pytest | ||
import ttnn | ||
|
||
|
||
class ZerosLikeModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, input): | ||
return torch.zeros_like(input) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shapes", | ||
[[(4, 4)]], | ||
) | ||
def test_zeros_like(device, input_shapes): | ||
m = ZerosLikeModule() | ||
inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] | ||
result_before = m.forward(*inputs) | ||
option = torch_ttnn.TorchTtnnOption(device=device) | ||
option.gen_graphviz = True | ||
# The compilation is lazy, so we need to run forward once to trigger the compilation | ||
m = torch.compile(m, backend=torch_ttnn.backend, options=option) | ||
result_after = m.forward(*inputs) | ||
option._out_fx_graphs[0].print_tabular() | ||
|
||
# Check the graph has be rewritten and contain ttnn ops | ||
nodes = list(option._out_fx_graphs[0].nodes) | ||
assert [node.target for node in nodes].count(ttnn.zeros_like) == 1 | ||
# Check inference result | ||
assert torch.allclose(result_before, result_after) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import torch | ||
import torch_ttnn | ||
import pytest | ||
import ttnn | ||
|
||
|
||
class AddModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, x): | ||
return x + x | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"input_shapes", | ||
[ | ||
[(4, 4)], | ||
], | ||
) | ||
def test_add(device, input_shapes): | ||
m = AddModule() | ||
inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes] | ||
result_before = m.forward(*inputs) | ||
option = torch_ttnn.TorchTtnnOption(device=device) | ||
# The compilation is lazy, so we need to run forward once to trigger the compilation | ||
m = torch.compile(m, backend=torch_ttnn.backend, options=option) | ||
result_after = m.forward(*inputs) | ||
assert 1 == len(option._out_fx_graphs) | ||
option._out_fx_graphs[0].print_tabular() | ||
# Check the graph has be rewritten and contain ttnn ops | ||
nodes = list(option._out_fx_graphs[0].nodes) | ||
assert [node.target for node in nodes].count(ttnn.add) == 1 | ||
# Check inference result | ||
assert torch.allclose(result_before, result_after) |
Oops, something went wrong.