forked from harvardnlp/pytorch-struct
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfull_cky_crf.py
104 lines (86 loc) · 5.02 KB
/
full_cky_crf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
from .helpers import _Struct, Chart
from tqdm import tqdm
A, B = 0, 1
class Full_CKY_CRF(_Struct):
def _check_potentials(self, edge, lengths=None):
batch, N, N1, N2, NT, NT1, NT2 = self._get_dimension(edge)
assert (
N == N1 == N2 and NT == NT1 == NT2
), f"Want N:{N} == N1:{N1} == N2:{N2} and NT:{NT} == NT1:{NT1} == NT2:{NT2}"
edge = self.semiring.convert(edge)
semiring_shape = edge.shape[:-7]
if lengths is None:
lengths = torch.LongTensor([N] * batch).to(edge.device)
return edge, semiring_shape, batch, N, NT, lengths
def _dp(self, scores, lengths=None, force_grad=False, cache=True):
sr = self.semiring
# Scores.shape = *sshape, B, N, N, N, NT, NT, NT
# w/ semantics [ *semiring stuff, b, i, j, k, A, B, C]
# where b is batch index, i is left endpoint, j is right endpoint, k is splitpoint, with rule A -> B C
scores, sshape, batch, N, NT, lengths = self._check_potentials(scores, lengths)
sshape, sdims = list(sshape), list(range(len(sshape))) # usually [0]
S, b = len(sdims), batch
# Initialize data structs
LEFT, RIGHT = 0, 1
L_DIM, R_DIM = S + 1, S + 2 # one and two to the right of the batch dim
# Initialize the base cases with scores from diagonal i=j=k, A=B=C
term_scores = (
scores.diagonal(0, L_DIM, R_DIM) # diag i,j now at dim -1
.diagonal(0, L_DIM, -1) # diag of k with that gives i=j=k, now at dim -1
.diagonal(0, -4, -3) # diag of A, B, now at dim -1, ijk moves to -2
.diagonal(0, -3, -1) # diag of C with that gives A=B=C
)
assert term_scores.shape[S + 1 :] == (N, NT), f"{term_scores.shape[S + 1 :]} == {(N, NT)}"
alpha_left = term_scores
alpha_right = term_scores
alphas = [[alpha_left], [alpha_right]]
# Run vectorized inside alg
for w in range(1, N):
# Scores
# What we want is a tensor with:
# shape: *sshape, batch, (N-w), NT, w, NT, NT
# w/ semantics: [...batch, (i,j=i+w), A, k, B, C]
# where (i,j=i+w) means the diagonal of trees nodes with width w
# Shape: *sshape, batch, N, NT, NT, NT, (N-w) w/ semantics [ ...batch, k, A, B, C, (i,j=i+w)]
score = scores.diagonal(w, L_DIM, R_DIM) # get diagonal scores
score = score.permute(sdims + [-6, -1, -4, -5, -3, -2]) # move diag (-1) dim and head NT (-4) dim to front
score = score[..., :w, :, :] # remove illegal splitpoints
assert score.shape[S:] == (batch, N - w, NT, w, NT, NT), f"{score.shape[S:]} == {(b, N-w, NT, w, NT, NT)}"
# Sums of left subtrees
# Shape: *sshape, batch, (N-w), w, NT
# where L[..., i, d, B] is the sum of subtrees up to (i,j=(i+d),B)
left = slice(None, N - w) # left indices
L = torch.stack(alphas[LEFT][:w], dim=-2)[..., left, :, :]
# Sums of right subtrees
# Shape: *sshape, batch, (N-w), w, NT
# where R[..., h, d, C] is the sum of subtrees up to (i=(N-h-d),j=(N-h),C)
right = slice(w, None) # right indices
R = torch.stack(list(reversed(alphas[RIGHT][:w])), dim=-2)[..., right, :, :]
# Broadcast them both to match missing dims in score
# Left B is duplicated for all head and right symbols A C
L_bcast = L.reshape(list(sshape) + [b, N - w, 1, w, NT, 1]).repeat(S * [1] + [1, 1, NT, 1, 1, NT])
# Right C is duplicated for all head and left symbols A B
R_bcast = R.reshape(list(sshape) + [b, N - w, 1, w, 1, NT]).repeat(S * [1] + [1, 1, NT, 1, NT, 1])
assert score.shape == L_bcast.shape == R_bcast.shape == tuple(list(sshape) + [b, N - w, NT, w, NT, NT])
# Now multiply all the scores and sum over k, B, C dimensions (the last three dims)
assert sr.times(score, L_bcast, R_bcast).shape == tuple(list(sshape) + [b, N - w, NT, w, NT, NT])
sum_prod_w = sr.sum(sr.sum(sr.sum(sr.times(score, L_bcast, R_bcast))))
assert sum_prod_w.shape[S:] == (b, N - w, NT), f"{sum_prod_w.shape[S:]} == {(b,N-w, NT)}"
pad = sr.zero_(torch.ones(sshape + [b, w, NT]).to(sum_prod_w.device))
sum_prod_w_left = torch.cat([sum_prod_w, pad], dim=-2)
sum_prod_w_right = torch.cat([pad, sum_prod_w], dim=-2)
alphas[LEFT].append(sum_prod_w_left)
alphas[RIGHT].append(sum_prod_w_right)
final = sr.sum(torch.stack(alphas[LEFT], dim=-2))[..., 0, :] # sum out root symbol
log_Z = final[:, torch.arange(batch), lengths - 1]
return log_Z, [scores], alphas
@staticmethod
def _rand():
batch = torch.randint(2, 5, (1,))
N = torch.randint(2, 5, (1,))
NT = torch.randint(2, 5, (1,))
scores = torch.rand(batch, N, N, N, NT, NT, NT)
return scores, (batch.item(), N.item())
def enumerate(self, scores, lengths=None):
raise NotImplementedError