Skip to content

Commit 29cc28f

Browse files
authored
Check cases in docs/operations/aten.lt.Scalar.md (#624)
1 parent 8edefa3 commit 29cc28f

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

tests/lowering/eltwise/binary/test_lt.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,23 @@ def test_lt_tensor(device, input_shapes):
4141
assert torch.allclose(result_before, result_after.to(torch.bool))
4242

4343

44-
@pytest.mark.parametrize("input_shape", ((64, 128), (1, 1, 256), (2, 377, 355)))
44+
@pytest.mark.parametrize(
45+
"input_shape",
46+
(
47+
(64, 128),
48+
(1, 1, 256),
49+
(2, 377, 355),
50+
(1, 1),
51+
(10, 10),
52+
(15, 15),
53+
(17, 17),
54+
(2, 2),
55+
),
56+
)
4557
def test_lt_scalar(device, input_shape):
4658
m = LtModule()
4759
input = torch.randint(0, 10, input_shape, dtype=torch.bfloat16)
48-
scalar = torch.randint(0, 10, ()).item()
60+
scalar = 5
4961
result_before = m.forward(input, scalar)
5062
option = torch_ttnn.TorchTtnnOption(device=device)
5163
option.gen_graphviz = True
@@ -55,8 +67,9 @@ def test_lt_scalar(device, input_shape):
5567
option._out_fx_graphs[0].print_tabular()
5668

5769
# Check the graph has be rewritten and contain ttnn ops
58-
nodes = list(option._out_fx_graphs[0].nodes)
59-
assert [node.target for node in nodes].count(ttnn.lt) == 1
70+
nodes = [node.target for node in option._out_fx_graphs[0].nodes]
71+
assert torch.ops.aten.lt.Scalar not in nodes
72+
assert nodes.count(ttnn.lt) == 1
6073

6174
# Check inference result
6275
assert torch.allclose(result_before, result_after.to(torch.bool))

0 commit comments

Comments
 (0)