Skip to content

Commit

Permalink
Separate tests into individual files and switch to Pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinwuTT committed Jul 22, 2024
1 parent ffd89d1 commit 344f844
Show file tree
Hide file tree
Showing 53 changed files with 2,544 additions and 2,808 deletions.
17 changes: 17 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
import ttnn
import torch


@pytest.fixture(scope="session")
def device():
device = ttnn.open_device(device_id=0)
yield device
ttnn.close_device(device)


@pytest.fixture(autouse=True)
def reset_torch_dynamo():
# PyTorch caches models. Start a fresh compile for each parameter of the test case.
torch._dynamo.reset()
yield
95 changes: 95 additions & 0 deletions tests/lowering/creation/test_arange.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
import torch_ttnn
import pytest
import ttnn


class ArangeModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, end):
# start = 0, step = 1
return torch.arange(end)


class ArangeStartModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, start, end):
# step = 1
return torch.arange(start, end)


class ArangeStartStepModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, start, end, step):
return torch.arange(start, end, step)


# NOTE(kevinwuTT) This test fails because ttnn.arange does not support start value of 0.
@pytest.mark.xfail
@pytest.mark.parametrize(
"input_shapes",
[[100]],
)
def test_arange(device, input_shapes):
m = ArangeModule()
result_before = m.forward(*input_shapes).to(torch.bfloat16)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*input_shapes)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.arange) == 1
# Check inference result
assert torch.allclose(result_before, result_after)


@pytest.mark.parametrize(
"input_shapes",
[[2, 100]],
)
def test_arange_start(device, input_shapes):
m = ArangeStartModule()
result_before = m.forward(*input_shapes).to(torch.bfloat16)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*input_shapes)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.arange) == 1
# Check inference result
assert torch.allclose(result_before, result_after)


@pytest.mark.parametrize(
"input_shapes",
[[4, 100, 3]],
)
def test_arange_start_step(device, input_shapes):
m = ArangeStartStepModule()
result_before = m.forward(*input_shapes).to(torch.bfloat16)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*input_shapes)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.arange) == 1
# Check inference result
assert torch.allclose(result_before, result_after)
72 changes: 72 additions & 0 deletions tests/lowering/creation/test_clone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch
import torch_ttnn
import pytest
import ttnn

from tests.utils import check_with_pcc


class CloneFromNodeModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
a = input + input
return torch.clone(a)


class CloneFromArgModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.clone(input)


@pytest.mark.parametrize(
"input_shapes",
[[(4, 4)]],
)
def test_clone_from_arg(device, input_shapes):
m = CloneFromArgModule()
inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = False
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.clone) == 1
# Check inference result
assert torch.allclose(result_before, result_after)


@pytest.mark.parametrize(
"input_shapes",
[[(4, 4)]],
)
def test_clone_from_node(device, input_shapes):
m = CloneFromNodeModule()
inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = False
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
target = [node.target for node in nodes]
assert target.count(ttnn.clone) == 1
clone_arg_0 = nodes[target.index(ttnn.clone)].args[0].target
assert isinstance(clone_arg_0, ttnn.decorators.FastOperation) or isinstance(
clone_arg_0, ttnn.decorators.Operation
)
# Check inference result
assert torch.allclose(result_before, result_after)
34 changes: 34 additions & 0 deletions tests/lowering/creation/test_full.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
import torch_ttnn
import pytest
import ttnn


class FullModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, size, fill_value):
return torch.full(size, fill_value)


@pytest.mark.parametrize(
"input_shapes",
[[(64, 128)]],
)
def test_full(device, input_shapes):
m = FullModule()
fill_value = 1.23
result_before = m.forward(input_shapes[0], fill_value).to(torch.bfloat16)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(input_shapes[0], fill_value)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.full) == 1
# Check inference result
assert torch.allclose(result_before, result_after)
34 changes: 34 additions & 0 deletions tests/lowering/creation/test_ones.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
import torch_ttnn
import pytest
import ttnn


class OnesModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, shape):
return torch.ones(shape)


@pytest.mark.parametrize(
"input_shapes",
[(32, 32)],
)
def test_ones(device, input_shapes):
m = OnesModule()
result_before = m.forward(input_shapes)
result_before = result_before.to(torch.bfloat16)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(input_shapes)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.ones) == 1
# Check inference result
assert torch.allclose(result_before, result_after)
70 changes: 70 additions & 0 deletions tests/lowering/creation/test_to_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import torch
import torch_ttnn
import pytest
import ttnn


class ToCopyModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.to(torch.bfloat16)


class ToCopyWithOpAfterModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
to = x.to(torch.bfloat16)
return torch.add(to, to)


@pytest.mark.parametrize(
"input_shapes",
[[(4, 4)]],
)
def test_to_copy(device, input_shapes):
m = ToCopyModule()
inputs = [torch.rand(shape) for shape in input_shapes]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.as_tensor) == 1
# Check inference result
assert torch.allclose(result_before, result_after, rtol=0.2)


@pytest.mark.parametrize(
"input_shapes",
[[(4, 4)]],
)
def test_to_copy_with_op_after(device, input_shapes):
m = ToCopyWithOpAfterModule()
inputs = [torch.rand(shape) for shape in input_shapes]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
target = [node.target for node in nodes]
assert target.count(ttnn.as_tensor) == 1
assert target.count(ttnn.add) == 1
add_node = nodes[target.index(ttnn.add)]
assert add_node.args[0].target == ttnn.as_tensor
assert add_node.args[1].target == ttnn.as_tensor
# Check inference result
assert torch.allclose(result_before, result_after, rtol=0.2)
34 changes: 34 additions & 0 deletions tests/lowering/creation/test_zeros_like.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
import torch_ttnn
import pytest
import ttnn


class ZerosLikeModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input):
return torch.zeros_like(input)


@pytest.mark.parametrize(
"input_shapes",
[[(4, 4)]],
)
def test_zeros_like(device, input_shapes):
m = ZerosLikeModule()
inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.zeros_like) == 1
# Check inference result
assert torch.allclose(result_before, result_after)
35 changes: 35 additions & 0 deletions tests/lowering/eltwise/binary/test_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
import torch_ttnn
import pytest
import ttnn


class AddModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x + x


@pytest.mark.parametrize(
"input_shapes",
[
[(4, 4)],
],
)
def test_add(device, input_shapes):
m = AddModule()
inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=device)
# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
result_after = m.forward(*inputs)
assert 1 == len(option._out_fx_graphs)
option._out_fx_graphs[0].print_tabular()
# Check the graph has be rewritten and contain ttnn ops
nodes = list(option._out_fx_graphs[0].nodes)
assert [node.target for node in nodes].count(ttnn.add) == 1
# Check inference result
assert torch.allclose(result_before, result_after)
Loading

0 comments on commit 344f844

Please sign in to comment.