Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Value expectation and 1st order CKY #93

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
12 changes: 12 additions & 0 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@ def test_simple(data, seed):
dist.kmax(5)
dist.count

val_func = torch.rand(*vals.shape, 10)
E_val = dist.expected_value(val_func)
struct_vals = (
edges.unsqueeze(-1)
.mul(val_func.unsqueeze(0))
.reshape(*edges.shape[:2], -1, val_func.shape[-1])
.sum(2)
)
assert torch.isclose(
E_val, log_probs.exp().unsqueeze(-1).mul(struct_vals).sum(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious. Why not just make this the implementation of expected value? It seems just as good and perhaps more efficient.y

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, maybe I'm confused but isn't this enumerating over all possible structures explicitly?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, my comment is confusing.

I think a valid way of computing an expectation over any "part-level value" is to first compute the marginals (.marginals()) and then doing an elementwise mul (.mul) and then summing. Doesn't that give you the same thing as the semiring?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wow, I didn't realize this! I just tested it out and it appears to be more efficient for larger structure sizes. I guess this is due to the fast log semiring implementation? I'll update things to use this approach instead.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think that is right... I haven't thought about this too much, but my guess is that this is just better on GPU hardware since the expectation is batched at the end. But it seems worth understand when this works. I don't think you can compute Entropy this way? (but I might be wrong)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I also don't think entropy can be done this way -- I just tested it out and the results didn't match the semiring. I will switch to this implementation in the latest commit and get rid of the value semiring.

Fwiw I ran a quick speed comparison you might be interested in:

B, N, C = 4, 200, 10

phis = torch.randn(B,N,C,C).cuda()
vals = torch.randn(B,N,C,C,10).cuda()

Results from running w/ genbmm

%%timeit
LinearChainCRF(phis).expected_value(vals)
>>> 100 loops, best of 3: 6.34 ms per loop

%%timeit
LinearChainCRF(phis).marginals.unsqueeze(-1).mul(vals).reshape(B,-1,vals.shape[-1]).sum(1)
>>> 100 loops, best of 3: 5.64 ms per loop

Results from running w/o genbmm

%%timeit
LinearChainCRF(phis).expected_value(vals)
>>> 100 loops, best of 3: 9.67 ms per loop

%%timeit
LinearChainCRF(phis).marginals.unsqueeze(-1).mul(vals).reshape(B,-1,vals.shape[-1]).sum(1)
>>> 100 loops, best of 3: 8.83 ms per loop

).all(), "Efficient expected value not equal to enumeration"


@given(data(), integers(min_value=1, max_value=20))
@settings(max_examples=50, deadline=None)
Expand Down
82 changes: 76 additions & 6 deletions torch_struct/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .alignment import Alignment
from .deptree import DepTree, deptree_nonproj, deptree_part
from .cky_crf import CKY_CRF
from .full_cky_crf import Full_CKY_CRF
from .semirings import (
LogSemiring,
MaxSemiring,
Expand All @@ -17,6 +18,7 @@
KMaxSemiring,
StdSemiring,
GumbelCRFSemiring,
ValueExpectationSemiring,
)


Expand Down Expand Up @@ -179,11 +181,49 @@ def marginals(self):

@lazy_property
def count(self):
"Compute the log-partition function."
"Compute the total number of structures in the CRF support set."
ones = torch.ones_like(self.log_potentials)
ones[self.log_potentials.eq(-float("inf"))] = 0
return self._struct(StdSemiring).sum(ones, self.lengths)

def expected_value(self, values):
"""
Compute expectated value for distribution :math:`E_z[f(z)]` where f decomposes additively over the factors of p_z.

Params:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be "Parameters:"

* values (*batch_shape x *event_shape, *value_shape): torch.FloatTensor that assigns a value to each part
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's put the types in the first parens, and use :class:torch.FloatTensor

of the structure. `values` can have 0 or more training dimensions in addition to the `event_shape`,
which allows for computing the expected value of, say, a vector valued function
(or a vector of scalar functions).
Returns:
expected value (*batch_shape, *value_shape)
"""
# Handle value function dimensionality
phi_shape = self.log_potentials.shape
extra_dims = len(values.shape) - len(phi_shape)
if extra_dims:
# Extra dims get flattened and put in front
out_val_shape = values.shape[len(phi_shape) :]
values = values.reshape(*phi_shape, -1)
values = values.permute([-1] + list(range(len(phi_shape))))
k = values.shape[0]
else:
out_val_shape = None
k = 1

# Compute expected value
val = self._struct(ValueExpectationSemiring(k)).sum(
[self.log_potentials, values], self.lengths
)

# Reformat dimensions to match input dimensions
val = val.permute(list(range(1, len(val.shape))) + [0])
if out_val_shape is not None:
val = val.reshape(*val.shape[:-1] + out_val_shape)
else:
val = val.squeeze(-1)
return val

def gumbel_crf(self, temperature=1.0):
with torch.enable_grad():
st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(
Expand All @@ -204,26 +244,30 @@ def partition(self):
"Compute the log-partition function."
return self._struct(LogSemiring).sum(self.log_potentials, self.lengths)

def sample(self, sample_shape=torch.Size()):
def sample(self, sample_shape=torch.Size(), batch_size=10):
r"""
Compute structured samples from the distribution :math:`z \sim p(z)`.

Parameters:
sample_shape (int): number of samples
batch_size (int): number of samples to compute at a time

Returns:
samples (*sample_shape x batch_shape x event_shape*)
"""
assert len(sample_shape) == 1
nsamples = sample_shape[0]
if type(sample_shape) == int:
nsamples = sample_shape
else:
assert len(sample_shape) == 1
nsamples = sample_shape[0]
samples = []
for k in range(nsamples):
if k % 10 == 0:
if k % batch_size == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, sorry this is my fault. 10 is a global constant. Let's put it on MultiSampledSemiring.

sample = self._struct(MultiSampledSemiring).marginals(
self.log_potentials, lengths=self.lengths
)
sample = sample.detach()
tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % 10) + 1)
tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % batch_size) + 1)
samples.append(tmp_sample)
return torch.stack(samples)

Expand Down Expand Up @@ -411,6 +455,32 @@ class TreeCRF(StructDistribution):
struct = CKY_CRF


class FullTreeCRF(StructDistribution):
r"""
Represents a 1st-order span parser with NT nonterminals. Implemented using a
fast CKY algorithm.

For a description see:

* Inside-Outside Algorithm, by Michael Collins

Event shape is of the form:

Parameters:
log_potentials (tensor) : event_shape (*N x N x N x NT x NT x NT*), e.g.
:math:`\phi(i, j, k, A_i^j \rightarrow B_i^k C_{k+1}^j)`
lengths (long tensor) : batch shape integers for length masking.

Implementation uses width-batched, forward-pass only

* Parallel Time: :math:`O(N)` parallel merges.
* Forward Memory: :math:`O(N^2)`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can't be right... isn't the event shape O(N^3) alone?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops yeah that's from modifying the CKYCRF class


Compact representation: *N x N x N xNT x NT x NT* long tensor (Same)
"""
struct = Full_CKY_CRF


class SentCFG(StructDistribution):
"""
Represents a full generative context-free grammar with
Expand Down
114 changes: 114 additions & 0 deletions torch_struct/full_cky_crf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import torch
from .helpers import _Struct, Chart
from tqdm import tqdm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be sure to run python setup.py style to run flake8 . It will catch these errors.


A, B = 0, 1


class Full_CKY_CRF(_Struct):
def _check_potentials(self, edge, lengths=None):
batch, N, N1, N2, NT, NT1, NT2 = self._get_dimension(edge)
assert (
N == N1 == N2 and NT == NT1 == NT2
), f"Want N:{N} == N1:{N1} == N2:{N2} and NT:{NT} == NT1:{NT1} == NT2:{NT2}"
edge = self.semiring.convert(edge)
semiring_shape = edge.shape[:-7]
if lengths is None:
lengths = torch.LongTensor([N] * batch).to(edge.device)

return edge, semiring_shape, batch, N, NT, lengths

def logpartition(self, scores, lengths=None, force_grad=False, cache=True):
sr = self.semiring

# Scores.shape = *sshape, B, N, N, N, NT, NT, NT
# w/ semantics [ *semiring stuff, b, i, j, k, A, B, C]
# where b is batch index, i is left endpoint, j is right endpoint, k is splitpoint, with rule A -> B C
scores, sshape, batch, N, NT, lengths = self._check_potentials(scores, lengths)
sshape, sdims = list(sshape), list(range(len(sshape))) # usually [0]
S, b = len(sdims), batch

# Initialize data structs
LEFT, RIGHT = 0, 1
L_DIM, R_DIM = S + 1, S + 2 # one and two to the right of the batch dim

# Initialize the base cases with scores from diagonal i=j=k, A=B=C
term_scores = (
scores.diagonal(0, L_DIM, R_DIM) # diag i,j now at dim -1
.diagonal(0, L_DIM, -1) # diag of k with that gives i=j=k, now at dim -1
.diagonal(0, -4, -3) # diag of A, B, now at dim -1, ijk moves to -2
.diagonal(0, -3, -1) # diag of C with that gives A=B=C
)
assert term_scores.shape[S + 1 :] == (
N,
NT,
), f"{term_scores.shape[S + 1 :]} == {(N, NT)}"
alpha_left = term_scores
alpha_right = term_scores
alphas = [[alpha_left], [alpha_right]]

# Run vectorized inside alg
for w in range(1, N):
# Scores
# What we want is a tensor with:
# shape: *sshape, batch, (N-w), NT, w, NT, NT
# w/ semantics: [...batch, (i,j=i+w), A, k, B, C]
# where (i,j=i+w) means the diagonal of trees nodes with width w
# Shape: *sshape, batch, N, NT, NT, NT, (N-w) w/ semantics [ ...batch, k, A, B, C, (i,j=i+w)]
score = scores.diagonal(w, L_DIM, R_DIM) # get diagonal scores

score = score.permute(
sdims + [-6, -1, -4, -5, -3, -2]
) # move diag (-1) dim and head NT (-4) dim to front
score = score[..., :w, :, :] # remove illegal splitpoints
assert score.shape[S:] == (
batch,
N - w,
NT,
w,
NT,
NT,
), f"{score.shape[S:]} == {(b, N-w, NT, w, NT, NT)}"

# Sums of left subtrees
# Shape: *sshape, batch, (N-w), w, NT
# where L[..., i, d, B] is the sum of subtrees up to (i,j=(i+d),B)
left = slice(None, N - w) # left indices
L = torch.stack(alphas[LEFT][:w], dim=-2)[..., left, :, :]

# Sums of right subtrees
# Shape: *sshape, batch, (N-w), w, NT
# where R[..., h, d, C] is the sum of subtrees up to (i=(N-h-d),j=(N-h),C)
right = slice(w, None) # right indices
R = torch.stack(list(reversed(alphas[RIGHT][:w])), dim=-2)[..., right, :, :]

# Broadcast them both to match missing dims in score
# Left B is duplicated for all head and right symbols A C
L_bcast = L.reshape(list(sshape) + [b, N - w, 1, w, NT, 1])

# Right C is duplicated for all head and left symbols A B
R_bcast = R.reshape(list(sshape) + [b, N - w, 1, w, 1, NT])

# Now multiply all the scores and sum over k, B, C dimensions (the last three dims)
assert sr.times(score, L_bcast, R_bcast).shape == tuple(
list(sshape) + [b, N - w, NT, w, NT, NT]
)
# sum_prod_w = sr.sum(sr.sum(sr.sum(sr.times(score, L_bcast, R_bcast))))
sum_prod_w = sr.sum(sr.times(score, L_bcast, R_bcast).reshape(*score.shape[:-3],-1))
assert sum_prod_w.shape[S:] == (
b,
N - w,
NT,
), f"{sum_prod_w.shape[S:]} == {(b,N-w, NT)}"

pad = sr.zero_(torch.ones(sshape + [b, w, NT]).to(sum_prod_w.device))
sum_prod_w_left = torch.cat([sum_prod_w, pad], dim=-2)
sum_prod_w_right = torch.cat([pad, sum_prod_w], dim=-2)
alphas[LEFT].append(sum_prod_w_left)
alphas[RIGHT].append(sum_prod_w_right)

final = sr.sum(torch.stack(alphas[LEFT], dim=-2))[
..., 0, :
] # sum out root symbol
log_Z = final[:, torch.arange(batch), lengths - 1]
return log_Z, [scores]
74 changes: 53 additions & 21 deletions torch_struct/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,44 @@ def __setitem__(self, ind, new):


class _Struct:
"""`_Struct` is base class used to represent the graphical structure of a model.

Subclasses should implement a `logpartition` method which computes the partition function (under the standard `_BaseSemiring`).
Different `StructDistribution` methods will instantiate the `_Struct` subclasses
"""

def __init__(self, semiring=LogSemiring):
self.semiring = semiring

def logpartition(self, scores, lengths=None, force_grad=False):
"""Implement computation equivalent to the computing log partition constant logZ (if self.semiring == `_BaseSemiring`).

Parameters:
scores (torch.FloatTensor) : log potential scores for each factor of the model. Shape (* x batch size x *event_shape )
lengths (torch.LongTensor) : = None, lengths of batch padded examples. Shape = ( * x batch size )
force_grad: bool = False

Returns:
v (torch.Tensor) : the resulting output of the dynammic program
edges (List[torch.Tensor]): the log edge potentials of the model.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changing this to logpotentials throughout.

When `scores` is already in a log_potential format for the distribution (typical), this will be
[scores], as in `Alignment`, `LinearChain`, `SemiMarkov`, `CKY_CRF`.
An exceptional case is the `CKY` struct, which takes log potential parameters from production rules
for a PCFG, which are by definition independent of position in the sequence.
charts: Optional[List[Chart]] = None, the charts used in computing the dp. They are needed if we want to run the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Going to remove this for simplicity.

"backward" dynamic program and compute things like marginals w/o autograd.

"""
raise NotImplementedError

def score(self, potentials, parts, batch_dims=[0]):
score = torch.mul(potentials, parts)
"""Score for entire structure is product of potentials for all activated "parts"."""
score = torch.mul(potentials, parts) # mask potentials by activated "parts"
batch = tuple((score.shape[b] for b in batch_dims))
return self.semiring.prod(score.view(batch + (-1,)))

def _bin_length(self, length):
"""Find least upper bound for lengths that is a power of 2. Used in parallel scans."""
log_N = int(math.ceil(math.log(length, 2)))
bin_N = int(math.pow(2, log_N))
return log_N, bin_N
Expand Down Expand Up @@ -92,28 +121,31 @@ def marginals(self, logpotentials, lengths=None, _raw=False):
marginals: b x (N-1) x C x C table

"""
v, edges = self.logpartition(logpotentials, lengths=lengths, force_grad=True)
if _raw:
all_m = []
for k in range(v.shape[0]):
obj = v[k].sum(dim=0)

with torch.autograd.enable_grad(): # in case input potentials don't have grads enabled.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool

v, edges = self.logpartition(
logpotentials, lengths=lengths, force_grad=True
)
if _raw:
all_m = []
for k in range(v.shape[0]):
obj = v[k].sum(dim=0)

marg = torch.autograd.grad(
obj,
edges,
create_graph=True,
only_inputs=True,
allow_unused=False,
)
all_m.append(self.semiring.unconvert(self._arrange_marginals(marg)))
return torch.stack(all_m, dim=0)
else:
obj = self.semiring.unconvert(v).sum(dim=0)
marg = torch.autograd.grad(
obj,
edges,
create_graph=True,
only_inputs=True,
allow_unused=False,
obj, edges, create_graph=True, only_inputs=True, allow_unused=False
)
all_m.append(self.semiring.unconvert(self._arrange_marginals(marg)))
return torch.stack(all_m, dim=0)
else:
obj = self.semiring.unconvert(v).sum(dim=0)
marg = torch.autograd.grad(
obj, edges, create_graph=True, only_inputs=True, allow_unused=False
)
a_m = self._arrange_marginals(marg)
return self.semiring.unconvert(a_m)
a_m = self._arrange_marginals(marg)
return self.semiring.unconvert(a_m)

@staticmethod
def to_parts(spans, extra, lengths=None):
Expand Down
2 changes: 1 addition & 1 deletion torch_struct/linearchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def from_parts(edge):
batch, N_1, C, _ = edge.shape
N = N_1 + 1
labels = torch.zeros(batch, N).long()
on = edge.nonzero()
on = edge.nonzero(as_tuple=False)
for i in range(on.shape[0]):
if on[i][1] == 0:
labels[on[i][0], on[i][1]] = on[i][3]
Expand Down
Loading