Skip to content

Commit e7f1936

Browse files
committed
Change input shapes of some unit test to match exceptions in current state of lowering
1 parent 3047406 commit e7f1936

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

tests/lowering/eltwise/binary/test_div.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@ def forward(self, numerator, denominator):
1515

1616
@pytest.mark.parametrize(
1717
"input_shapes",
18-
[[(4, 4), (4, 4)]],
18+
# [[(4, 4), (4, 4)]],
19+
[[(64, 128), (64, 128)]],
1920
)
2021
def test_div(device, input_shapes):
2122
m = DivModule()
22-
inputs = [torch.rand(shape, dtype=torch.bfloat16) for shape in input_shapes]
23+
inputs = [torch.randint(1, 15, shape).to(torch.bfloat16) for shape in input_shapes]
2324
result_before = m.forward(*inputs)
2425
option = torch_ttnn.TorchTtnnOption(device=device)
2526
option.gen_graphviz = True
@@ -45,7 +46,8 @@ def test_div(device, input_shapes):
4546

4647
@pytest.mark.parametrize(
4748
"input_shapes",
48-
[[(4, 4)]],
49+
# [[(4, 4)]],
50+
[[(32, 32)]],
4951
)
5052
def test_div_scalar_denom(device, input_shapes):
5153
m = DivModule()

tests/lowering/eltwise/binary/test_sub.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def test_rsub(device, input_shapes):
8080

8181
@pytest.mark.parametrize(
8282
"input_shapes",
83-
[[(4, 4)]],
83+
# [[(4, 4)]],
84+
[[(32, 32)]],
8485
)
8586
def test_rsub_scalar(device, input_shapes):
8687
m = RSubScalarModule()

tests/lowering/tensor_manipulation/test_unsqueeze.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ def forward(self, x, y):
1717

1818
@pytest.mark.parametrize(
1919
"input_shape, dim",
20-
[((5, 2, 4, 3), 1)],
20+
# [((5, 2, 4, 3), 1)],
21+
[((32, 32, 32, 32), 1)],
2122
)
2223
def test_unsqueeze1(device, input_shape, dim):
2324
mod = UnsqueezeModule()
@@ -64,7 +65,8 @@ def test_unsqueeze2(device, input_shape, dim):
6465

6566
@pytest.mark.parametrize(
6667
"input_shape, dim",
67-
[((5, 2, 4, 3), -2)],
68+
# [((5, 2, 4, 3), -2)],
69+
[((32, 32, 32, 32), -2)],
6870
)
6971
def test_unsqueeze3(device, input_shape, dim):
7072
mod = UnsqueezeModule()

0 commit comments

Comments
 (0)