Skip to content

Commit

Permalink
Change input shapes of some unit test to match exceptions in current …
Browse files Browse the repository at this point in the history
…state of lowering
  • Loading branch information
kevinwuTT committed Aug 12, 2024
1 parent 3047406 commit e7f1936
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
8 changes: 5 additions & 3 deletions tests/lowering/eltwise/binary/test_div.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ def forward(self, numerator, denominator):

@pytest.mark.parametrize(
"input_shapes",
[[(4, 4), (4, 4)]],
# [[(4, 4), (4, 4)]],
[[(64, 128), (64, 128)]],
)
def test_div(device, input_shapes):
m = DivModule()
inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes]
inputs = [torch.randint(1, 15, shape).to(torch.bfloat16) for shape in input_shapes]
result_before = m.forward(*inputs)
option = torch_ttnn.TorchTtnnOption(device=device)
option.gen_graphviz = True
Expand All @@ -45,7 +46,8 @@ def test_div(device, input_shapes):

@pytest.mark.parametrize(
"input_shapes",
[[(4, 4)]],
# [[(4, 4)]],
[[(32, 32)]],
)
def test_div_scalar_denom(device, input_shapes):
m = DivModule()
Expand Down
3 changes: 2 additions & 1 deletion tests/lowering/eltwise/binary/test_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def test_rsub(device, input_shapes):

@pytest.mark.parametrize(
"input_shapes",
[[(4, 4)]],
# [[(4, 4)]],
[[(32, 32)]],
)
def test_rsub_scalar(device, input_shapes):
m = RSubScalarModule()
Expand Down
6 changes: 4 additions & 2 deletions tests/lowering/tensor_manipulation/test_unsqueeze.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ def forward(self, x, y):

@pytest.mark.parametrize(
"input_shape, dim",
[((5, 2, 4, 3), 1)],
# [((5, 2, 4, 3), 1)],
[((32, 32, 32, 32), 1)],
)
def test_unsqueeze1(device, input_shape, dim):
mod = UnsqueezeModule()
Expand Down Expand Up @@ -64,7 +65,8 @@ def test_unsqueeze2(device, input_shape, dim):

@pytest.mark.parametrize(
"input_shape, dim",
[((5, 2, 4, 3), -2)],
# [((5, 2, 4, 3), -2)],
[((32, 32, 32, 32), -2)],
)
def test_unsqueeze3(device, input_shape, dim):
mod = UnsqueezeModule()
Expand Down

0 comments on commit e7f1936

Please sign in to comment.