-
Notifications
You must be signed in to change notification settings - Fork 92
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Value expectation and 1st order CKY #93
base: master
Are you sure you want to change the base?
Changes from 10 commits
90fc546
e863fbb
498c964
049dbd4
3c5dfbc
73c8d7b
6e1704a
2d0abe8
297209a
657fbc6
6edceb0
19982f7
71004b2
cded5e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
from .alignment import Alignment | ||
from .deptree import DepTree, deptree_nonproj, deptree_part | ||
from .cky_crf import CKY_CRF | ||
from .full_cky_crf import Full_CKY_CRF | ||
from .semirings import ( | ||
LogSemiring, | ||
MaxSemiring, | ||
|
@@ -17,6 +18,7 @@ | |
KMaxSemiring, | ||
StdSemiring, | ||
GumbelCRFSemiring, | ||
ValueExpectationSemiring, | ||
) | ||
|
||
|
||
|
@@ -179,11 +181,49 @@ def marginals(self): | |
|
||
@lazy_property | ||
def count(self): | ||
"Compute the log-partition function." | ||
"Compute the total number of structures in the CRF support set." | ||
ones = torch.ones_like(self.log_potentials) | ||
ones[self.log_potentials.eq(-float("inf"))] = 0 | ||
return self._struct(StdSemiring).sum(ones, self.lengths) | ||
|
||
def expected_value(self, values): | ||
""" | ||
Compute expectated value for distribution :math:`E_z[f(z)]` where f decomposes additively over the factors of p_z. | ||
|
||
Params: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be "Parameters:" |
||
* values (*batch_shape x *event_shape, *value_shape): torch.FloatTensor that assigns a value to each part | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's put the types in the first parens, and use :class: |
||
of the structure. `values` can have 0 or more training dimensions in addition to the `event_shape`, | ||
which allows for computing the expected value of, say, a vector valued function | ||
(or a vector of scalar functions). | ||
Returns: | ||
expected value (*batch_shape, *value_shape) | ||
""" | ||
# Handle value function dimensionality | ||
phi_shape = self.log_potentials.shape | ||
extra_dims = len(values.shape) - len(phi_shape) | ||
if extra_dims: | ||
# Extra dims get flattened and put in front | ||
out_val_shape = values.shape[len(phi_shape) :] | ||
values = values.reshape(*phi_shape, -1) | ||
values = values.permute([-1] + list(range(len(phi_shape)))) | ||
k = values.shape[0] | ||
else: | ||
out_val_shape = None | ||
k = 1 | ||
|
||
# Compute expected value | ||
val = self._struct(ValueExpectationSemiring(k)).sum( | ||
[self.log_potentials, values], self.lengths | ||
) | ||
|
||
# Reformat dimensions to match input dimensions | ||
val = val.permute(list(range(1, len(val.shape))) + [0]) | ||
if out_val_shape is not None: | ||
val = val.reshape(*val.shape[:-1] + out_val_shape) | ||
else: | ||
val = val.squeeze(-1) | ||
return val | ||
|
||
def gumbel_crf(self, temperature=1.0): | ||
with torch.enable_grad(): | ||
st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals( | ||
|
@@ -204,26 +244,30 @@ def partition(self): | |
"Compute the log-partition function." | ||
return self._struct(LogSemiring).sum(self.log_potentials, self.lengths) | ||
|
||
def sample(self, sample_shape=torch.Size()): | ||
def sample(self, sample_shape=torch.Size(), batch_size=10): | ||
r""" | ||
Compute structured samples from the distribution :math:`z \sim p(z)`. | ||
|
||
Parameters: | ||
sample_shape (int): number of samples | ||
batch_size (int): number of samples to compute at a time | ||
|
||
Returns: | ||
samples (*sample_shape x batch_shape x event_shape*) | ||
""" | ||
assert len(sample_shape) == 1 | ||
nsamples = sample_shape[0] | ||
if type(sample_shape) == int: | ||
nsamples = sample_shape | ||
else: | ||
assert len(sample_shape) == 1 | ||
nsamples = sample_shape[0] | ||
samples = [] | ||
for k in range(nsamples): | ||
if k % 10 == 0: | ||
if k % batch_size == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yeah, sorry this is my fault. 10 is a global constant. Let's put it on MultiSampledSemiring. |
||
sample = self._struct(MultiSampledSemiring).marginals( | ||
self.log_potentials, lengths=self.lengths | ||
) | ||
sample = sample.detach() | ||
tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % 10) + 1) | ||
tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % batch_size) + 1) | ||
samples.append(tmp_sample) | ||
return torch.stack(samples) | ||
|
||
|
@@ -411,6 +455,32 @@ class TreeCRF(StructDistribution): | |
struct = CKY_CRF | ||
|
||
|
||
class FullTreeCRF(StructDistribution): | ||
r""" | ||
Represents a 1st-order span parser with NT nonterminals. Implemented using a | ||
fast CKY algorithm. | ||
|
||
For a description see: | ||
|
||
* Inside-Outside Algorithm, by Michael Collins | ||
|
||
Event shape is of the form: | ||
|
||
Parameters: | ||
log_potentials (tensor) : event_shape (*N x N x N x NT x NT x NT*), e.g. | ||
:math:`\phi(i, j, k, A_i^j \rightarrow B_i^k C_{k+1}^j)` | ||
lengths (long tensor) : batch shape integers for length masking. | ||
|
||
Implementation uses width-batched, forward-pass only | ||
|
||
* Parallel Time: :math:`O(N)` parallel merges. | ||
* Forward Memory: :math:`O(N^2)` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can't be right... isn't the event shape O(N^3) alone? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oops yeah that's from modifying the CKYCRF class |
||
|
||
Compact representation: *N x N x N xNT x NT x NT* long tensor (Same) | ||
""" | ||
struct = Full_CKY_CRF | ||
|
||
|
||
class SentCFG(StructDistribution): | ||
""" | ||
Represents a full generative context-free grammar with | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import torch | ||
from .helpers import _Struct, Chart | ||
from tqdm import tqdm | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Be sure to run |
||
|
||
A, B = 0, 1 | ||
|
||
|
||
class Full_CKY_CRF(_Struct): | ||
def _check_potentials(self, edge, lengths=None): | ||
batch, N, N1, N2, NT, NT1, NT2 = self._get_dimension(edge) | ||
assert ( | ||
N == N1 == N2 and NT == NT1 == NT2 | ||
), f"Want N:{N} == N1:{N1} == N2:{N2} and NT:{NT} == NT1:{NT1} == NT2:{NT2}" | ||
edge = self.semiring.convert(edge) | ||
semiring_shape = edge.shape[:-7] | ||
if lengths is None: | ||
lengths = torch.LongTensor([N] * batch).to(edge.device) | ||
|
||
return edge, semiring_shape, batch, N, NT, lengths | ||
|
||
def logpartition(self, scores, lengths=None, force_grad=False, cache=True): | ||
sr = self.semiring | ||
|
||
# Scores.shape = *sshape, B, N, N, N, NT, NT, NT | ||
# w/ semantics [ *semiring stuff, b, i, j, k, A, B, C] | ||
# where b is batch index, i is left endpoint, j is right endpoint, k is splitpoint, with rule A -> B C | ||
scores, sshape, batch, N, NT, lengths = self._check_potentials(scores, lengths) | ||
sshape, sdims = list(sshape), list(range(len(sshape))) # usually [0] | ||
S, b = len(sdims), batch | ||
|
||
# Initialize data structs | ||
LEFT, RIGHT = 0, 1 | ||
L_DIM, R_DIM = S + 1, S + 2 # one and two to the right of the batch dim | ||
|
||
# Initialize the base cases with scores from diagonal i=j=k, A=B=C | ||
term_scores = ( | ||
scores.diagonal(0, L_DIM, R_DIM) # diag i,j now at dim -1 | ||
.diagonal(0, L_DIM, -1) # diag of k with that gives i=j=k, now at dim -1 | ||
.diagonal(0, -4, -3) # diag of A, B, now at dim -1, ijk moves to -2 | ||
.diagonal(0, -3, -1) # diag of C with that gives A=B=C | ||
) | ||
assert term_scores.shape[S + 1 :] == ( | ||
N, | ||
NT, | ||
), f"{term_scores.shape[S + 1 :]} == {(N, NT)}" | ||
alpha_left = term_scores | ||
alpha_right = term_scores | ||
alphas = [[alpha_left], [alpha_right]] | ||
|
||
# Run vectorized inside alg | ||
for w in range(1, N): | ||
# Scores | ||
# What we want is a tensor with: | ||
# shape: *sshape, batch, (N-w), NT, w, NT, NT | ||
# w/ semantics: [...batch, (i,j=i+w), A, k, B, C] | ||
# where (i,j=i+w) means the diagonal of trees nodes with width w | ||
# Shape: *sshape, batch, N, NT, NT, NT, (N-w) w/ semantics [ ...batch, k, A, B, C, (i,j=i+w)] | ||
score = scores.diagonal(w, L_DIM, R_DIM) # get diagonal scores | ||
|
||
score = score.permute( | ||
sdims + [-6, -1, -4, -5, -3, -2] | ||
) # move diag (-1) dim and head NT (-4) dim to front | ||
score = score[..., :w, :, :] # remove illegal splitpoints | ||
assert score.shape[S:] == ( | ||
batch, | ||
N - w, | ||
NT, | ||
w, | ||
NT, | ||
NT, | ||
), f"{score.shape[S:]} == {(b, N-w, NT, w, NT, NT)}" | ||
|
||
# Sums of left subtrees | ||
# Shape: *sshape, batch, (N-w), w, NT | ||
# where L[..., i, d, B] is the sum of subtrees up to (i,j=(i+d),B) | ||
left = slice(None, N - w) # left indices | ||
L = torch.stack(alphas[LEFT][:w], dim=-2)[..., left, :, :] | ||
|
||
# Sums of right subtrees | ||
# Shape: *sshape, batch, (N-w), w, NT | ||
# where R[..., h, d, C] is the sum of subtrees up to (i=(N-h-d),j=(N-h),C) | ||
right = slice(w, None) # right indices | ||
R = torch.stack(list(reversed(alphas[RIGHT][:w])), dim=-2)[..., right, :, :] | ||
|
||
# Broadcast them both to match missing dims in score | ||
# Left B is duplicated for all head and right symbols A C | ||
L_bcast = L.reshape(list(sshape) + [b, N - w, 1, w, NT, 1]) | ||
|
||
# Right C is duplicated for all head and left symbols A B | ||
R_bcast = R.reshape(list(sshape) + [b, N - w, 1, w, 1, NT]) | ||
|
||
# Now multiply all the scores and sum over k, B, C dimensions (the last three dims) | ||
assert sr.times(score, L_bcast, R_bcast).shape == tuple( | ||
list(sshape) + [b, N - w, NT, w, NT, NT] | ||
) | ||
# sum_prod_w = sr.sum(sr.sum(sr.sum(sr.times(score, L_bcast, R_bcast)))) | ||
sum_prod_w = sr.sum(sr.times(score, L_bcast, R_bcast).reshape(*score.shape[:-3],-1)) | ||
assert sum_prod_w.shape[S:] == ( | ||
b, | ||
N - w, | ||
NT, | ||
), f"{sum_prod_w.shape[S:]} == {(b,N-w, NT)}" | ||
|
||
pad = sr.zero_(torch.ones(sshape + [b, w, NT]).to(sum_prod_w.device)) | ||
sum_prod_w_left = torch.cat([sum_prod_w, pad], dim=-2) | ||
sum_prod_w_right = torch.cat([pad, sum_prod_w], dim=-2) | ||
alphas[LEFT].append(sum_prod_w_left) | ||
alphas[RIGHT].append(sum_prod_w_right) | ||
|
||
final = sr.sum(torch.stack(alphas[LEFT], dim=-2))[ | ||
..., 0, : | ||
] # sum out root symbol | ||
log_Z = final[:, torch.arange(batch), lengths - 1] | ||
return log_Z, [scores] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,15 +24,44 @@ def __setitem__(self, ind, new): | |
|
||
|
||
class _Struct: | ||
"""`_Struct` is base class used to represent the graphical structure of a model. | ||
|
||
Subclasses should implement a `logpartition` method which computes the partition function (under the standard `_BaseSemiring`). | ||
Different `StructDistribution` methods will instantiate the `_Struct` subclasses | ||
""" | ||
|
||
def __init__(self, semiring=LogSemiring): | ||
self.semiring = semiring | ||
|
||
def logpartition(self, scores, lengths=None, force_grad=False): | ||
"""Implement computation equivalent to the computing log partition constant logZ (if self.semiring == `_BaseSemiring`). | ||
|
||
Parameters: | ||
scores (torch.FloatTensor) : log potential scores for each factor of the model. Shape (* x batch size x *event_shape ) | ||
lengths (torch.LongTensor) : = None, lengths of batch padded examples. Shape = ( * x batch size ) | ||
force_grad: bool = False | ||
|
||
Returns: | ||
v (torch.Tensor) : the resulting output of the dynammic program | ||
edges (List[torch.Tensor]): the log edge potentials of the model. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changing this to |
||
When `scores` is already in a log_potential format for the distribution (typical), this will be | ||
[scores], as in `Alignment`, `LinearChain`, `SemiMarkov`, `CKY_CRF`. | ||
An exceptional case is the `CKY` struct, which takes log potential parameters from production rules | ||
for a PCFG, which are by definition independent of position in the sequence. | ||
charts: Optional[List[Chart]] = None, the charts used in computing the dp. They are needed if we want to run the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Going to remove this for simplicity. |
||
"backward" dynamic program and compute things like marginals w/o autograd. | ||
|
||
""" | ||
raise NotImplementedError | ||
|
||
def score(self, potentials, parts, batch_dims=[0]): | ||
score = torch.mul(potentials, parts) | ||
"""Score for entire structure is product of potentials for all activated "parts".""" | ||
score = torch.mul(potentials, parts) # mask potentials by activated "parts" | ||
batch = tuple((score.shape[b] for b in batch_dims)) | ||
return self.semiring.prod(score.view(batch + (-1,))) | ||
|
||
def _bin_length(self, length): | ||
"""Find least upper bound for lengths that is a power of 2. Used in parallel scans.""" | ||
log_N = int(math.ceil(math.log(length, 2))) | ||
bin_N = int(math.pow(2, log_N)) | ||
return log_N, bin_N | ||
|
@@ -92,28 +121,31 @@ def marginals(self, logpotentials, lengths=None, _raw=False): | |
marginals: b x (N-1) x C x C table | ||
|
||
""" | ||
v, edges = self.logpartition(logpotentials, lengths=lengths, force_grad=True) | ||
if _raw: | ||
all_m = [] | ||
for k in range(v.shape[0]): | ||
obj = v[k].sum(dim=0) | ||
|
||
with torch.autograd.enable_grad(): # in case input potentials don't have grads enabled. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cool |
||
v, edges = self.logpartition( | ||
logpotentials, lengths=lengths, force_grad=True | ||
) | ||
if _raw: | ||
all_m = [] | ||
for k in range(v.shape[0]): | ||
obj = v[k].sum(dim=0) | ||
|
||
marg = torch.autograd.grad( | ||
obj, | ||
edges, | ||
create_graph=True, | ||
only_inputs=True, | ||
allow_unused=False, | ||
) | ||
all_m.append(self.semiring.unconvert(self._arrange_marginals(marg))) | ||
return torch.stack(all_m, dim=0) | ||
else: | ||
obj = self.semiring.unconvert(v).sum(dim=0) | ||
marg = torch.autograd.grad( | ||
obj, | ||
edges, | ||
create_graph=True, | ||
only_inputs=True, | ||
allow_unused=False, | ||
obj, edges, create_graph=True, only_inputs=True, allow_unused=False | ||
) | ||
all_m.append(self.semiring.unconvert(self._arrange_marginals(marg))) | ||
return torch.stack(all_m, dim=0) | ||
else: | ||
obj = self.semiring.unconvert(v).sum(dim=0) | ||
marg = torch.autograd.grad( | ||
obj, edges, create_graph=True, only_inputs=True, allow_unused=False | ||
) | ||
a_m = self._arrange_marginals(marg) | ||
return self.semiring.unconvert(a_m) | ||
a_m = self._arrange_marginals(marg) | ||
return self.semiring.unconvert(a_m) | ||
|
||
@staticmethod | ||
def to_parts(spans, extra, lengths=None): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious. Why not just make this the implementation of expected value? It seems just as good and perhaps more efficient.y
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, maybe I'm confused but isn't this enumerating over all possible structures explicitly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh sorry, my comment is confusing.
I think a valid way of computing an expectation over any "part-level value" is to first compute the marginals (.marginals()) and then doing an elementwise mul (.mul) and then summing. Doesn't that give you the same thing as the semiring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh wow, I didn't realize this! I just tested it out and it appears to be more efficient for larger structure sizes. I guess this is due to the fast log semiring implementation? I'll update things to use this approach instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think that is right... I haven't thought about this too much, but my guess is that this is just better on GPU hardware since the expectation is batched at the end. But it seems worth understand when this works. I don't think you can compute Entropy this way? (but I might be wrong)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. I also don't think entropy can be done this way -- I just tested it out and the results didn't match the semiring. I will switch to this implementation in the latest commit and get rid of the value semiring.
Fwiw I ran a quick speed comparison you might be interested in:
Results from running w/ genbmm
Results from running w/o genbmm