From cfd0a7045e5be716dfcd47e3ba702fe8accaeefa Mon Sep 17 00:00:00 2001 From: Kin Ian Lo Date: Fri, 27 Sep 2024 16:03:53 +0100 Subject: [PATCH] add test for spider --- tests/backend/test_tensor.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/backend/test_tensor.py b/tests/backend/test_tensor.py index cc189b73..22c9bd47 100644 --- a/tests/backend/test_tensor.py +++ b/tests/backend/test_tensor.py @@ -6,6 +6,9 @@ import lambeq.backend.grammar as grammar from lambeq.backend.tensor import * +@pytest.fixture +def spider(): + return Spider(Dim(3), 5, 3) def test_Ty(): assert Dim(1,1,1,1,1) == Dim(1) == Dim() @@ -85,3 +88,22 @@ def test_lambdify(): assert bx1.lambdify(a,b,c,d)(1,2,3,4) == bx1_concrete assert (bx1 >> bx2).lambdify(a,b,c,d)(1,2,3,4) == bx1_concrete >> bx2_concrete + +def test_to_tn_spider_unfuse(spider): + nodes, edges = spider.to_tn() + + assert len(edges) == spider.n_legs_in + spider.n_legs_out + assert all(node.rank <= 3 for node in nodes) + +def test_spider_eval(spider): + n_legs = 8 + dim = 3 + + expected = np.zeros(tuple(dim for _ in range(n_legs))) + for i in range(dim): + expected[tuple(i for _ in range(n_legs))] = 1 + + result = spider.eval() + + assert result.shape == tuple(dim for _ in range(n_legs)) + assert np.allclose(result, expected) \ No newline at end of file