Skip to content

1st order cky implementation #83

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions torch_struct/distributions.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
from .alignment import Alignment
from .deptree import DepTree, deptree_nonproj, deptree_part
from .cky_crf import CKY_CRF
from .full_cky_crf import Full_CKY_CRF
from .semirings import (
LogSemiring,
MaxSemiring,
@@ -19,7 +20,6 @@
)



class StructDistribution(Distribution):
r"""
Base structured distribution class.
@@ -69,7 +69,6 @@ def log_prob(self, value):
batch_dims=batch_dims,
)


return v - self.partition

@lazy_property
@@ -91,7 +90,9 @@ def cross_entropy(self, other):
cross entropy (*batch_shape*)
"""

return self._struct(CrossEntropySemiring).sum([self.log_potentials, other.log_potentials], self.lengths)
return self._struct(CrossEntropySemiring).sum(
[self.log_potentials, other.log_potentials], self.lengths
)

def kl(self, other):
"""
@@ -100,7 +101,9 @@ def kl(self, other):
Returns:
cross entropy (*batch_shape*)
"""
return self._struct(KLDivergenceSemiring).sum([self.log_potentials, other.log_potentials], self.lengths)
return self._struct(KLDivergenceSemiring).sum(
[self.log_potentials, other.log_potentials], self.lengths
)

@lazy_property
def max(self):
@@ -166,10 +169,8 @@ def marginals(self):
def count(self):
"Compute the log-partition function."
ones = torch.ones_like(self.log_potentials)
ones[self.log_potentials.eq(-float('inf'))] = 0
return self._struct(StdSemiring).sum(
ones, self.lengths
)
ones[self.log_potentials.eq(-float("inf"))] = 0
return self._struct(StdSemiring).sum(ones, self.lengths)

# @constraints.dependent_property
# def support(self):
@@ -379,7 +380,6 @@ def __init__(self, log_potentials, lengths=None, args={}, multiroot=True):
setattr(self.struct, "multiroot", multiroot)



class TreeCRF(StructDistribution):
r"""
Represents a 0th-order span parser with NT nonterminals. Implemented using a
@@ -406,6 +406,22 @@ class TreeCRF(StructDistribution):
struct = CKY_CRF


class FullTreeCRF(StructDistribution):
r"""
Represents a 1st-order span parser with NT nonterminals. Implemented using a vectorized inside algorithm.
For a mathematical description see:
* Inside-Outside Algorithm, by Michael Collins: http://www.cs.columbia.edu/~mcollins/io.pdf
Parameters:
log_potentials (tensor) : event_shape (*N x N x N x NT x NT x NT*), e.g.
:math:`\phi(i, j, k, A_i^j \rightarrow B_i^k C_{k+1}^j)`
lengths (long tensor) : batch shape integers for length masking.
Compact representation: *N x N x N xNT x NT x NT* long tensor (Same)
"""
struct = Full_CKY_CRF


class SentCFG(StructDistribution):
"""
Represents a full generative context-free grammar with
104 changes: 104 additions & 0 deletions torch_struct/full_cky_crf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import torch
from .helpers import _Struct, Chart
from tqdm import tqdm

A, B = 0, 1


class Full_CKY_CRF(_Struct):
def _check_potentials(self, edge, lengths=None):
batch, N, N1, N2, NT, NT1, NT2 = self._get_dimension(edge)
assert (
N == N1 == N2 and NT == NT1 == NT2
), f"Want N:{N} == N1:{N1} == N2:{N2} and NT:{NT} == NT1:{NT1} == NT2:{NT2}"
edge = self.semiring.convert(edge)
semiring_shape = edge.shape[:-7]
if lengths is None:
lengths = torch.LongTensor([N] * batch).to(edge.device)

return edge, semiring_shape, batch, N, NT, lengths

def _dp(self, scores, lengths=None, force_grad=False, cache=True):
sr = self.semiring

# Scores.shape = *sshape, B, N, N, N, NT, NT, NT
# w/ semantics [ *semiring stuff, b, i, j, k, A, B, C]
# where b is batch index, i is left endpoint, j is right endpoint, k is splitpoint, with rule A -> B C
scores, sshape, batch, N, NT, lengths = self._check_potentials(scores, lengths)
sshape, sdims = list(sshape), list(range(len(sshape))) # usually [0]
S, b = len(sdims), batch

# Initialize data structs
LEFT, RIGHT = 0, 1
L_DIM, R_DIM = S + 1, S + 2 # one and two to the right of the batch dim

# Initialize the base cases with scores from diagonal i=j=k, A=B=C
term_scores = (
scores.diagonal(0, L_DIM, R_DIM) # diag i,j now at dim -1
.diagonal(0, L_DIM, -1) # diag of k with that gives i=j=k, now at dim -1
.diagonal(0, -4, -3) # diag of A, B, now at dim -1, ijk moves to -2
.diagonal(0, -3, -1) # diag of C with that gives A=B=C
)
assert term_scores.shape[S + 1 :] == (N, NT), f"{term_scores.shape[S + 1 :]} == {(N, NT)}"
alpha_left = term_scores
alpha_right = term_scores
alphas = [[alpha_left], [alpha_right]]

# Run vectorized inside alg
for w in range(1, N):
# Scores
# What we want is a tensor with:
# shape: *sshape, batch, (N-w), NT, w, NT, NT
# w/ semantics: [...batch, (i,j=i+w), A, k, B, C]
# where (i,j=i+w) means the diagonal of trees nodes with width w
# Shape: *sshape, batch, N, NT, NT, NT, (N-w) w/ semantics [ ...batch, k, A, B, C, (i,j=i+w)]
score = scores.diagonal(w, L_DIM, R_DIM) # get diagonal scores

score = score.permute(sdims + [-6, -1, -4, -5, -3, -2]) # move diag (-1) dim and head NT (-4) dim to front
score = score[..., :w, :, :] # remove illegal splitpoints
assert score.shape[S:] == (batch, N - w, NT, w, NT, NT), f"{score.shape[S:]} == {(b, N-w, NT, w, NT, NT)}"

# Sums of left subtrees
# Shape: *sshape, batch, (N-w), w, NT
# where L[..., i, d, B] is the sum of subtrees up to (i,j=(i+d),B)
left = slice(None, N - w) # left indices
L = torch.stack(alphas[LEFT][:w], dim=-2)[..., left, :, :]

# Sums of right subtrees
# Shape: *sshape, batch, (N-w), w, NT
# where R[..., h, d, C] is the sum of subtrees up to (i=(N-h-d),j=(N-h),C)
right = slice(w, None) # right indices
R = torch.stack(list(reversed(alphas[RIGHT][:w])), dim=-2)[..., right, :, :]

# Broadcast them both to match missing dims in score
# Left B is duplicated for all head and right symbols A C
L_bcast = L.reshape(list(sshape) + [b, N - w, 1, w, NT, 1]).repeat(S * [1] + [1, 1, NT, 1, 1, NT])
# Right C is duplicated for all head and left symbols A B
R_bcast = R.reshape(list(sshape) + [b, N - w, 1, w, 1, NT]).repeat(S * [1] + [1, 1, NT, 1, NT, 1])
assert score.shape == L_bcast.shape == R_bcast.shape == tuple(list(sshape) + [b, N - w, NT, w, NT, NT])

# Now multiply all the scores and sum over k, B, C dimensions (the last three dims)
assert sr.times(score, L_bcast, R_bcast).shape == tuple(list(sshape) + [b, N - w, NT, w, NT, NT])
sum_prod_w = sr.sum(sr.sum(sr.sum(sr.times(score, L_bcast, R_bcast))))
assert sum_prod_w.shape[S:] == (b, N - w, NT), f"{sum_prod_w.shape[S:]} == {(b,N-w, NT)}"

pad = sr.zero_(torch.ones(sshape + [b, w, NT]).to(sum_prod_w.device))
sum_prod_w_left = torch.cat([sum_prod_w, pad], dim=-2)
sum_prod_w_right = torch.cat([pad, sum_prod_w], dim=-2)
alphas[LEFT].append(sum_prod_w_left)
alphas[RIGHT].append(sum_prod_w_right)

final = sr.sum(torch.stack(alphas[LEFT], dim=-2))[..., 0, :] # sum out root symbol
log_Z = final[:, torch.arange(batch), lengths - 1]
return log_Z, [scores], alphas

@staticmethod
def _rand():
batch = torch.randint(2, 5, (1,))
N = torch.randint(2, 5, (1,))
NT = torch.randint(2, 5, (1,))
scores = torch.rand(batch, N, N, N, NT, NT, NT)
return scores, (batch.item(), N.item())

def enumerate(self, scores, lengths=None):
raise NotImplementedError
47 changes: 24 additions & 23 deletions torch_struct/helpers.py
Original file line number Diff line number Diff line change
@@ -161,30 +161,31 @@ def marginals(self, edge, lengths=None, _autograd=True, _raw=False):
or self.semiring is not LogSemiring
or not hasattr(self, "_dp_backward")
):
v, edges, _ = self._dp(
edge, lengths=lengths, force_grad=True, cache=not _raw
)
if _raw:
all_m = []
for k in range(v.shape[0]):
obj = v[k].sum(dim=0)

with torch.enable_grad(): # allows marginals even when input tensors don't need grad
v, edges, _ = self._dp(
edge, lengths=lengths, force_grad=True, cache=not _raw
)
if _raw:
all_m = []
for k in range(v.shape[0]):
obj = v[k].sum(dim=0)

marg = torch.autograd.grad(
obj,
edges,
create_graph=True,
only_inputs=True,
allow_unused=False,
)
all_m.append(self.semiring.unconvert(self._arrange_marginals(marg)))
return torch.stack(all_m, dim=0)
else:
obj = self.semiring.unconvert(v).sum(dim=0)
marg = torch.autograd.grad(
obj,
edges,
create_graph=True,
only_inputs=True,
allow_unused=False,
obj, edges, create_graph=True, only_inputs=True, allow_unused=False
)
all_m.append(self.semiring.unconvert(self._arrange_marginals(marg)))
return torch.stack(all_m, dim=0)
else:
obj = self.semiring.unconvert(v).sum(dim=0)
marg = torch.autograd.grad(
obj, edges, create_graph=True, only_inputs=True, allow_unused=False
)
a_m = self._arrange_marginals(marg)
return self.semiring.unconvert(a_m)
a_m = self._arrange_marginals(marg)
return self.semiring.unconvert(a_m)
else:
v, _, alpha = self._dp(edge, lengths=lengths, force_grad=True)
return self._dp_backward(edge, lengths, alpha)
@@ -198,4 +199,4 @@ def from_parts(spans):
return spans, None

def _arrange_marginals(self, marg):
return marg[0]
return marg[0]