From ec6a3cbf690c76d44bb46b717d02f494a0826163 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 5 Dec 2020 20:14:11 +0000 Subject: [PATCH 1/3] add 1st order cky implementation and marginals even when potentials don't require grad --- torch_struct/distributions.py | 50 ++++++++-------- torch_struct/full_cky_crf.py | 104 ++++++++++++++++++++++++++++++++++ torch_struct/helpers.py | 73 +++++++++--------------- 3 files changed, 155 insertions(+), 72 deletions(-) create mode 100644 torch_struct/full_cky_crf.py diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index 7bc525de..b9edcd3e 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -7,6 +7,7 @@ from .alignment import Alignment from .deptree import DepTree, deptree_nonproj, deptree_part from .cky_crf import CKY_CRF +from .full_cky_crf import Full_CKY_CRF from .semirings import ( LogSemiring, MaxSemiring, @@ -19,7 +20,6 @@ ) - class StructDistribution(Distribution): r""" Base structured distribution class. @@ -69,7 +69,6 @@ def log_prob(self, value): batch_dims=batch_dims, ) - return v - self.partition @lazy_property @@ -128,9 +127,7 @@ def kmax(self, k): kmax (*k x batch_shape*) """ with torch.enable_grad(): - return self._struct(KMaxSemiring(k)).sum( - self.log_potentials, self.lengths, _raw=True - ) + return self._struct(KMaxSemiring(k)).sum(self.log_potentials, self.lengths, _raw=True) def topk(self, k): r""" @@ -140,9 +137,7 @@ def topk(self, k): kmax (*k x batch_shape x event_shape*) """ with torch.enable_grad(): - return self._struct(KMaxSemiring(k)).marginals( - self.log_potentials, self.lengths, _raw=True - ) + return self._struct(KMaxSemiring(k)).marginals(self.log_potentials, self.lengths, _raw=True) @lazy_property def mode(self): @@ -166,10 +161,8 @@ def marginals(self): def count(self): "Compute the log-partition function." ones = torch.ones_like(self.log_potentials) - ones[self.log_potentials.eq(-float('inf'))] = 0 - return self._struct(StdSemiring).sum( - ones, self.lengths - ) + ones[self.log_potentials.eq(-float("inf"))] = 0 + return self._struct(StdSemiring).sum(ones, self.lengths) # @constraints.dependent_property # def support(self): @@ -199,9 +192,7 @@ def sample(self, sample_shape=torch.Size()): samples = [] for k in range(nsamples): if k % 10 == 0: - sample = self._struct(MultiSampledSemiring).marginals( - self.log_potentials, lengths=self.lengths - ) + sample = self._struct(MultiSampledSemiring).marginals(self.log_potentials, lengths=self.lengths) sample = sample.detach() tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % 10) + 1) samples.append(tmp_sample) @@ -222,9 +213,7 @@ def enumerate_support(self, expand=True): Returns: (enum, enum_lengths) - (*tuple cardinality x batch_shape x event_shape*) """ - _, _, edges, enum_lengths = self._struct().enumerate( - self.log_potentials, self.lengths - ) + _, _, edges, enum_lengths = self._struct().enumerate(self.log_potentials, self.lengths) # if expand: # edges = edges.unsqueeze(1).expand(edges.shape[:1] + self.batch_shape[:1] + edges.shape[1:]) return edges, enum_lengths @@ -295,9 +284,7 @@ def __init__(self, log_potentials, local=False, lengths=None, max_gap=None): super().__init__(log_potentials, lengths) def _struct(self, sr=None): - return self.struct( - sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap - ) + return self.struct(sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap) class HMM(StructDistribution): @@ -379,7 +366,6 @@ def __init__(self, log_potentials, lengths=None, args={}, multiroot=True): setattr(self.struct, "multiroot", multiroot) - class TreeCRF(StructDistribution): r""" Represents a 0th-order span parser with NT nonterminals. Implemented using a @@ -406,6 +392,22 @@ class TreeCRF(StructDistribution): struct = CKY_CRF +class FullTreeCRF(StructDistribution): + r""" + Represents a 1st-order span parser with NT nonterminals. Implemented using a vectorized inside algorithm. + For a mathematical description see: + * Inside-Outside Algorithm, by Michael Collins: http://www.cs.columbia.edu/~mcollins/io.pdf + + Parameters: + log_potentials (tensor) : event_shape (*N x N x N x NT x NT x NT*), e.g. + :math:`\phi(i, j, k, A_i^j \rightarrow B_i^k C_{k+1}^j)` + lengths (long tensor) : batch shape integers for length masking. + + Compact representation: *N x N x N xNT x NT x NT* long tensor (Same) + """ + struct = Full_CKY_CRF + + class SentCFG(StructDistribution): """ Represents a full generative context-free grammar with @@ -435,9 +437,7 @@ def __init__(self, log_potentials, lengths=None): event_shape = log_potentials[0].shape[1:] self.log_potentials = log_potentials self.lengths = lengths - super(StructDistribution, self).__init__( - batch_shape=batch_shape, event_shape=event_shape - ) + super(StructDistribution, self).__init__(batch_shape=batch_shape, event_shape=event_shape) class NonProjectiveDependencyCRF(StructDistribution): diff --git a/torch_struct/full_cky_crf.py b/torch_struct/full_cky_crf.py new file mode 100644 index 00000000..16347c60 --- /dev/null +++ b/torch_struct/full_cky_crf.py @@ -0,0 +1,104 @@ +import torch +from .helpers import _Struct, Chart +from tqdm import tqdm + +A, B = 0, 1 + + +class Full_CKY_CRF(_Struct): + def _check_potentials(self, edge, lengths=None): + batch, N, N1, N2, NT, NT1, NT2 = self._get_dimension(edge) + assert ( + N == N1 == N2 and NT == NT1 == NT2 + ), f"Want N:{N} == N1:{N1} == N2:{N2} and NT:{NT} == NT1:{NT1} == NT2:{NT2}" + edge = self.semiring.convert(edge) + semiring_shape = edge.shape[:-7] + if lengths is None: + lengths = torch.LongTensor([N] * batch).to(edge.device) + + return edge, semiring_shape, batch, N, NT, lengths + + def _dp(self, scores, lengths=None, force_grad=False, cache=True): + sr = self.semiring + + # Scores.shape = *sshape, B, N, N, N, NT, NT, NT + # w/ semantics [ *semiring stuff, b, i, j, k, A, B, C] + # where b is batch index, i is left endpoint, j is right endpoint, k is splitpoint, with rule A -> B C + scores, sshape, batch, N, NT, lengths = self._check_potentials(scores, lengths) + sshape, sdims = list(sshape), list(range(len(sshape))) # usually [0] + S, b = len(sdims), batch + + # Initialize data structs + LEFT, RIGHT = 0, 1 + L_DIM, R_DIM = S + 1, S + 2 # one and two to the right of the batch dim + + # Initialize the base cases with scores from diagonal i=j=k, A=B=C + term_scores = ( + scores.diagonal(0, L_DIM, R_DIM) # diag i,j now at dim -1 + .diagonal(0, L_DIM, -1) # diag of k with that gives i=j=k, now at dim -1 + .diagonal(0, -4, -3) # diag of A, B, now at dim -1, ijk moves to -2 + .diagonal(0, -3, -1) # diag of C with that gives A=B=C + ) + assert term_scores.shape[S + 1 :] == (N, NT), f"{term_scores.shape[S + 1 :]} == {(N, NT)}" + alpha_left = term_scores + alpha_right = term_scores + alphas = [[alpha_left], [alpha_right]] + + # Run vectorized inside alg + for w in range(1, N): + # Scores + # What we want is a tensor with: + # shape: *sshape, batch, (N-w), NT, w, NT, NT + # w/ semantics: [...batch, (i,j=i+w), A, k, B, C] + # where (i,j=i+w) means the diagonal of trees nodes with width w + # Shape: *sshape, batch, N, NT, NT, NT, (N-w) w/ semantics [ ...batch, k, A, B, C, (i,j=i+w)] + score = scores.diagonal(w, L_DIM, R_DIM) # get diagonal scores + + score = score.permute(sdims + [-6, -1, -4, -5, -3, -2]) # move diag (-1) dim and head NT (-4) dim to front + score = score[..., :w, :, :] # remove illegal splitpoints + assert score.shape[S:] == (batch, N - w, NT, w, NT, NT), f"{score.shape[S:]} == {(b, N-w, NT, w, NT, NT)}" + + # Sums of left subtrees + # Shape: *sshape, batch, (N-w), w, NT + # where L[..., i, d, B] is the sum of subtrees up to (i,j=(i+d),B) + left = slice(None, N - w) # left indices + L = torch.stack(alphas[LEFT][:w], dim=-2)[..., left, :, :] + + # Sums of right subtrees + # Shape: *sshape, batch, (N-w), w, NT + # where R[..., h, d, C] is the sum of subtrees up to (i=(N-h-d),j=(N-h),C) + right = slice(w, None) # right indices + R = torch.stack(list(reversed(alphas[RIGHT][:w])), dim=-2)[..., right, :, :] + + # Broadcast them both to match missing dims in score + # Left B is duplicated for all head and right symbols A C + L_bcast = L.reshape(list(sshape) + [b, N - w, 1, w, NT, 1]).repeat(S * [1] + [1, 1, NT, 1, 1, NT]) + # Right C is duplicated for all head and left symbols A B + R_bcast = R.reshape(list(sshape) + [b, N - w, 1, w, 1, NT]).repeat(S * [1] + [1, 1, NT, 1, NT, 1]) + assert score.shape == L_bcast.shape == R_bcast.shape == tuple(list(sshape) + [b, N - w, NT, w, NT, NT]) + + # Now multiply all the scores and sum over k, B, C dimensions (the last three dims) + assert sr.times(score, L_bcast, R_bcast).shape == tuple(list(sshape) + [b, N - w, NT, w, NT, NT]) + sum_prod_w = sr.sum(sr.sum(sr.sum(sr.times(score, L_bcast, R_bcast)))) + assert sum_prod_w.shape[S:] == (b, N - w, NT), f"{sum_prod_w.shape[S:]} == {(b,N-w, NT)}" + + pad = sr.zero_(torch.ones(sshape + [b, w, NT]).to(sum_prod_w.device)) + sum_prod_w_left = torch.cat([sum_prod_w, pad], dim=-2) + sum_prod_w_right = torch.cat([pad, sum_prod_w], dim=-2) + alphas[LEFT].append(sum_prod_w_left) + alphas[RIGHT].append(sum_prod_w_right) + + final = sr.sum(torch.stack(alphas[LEFT], dim=-2))[..., 0, :] # sum out root symbol + log_Z = final[:, torch.arange(batch), lengths - 1] + return log_Z, [scores], alphas + + @staticmethod + def _rand(): + batch = torch.randint(2, 5, (1,)) + N = torch.randint(2, 5, (1,)) + NT = torch.randint(2, 5, (1,)) + scores = torch.rand(batch, N, N, N, NT, NT, NT) + return scores, (batch.item(), N.item()) + + def enumerate(self, scores, lengths=None): + raise NotImplementedError \ No newline at end of file diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index bdad6678..174e4c20 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -35,11 +35,7 @@ def backward(ctx, grad_output): class Chart: def __init__(self, size, potentials, semiring, cache=True): self.data = semiring.zero_( - torch.zeros( - *((semiring.size(),) + size), - dtype=potentials.dtype, - device=potentials.device - ) + torch.zeros(*((semiring.size(),) + size), dtype=potentials.dtype, device=potentials.device) ) self.grad = self.data.detach().clone().fill_(0.0) self.cache = cache @@ -95,11 +91,7 @@ def _make_chart(self, N, size, potentials, force_grad=False): return [ ( self.semiring.zero_( - torch.zeros( - *((self.semiring.size(),) + size), - dtype=potentials.dtype, - device=potentials.device - ) + torch.zeros(*((self.semiring.size(),) + size), dtype=potentials.dtype, device=potentials.device) ).requires_grad_(force_grad and not potentials.requires_grad) ) for _ in range(N) @@ -117,11 +109,7 @@ def sum(self, edge, lengths=None, _autograd=True, _raw=False): v: b tensor of total sum """ - if ( - _autograd - or self.semiring is not LogSemiring - or not hasattr(self, "_dp_backward") - ): + if _autograd or self.semiring is not LogSemiring or not hasattr(self, "_dp_backward"): v = self._dp(edge, lengths)[0] if _raw: @@ -139,9 +127,7 @@ def forward(ctx, input): @staticmethod def backward(ctx, grad_v): marginals = self._dp_backward(edge, lengths, alpha) - return marginals.mul( - grad_v.view((grad_v.shape[0],) + tuple([1] * marginals.dim())) - ) + return marginals.mul(grad_v.view((grad_v.shape[0],) + tuple([1] * marginals.dim()))) return DPManual.apply(edge) @@ -156,35 +142,28 @@ def marginals(self, edge, lengths=None, _autograd=True, _raw=False): marginals: b x (N-1) x C x C table """ - if ( - _autograd - or self.semiring is not LogSemiring - or not hasattr(self, "_dp_backward") - ): - v, edges, _ = self._dp( - edge, lengths=lengths, force_grad=True, cache=not _raw - ) - if _raw: - all_m = [] - for k in range(v.shape[0]): - obj = v[k].sum(dim=0) - - marg = torch.autograd.grad( - obj, - edges, - create_graph=True, - only_inputs=True, - allow_unused=False, - ) - all_m.append(self.semiring.unconvert(self._arrange_marginals(marg))) - return torch.stack(all_m, dim=0) - else: - obj = self.semiring.unconvert(v).sum(dim=0) - marg = torch.autograd.grad( - obj, edges, create_graph=True, only_inputs=True, allow_unused=False - ) - a_m = self._arrange_marginals(marg) - return self.semiring.unconvert(a_m) + if _autograd or self.semiring is not LogSemiring or not hasattr(self, "_dp_backward"): + with torch.enable_grad(): # allows marginals even when input tensors don't need grad + v, edges, _ = self._dp(edge, lengths=lengths, force_grad=True, cache=not _raw) + if _raw: + all_m = [] + for k in range(v.shape[0]): + obj = v[k].sum(dim=0) + + marg = torch.autograd.grad( + obj, + edges, + create_graph=True, + only_inputs=True, + allow_unused=False, + ) + all_m.append(self.semiring.unconvert(self._arrange_marginals(marg))) + return torch.stack(all_m, dim=0) + else: + obj = self.semiring.unconvert(v).sum(dim=0) + marg = torch.autograd.grad(obj, edges, create_graph=True, only_inputs=True, allow_unused=False) + a_m = self._arrange_marginals(marg) + return self.semiring.unconvert(a_m) else: v, _, alpha = self._dp(edge, lengths=lengths, force_grad=True) return self._dp_backward(edge, lengths, alpha) From 93ebb94496e133ca615b16589bcfb589383bb62d Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 5 Dec 2020 20:29:43 +0000 Subject: [PATCH 2/3] fix formatting clobber --- torch_struct/helpers.py | 38 ++++++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index 174e4c20..e41f51e2 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -35,7 +35,11 @@ def backward(ctx, grad_output): class Chart: def __init__(self, size, potentials, semiring, cache=True): self.data = semiring.zero_( - torch.zeros(*((semiring.size(),) + size), dtype=potentials.dtype, device=potentials.device) + torch.zeros( + *((semiring.size(),) + size), + dtype=potentials.dtype, + device=potentials.device + ) ) self.grad = self.data.detach().clone().fill_(0.0) self.cache = cache @@ -91,7 +95,11 @@ def _make_chart(self, N, size, potentials, force_grad=False): return [ ( self.semiring.zero_( - torch.zeros(*((self.semiring.size(),) + size), dtype=potentials.dtype, device=potentials.device) + torch.zeros( + *((self.semiring.size(),) + size), + dtype=potentials.dtype, + device=potentials.device + ) ).requires_grad_(force_grad and not potentials.requires_grad) ) for _ in range(N) @@ -109,7 +117,11 @@ def sum(self, edge, lengths=None, _autograd=True, _raw=False): v: b tensor of total sum """ - if _autograd or self.semiring is not LogSemiring or not hasattr(self, "_dp_backward"): + if ( + _autograd + or self.semiring is not LogSemiring + or not hasattr(self, "_dp_backward") + ): v = self._dp(edge, lengths)[0] if _raw: @@ -127,7 +139,9 @@ def forward(ctx, input): @staticmethod def backward(ctx, grad_v): marginals = self._dp_backward(edge, lengths, alpha) - return marginals.mul(grad_v.view((grad_v.shape[0],) + tuple([1] * marginals.dim()))) + return marginals.mul( + grad_v.view((grad_v.shape[0],) + tuple([1] * marginals.dim())) + ) return DPManual.apply(edge) @@ -142,9 +156,15 @@ def marginals(self, edge, lengths=None, _autograd=True, _raw=False): marginals: b x (N-1) x C x C table """ - if _autograd or self.semiring is not LogSemiring or not hasattr(self, "_dp_backward"): + if ( + _autograd + or self.semiring is not LogSemiring + or not hasattr(self, "_dp_backward") + ): with torch.enable_grad(): # allows marginals even when input tensors don't need grad - v, edges, _ = self._dp(edge, lengths=lengths, force_grad=True, cache=not _raw) + v, edges, _ = self._dp( + edge, lengths=lengths, force_grad=True, cache=not _raw + ) if _raw: all_m = [] for k in range(v.shape[0]): @@ -161,7 +181,9 @@ def marginals(self, edge, lengths=None, _autograd=True, _raw=False): return torch.stack(all_m, dim=0) else: obj = self.semiring.unconvert(v).sum(dim=0) - marg = torch.autograd.grad(obj, edges, create_graph=True, only_inputs=True, allow_unused=False) + marg = torch.autograd.grad( + obj, edges, create_graph=True, only_inputs=True, allow_unused=False + ) a_m = self._arrange_marginals(marg) return self.semiring.unconvert(a_m) else: @@ -177,4 +199,4 @@ def from_parts(spans): return spans, None def _arrange_marginals(self, marg): - return marg[0] + return marg[0] \ No newline at end of file From a105c097281c86e5fd9d026f547011a4aee349e8 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 5 Dec 2020 20:35:04 +0000 Subject: [PATCH 3/3] fix formatting clobber --- torch_struct/distributions.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index b9edcd3e..b80cd91a 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -90,7 +90,9 @@ def cross_entropy(self, other): cross entropy (*batch_shape*) """ - return self._struct(CrossEntropySemiring).sum([self.log_potentials, other.log_potentials], self.lengths) + return self._struct(CrossEntropySemiring).sum( + [self.log_potentials, other.log_potentials], self.lengths + ) def kl(self, other): """ @@ -99,7 +101,9 @@ def kl(self, other): Returns: cross entropy (*batch_shape*) """ - return self._struct(KLDivergenceSemiring).sum([self.log_potentials, other.log_potentials], self.lengths) + return self._struct(KLDivergenceSemiring).sum( + [self.log_potentials, other.log_potentials], self.lengths + ) @lazy_property def max(self): @@ -127,7 +131,9 @@ def kmax(self, k): kmax (*k x batch_shape*) """ with torch.enable_grad(): - return self._struct(KMaxSemiring(k)).sum(self.log_potentials, self.lengths, _raw=True) + return self._struct(KMaxSemiring(k)).sum( + self.log_potentials, self.lengths, _raw=True + ) def topk(self, k): r""" @@ -137,7 +143,9 @@ def topk(self, k): kmax (*k x batch_shape x event_shape*) """ with torch.enable_grad(): - return self._struct(KMaxSemiring(k)).marginals(self.log_potentials, self.lengths, _raw=True) + return self._struct(KMaxSemiring(k)).marginals( + self.log_potentials, self.lengths, _raw=True + ) @lazy_property def mode(self): @@ -192,7 +200,9 @@ def sample(self, sample_shape=torch.Size()): samples = [] for k in range(nsamples): if k % 10 == 0: - sample = self._struct(MultiSampledSemiring).marginals(self.log_potentials, lengths=self.lengths) + sample = self._struct(MultiSampledSemiring).marginals( + self.log_potentials, lengths=self.lengths + ) sample = sample.detach() tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % 10) + 1) samples.append(tmp_sample) @@ -213,7 +223,9 @@ def enumerate_support(self, expand=True): Returns: (enum, enum_lengths) - (*tuple cardinality x batch_shape x event_shape*) """ - _, _, edges, enum_lengths = self._struct().enumerate(self.log_potentials, self.lengths) + _, _, edges, enum_lengths = self._struct().enumerate( + self.log_potentials, self.lengths + ) # if expand: # edges = edges.unsqueeze(1).expand(edges.shape[:1] + self.batch_shape[:1] + edges.shape[1:]) return edges, enum_lengths @@ -284,7 +296,9 @@ def __init__(self, log_potentials, local=False, lengths=None, max_gap=None): super().__init__(log_potentials, lengths) def _struct(self, sr=None): - return self.struct(sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap) + return self.struct( + sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap + ) class HMM(StructDistribution): @@ -437,7 +451,9 @@ def __init__(self, log_potentials, lengths=None): event_shape = log_potentials[0].shape[1:] self.log_potentials = log_potentials self.lengths = lengths - super(StructDistribution, self).__init__(batch_shape=batch_shape, event_shape=event_shape) + super(StructDistribution, self).__init__( + batch_shape=batch_shape, event_shape=event_shape + ) class NonProjectiveDependencyCRF(StructDistribution):