Skip to content

Commit

Permalink
add test for spider
Browse files Browse the repository at this point in the history
  • Loading branch information
kinianlo committed Sep 27, 2024
1 parent 7d0b1d6 commit cfd0a70
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions tests/backend/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import lambeq.backend.grammar as grammar
from lambeq.backend.tensor import *

@pytest.fixture
def spider():
return Spider(Dim(3), 5, 3)

def test_Ty():
assert Dim(1,1,1,1,1) == Dim(1) == Dim()
Expand Down Expand Up @@ -85,3 +88,22 @@ def test_lambdify():

assert bx1.lambdify(a,b,c,d)(1,2,3,4) == bx1_concrete
assert (bx1 >> bx2).lambdify(a,b,c,d)(1,2,3,4) == bx1_concrete >> bx2_concrete

def test_to_tn_spider_unfuse(spider):
nodes, edges = spider.to_tn()

assert len(edges) == spider.n_legs_in + spider.n_legs_out
assert all(node.rank <= 3 for node in nodes)

def test_spider_eval(spider):
n_legs = 8
dim = 3

expected = np.zeros(tuple(dim for _ in range(n_legs)))
for i in range(dim):
expected[tuple(i for _ in range(n_legs))] = 1

result = spider.eval()

assert result.shape == tuple(dim for _ in range(n_legs))
assert np.allclose(result, expected)

0 comments on commit cfd0a70

Please sign in to comment.