diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index 7bc525de..b80cd91a 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -7,6 +7,7 @@ from .alignment import Alignment from .deptree import DepTree, deptree_nonproj, deptree_part from .cky_crf import CKY_CRF +from .full_cky_crf import Full_CKY_CRF from .semirings import ( LogSemiring, MaxSemiring, @@ -19,7 +20,6 @@ ) - class StructDistribution(Distribution): r""" Base structured distribution class. @@ -69,7 +69,6 @@ def log_prob(self, value): batch_dims=batch_dims, ) - return v - self.partition @lazy_property @@ -91,7 +90,9 @@ def cross_entropy(self, other): cross entropy (*batch_shape*) """ - return self._struct(CrossEntropySemiring).sum([self.log_potentials, other.log_potentials], self.lengths) + return self._struct(CrossEntropySemiring).sum( + [self.log_potentials, other.log_potentials], self.lengths + ) def kl(self, other): """ @@ -100,7 +101,9 @@ def kl(self, other): Returns: cross entropy (*batch_shape*) """ - return self._struct(KLDivergenceSemiring).sum([self.log_potentials, other.log_potentials], self.lengths) + return self._struct(KLDivergenceSemiring).sum( + [self.log_potentials, other.log_potentials], self.lengths + ) @lazy_property def max(self): @@ -166,10 +169,8 @@ def marginals(self): def count(self): "Compute the log-partition function." ones = torch.ones_like(self.log_potentials) - ones[self.log_potentials.eq(-float('inf'))] = 0 - return self._struct(StdSemiring).sum( - ones, self.lengths - ) + ones[self.log_potentials.eq(-float("inf"))] = 0 + return self._struct(StdSemiring).sum(ones, self.lengths) # @constraints.dependent_property # def support(self): @@ -379,7 +380,6 @@ def __init__(self, log_potentials, lengths=None, args={}, multiroot=True): setattr(self.struct, "multiroot", multiroot) - class TreeCRF(StructDistribution): r""" Represents a 0th-order span parser with NT nonterminals. Implemented using a @@ -406,6 +406,22 @@ class TreeCRF(StructDistribution): struct = CKY_CRF +class FullTreeCRF(StructDistribution): + r""" + Represents a 1st-order span parser with NT nonterminals. Implemented using a vectorized inside algorithm. + For a mathematical description see: + * Inside-Outside Algorithm, by Michael Collins: http://www.cs.columbia.edu/~mcollins/io.pdf + + Parameters: + log_potentials (tensor) : event_shape (*N x N x N x NT x NT x NT*), e.g. + :math:`\phi(i, j, k, A_i^j \rightarrow B_i^k C_{k+1}^j)` + lengths (long tensor) : batch shape integers for length masking. + + Compact representation: *N x N x N xNT x NT x NT* long tensor (Same) + """ + struct = Full_CKY_CRF + + class SentCFG(StructDistribution): """ Represents a full generative context-free grammar with diff --git a/torch_struct/full_cky_crf.py b/torch_struct/full_cky_crf.py new file mode 100644 index 00000000..16347c60 --- /dev/null +++ b/torch_struct/full_cky_crf.py @@ -0,0 +1,104 @@ +import torch +from .helpers import _Struct, Chart +from tqdm import tqdm + +A, B = 0, 1 + + +class Full_CKY_CRF(_Struct): + def _check_potentials(self, edge, lengths=None): + batch, N, N1, N2, NT, NT1, NT2 = self._get_dimension(edge) + assert ( + N == N1 == N2 and NT == NT1 == NT2 + ), f"Want N:{N} == N1:{N1} == N2:{N2} and NT:{NT} == NT1:{NT1} == NT2:{NT2}" + edge = self.semiring.convert(edge) + semiring_shape = edge.shape[:-7] + if lengths is None: + lengths = torch.LongTensor([N] * batch).to(edge.device) + + return edge, semiring_shape, batch, N, NT, lengths + + def _dp(self, scores, lengths=None, force_grad=False, cache=True): + sr = self.semiring + + # Scores.shape = *sshape, B, N, N, N, NT, NT, NT + # w/ semantics [ *semiring stuff, b, i, j, k, A, B, C] + # where b is batch index, i is left endpoint, j is right endpoint, k is splitpoint, with rule A -> B C + scores, sshape, batch, N, NT, lengths = self._check_potentials(scores, lengths) + sshape, sdims = list(sshape), list(range(len(sshape))) # usually [0] + S, b = len(sdims), batch + + # Initialize data structs + LEFT, RIGHT = 0, 1 + L_DIM, R_DIM = S + 1, S + 2 # one and two to the right of the batch dim + + # Initialize the base cases with scores from diagonal i=j=k, A=B=C + term_scores = ( + scores.diagonal(0, L_DIM, R_DIM) # diag i,j now at dim -1 + .diagonal(0, L_DIM, -1) # diag of k with that gives i=j=k, now at dim -1 + .diagonal(0, -4, -3) # diag of A, B, now at dim -1, ijk moves to -2 + .diagonal(0, -3, -1) # diag of C with that gives A=B=C + ) + assert term_scores.shape[S + 1 :] == (N, NT), f"{term_scores.shape[S + 1 :]} == {(N, NT)}" + alpha_left = term_scores + alpha_right = term_scores + alphas = [[alpha_left], [alpha_right]] + + # Run vectorized inside alg + for w in range(1, N): + # Scores + # What we want is a tensor with: + # shape: *sshape, batch, (N-w), NT, w, NT, NT + # w/ semantics: [...batch, (i,j=i+w), A, k, B, C] + # where (i,j=i+w) means the diagonal of trees nodes with width w + # Shape: *sshape, batch, N, NT, NT, NT, (N-w) w/ semantics [ ...batch, k, A, B, C, (i,j=i+w)] + score = scores.diagonal(w, L_DIM, R_DIM) # get diagonal scores + + score = score.permute(sdims + [-6, -1, -4, -5, -3, -2]) # move diag (-1) dim and head NT (-4) dim to front + score = score[..., :w, :, :] # remove illegal splitpoints + assert score.shape[S:] == (batch, N - w, NT, w, NT, NT), f"{score.shape[S:]} == {(b, N-w, NT, w, NT, NT)}" + + # Sums of left subtrees + # Shape: *sshape, batch, (N-w), w, NT + # where L[..., i, d, B] is the sum of subtrees up to (i,j=(i+d),B) + left = slice(None, N - w) # left indices + L = torch.stack(alphas[LEFT][:w], dim=-2)[..., left, :, :] + + # Sums of right subtrees + # Shape: *sshape, batch, (N-w), w, NT + # where R[..., h, d, C] is the sum of subtrees up to (i=(N-h-d),j=(N-h),C) + right = slice(w, None) # right indices + R = torch.stack(list(reversed(alphas[RIGHT][:w])), dim=-2)[..., right, :, :] + + # Broadcast them both to match missing dims in score + # Left B is duplicated for all head and right symbols A C + L_bcast = L.reshape(list(sshape) + [b, N - w, 1, w, NT, 1]).repeat(S * [1] + [1, 1, NT, 1, 1, NT]) + # Right C is duplicated for all head and left symbols A B + R_bcast = R.reshape(list(sshape) + [b, N - w, 1, w, 1, NT]).repeat(S * [1] + [1, 1, NT, 1, NT, 1]) + assert score.shape == L_bcast.shape == R_bcast.shape == tuple(list(sshape) + [b, N - w, NT, w, NT, NT]) + + # Now multiply all the scores and sum over k, B, C dimensions (the last three dims) + assert sr.times(score, L_bcast, R_bcast).shape == tuple(list(sshape) + [b, N - w, NT, w, NT, NT]) + sum_prod_w = sr.sum(sr.sum(sr.sum(sr.times(score, L_bcast, R_bcast)))) + assert sum_prod_w.shape[S:] == (b, N - w, NT), f"{sum_prod_w.shape[S:]} == {(b,N-w, NT)}" + + pad = sr.zero_(torch.ones(sshape + [b, w, NT]).to(sum_prod_w.device)) + sum_prod_w_left = torch.cat([sum_prod_w, pad], dim=-2) + sum_prod_w_right = torch.cat([pad, sum_prod_w], dim=-2) + alphas[LEFT].append(sum_prod_w_left) + alphas[RIGHT].append(sum_prod_w_right) + + final = sr.sum(torch.stack(alphas[LEFT], dim=-2))[..., 0, :] # sum out root symbol + log_Z = final[:, torch.arange(batch), lengths - 1] + return log_Z, [scores], alphas + + @staticmethod + def _rand(): + batch = torch.randint(2, 5, (1,)) + N = torch.randint(2, 5, (1,)) + NT = torch.randint(2, 5, (1,)) + scores = torch.rand(batch, N, N, N, NT, NT, NT) + return scores, (batch.item(), N.item()) + + def enumerate(self, scores, lengths=None): + raise NotImplementedError \ No newline at end of file diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index bdad6678..e41f51e2 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -161,30 +161,31 @@ def marginals(self, edge, lengths=None, _autograd=True, _raw=False): or self.semiring is not LogSemiring or not hasattr(self, "_dp_backward") ): - v, edges, _ = self._dp( - edge, lengths=lengths, force_grad=True, cache=not _raw - ) - if _raw: - all_m = [] - for k in range(v.shape[0]): - obj = v[k].sum(dim=0) - + with torch.enable_grad(): # allows marginals even when input tensors don't need grad + v, edges, _ = self._dp( + edge, lengths=lengths, force_grad=True, cache=not _raw + ) + if _raw: + all_m = [] + for k in range(v.shape[0]): + obj = v[k].sum(dim=0) + + marg = torch.autograd.grad( + obj, + edges, + create_graph=True, + only_inputs=True, + allow_unused=False, + ) + all_m.append(self.semiring.unconvert(self._arrange_marginals(marg))) + return torch.stack(all_m, dim=0) + else: + obj = self.semiring.unconvert(v).sum(dim=0) marg = torch.autograd.grad( - obj, - edges, - create_graph=True, - only_inputs=True, - allow_unused=False, + obj, edges, create_graph=True, only_inputs=True, allow_unused=False ) - all_m.append(self.semiring.unconvert(self._arrange_marginals(marg))) - return torch.stack(all_m, dim=0) - else: - obj = self.semiring.unconvert(v).sum(dim=0) - marg = torch.autograd.grad( - obj, edges, create_graph=True, only_inputs=True, allow_unused=False - ) - a_m = self._arrange_marginals(marg) - return self.semiring.unconvert(a_m) + a_m = self._arrange_marginals(marg) + return self.semiring.unconvert(a_m) else: v, _, alpha = self._dp(edge, lengths=lengths, force_grad=True) return self._dp_backward(edge, lengths, alpha) @@ -198,4 +199,4 @@ def from_parts(spans): return spans, None def _arrange_marginals(self, marg): - return marg[0] + return marg[0] \ No newline at end of file