-
Notifications
You must be signed in to change notification settings - Fork 92
/
Copy pathdistributions.py
541 lines (390 loc) · 16.8 KB
/
distributions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
import torch
from torch.distributions.distribution import Distribution
from torch.distributions.utils import lazy_property
from .linearchain import LinearChain
from .cky import CKY
from .semimarkov import SemiMarkov
from .alignment import Alignment
from .deptree import DepTree, deptree_nonproj, deptree_part
from .cky_crf import CKY_CRF
from .full_cky_crf import Full_CKY_CRF
from .semirings import (
LogSemiring,
MaxSemiring,
EntropySemiring,
CrossEntropySemiring,
KLDivergenceSemiring,
MultiSampledSemiring,
KMaxSemiring,
StdSemiring,
GumbelCRFSemiring,
)
class StructDistribution(Distribution):
r"""
Base structured distribution class.
Dynamic distribution for length N of structures :math:`p(z)`.
Implemented based on gradient identities from:
* Inside-outside and forward-backward algorithms are just backprop :cite:`eisner2016inside`
* Semiring Parsing :cite:`goodman1999semiring`
* First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
Parameters:
log_potentials (tensor, batch_shape x event_shape) : log-potentials :math:`\phi`
lengths (long tensor, batch_shape) : integers for length masking
"""
def __init__(self, log_potentials, lengths=None, args={}):
batch_shape = log_potentials.shape[:1]
event_shape = log_potentials.shape[1:]
self.log_potentials = log_potentials
self.lengths = lengths
self.args = args
super().__init__(batch_shape=batch_shape, event_shape=event_shape)
def _new(self, *args, **kwargs):
return self._param.new(*args, **kwargs)
def log_prob(self, value):
"""
Compute log probability over values :math:`p(z)`.
Parameters:
value (tensor): One-hot events (*sample_shape x batch_shape x event_shape*)
Returns:
log_probs (*sample_shape x batch_shape*)
"""
d = value.dim()
batch_dims = range(d - len(self.event_shape))
v = self._struct().score(
self.log_potentials,
value.type_as(self.log_potentials),
batch_dims=batch_dims,
)
return v - self.partition
@lazy_property
def entropy(self):
"""
Compute entropy for distribution :math:`H[z]`.
Returns:
entropy (*batch_shape*)
"""
return self._struct(EntropySemiring).sum(self.log_potentials, self.lengths)
def cross_entropy(self, other):
"""
Compute cross-entropy for distribution p(self) and q(other) :math:`H[p, q]`.
Parameters:
other : Comparison distribution
Returns:
cross entropy (*batch_shape*)
"""
return self._struct(CrossEntropySemiring).sum([self.log_potentials, other.log_potentials], self.lengths)
def kl(self, other):
"""
Compute KL-divergence for distribution p(self) and q(other) :math:`KL[p || q] = H[p, q] - H[p]`.
Parameters:
other : Comparison distribution
Returns:
cross entropy (*batch_shape*)
"""
return self._struct(KLDivergenceSemiring).sum([self.log_potentials, other.log_potentials], self.lengths)
@lazy_property
def max(self):
r"""
Compute an max for distribution :math:`\max p(z)`.
Returns:
max (*batch_shape*)
"""
return self._struct(MaxSemiring).sum(self.log_potentials, self.lengths)
@lazy_property
def argmax(self):
r"""
Compute an argmax for distribution :math:`\arg\max p(z)`.
Returns:
argmax (*batch_shape x event_shape*)
"""
return self._struct(MaxSemiring).marginals(self.log_potentials, self.lengths)
def kmax(self, k):
r"""
Compute the k-max for distribution :math:`k\max p(z)`.
Parameters :
k : Number of solutions to return
Returns:
kmax (*k x batch_shape*)
"""
with torch.enable_grad():
return self._struct(KMaxSemiring(k)).sum(self.log_potentials, self.lengths, _raw=True)
def topk(self, k):
r"""
Compute the k-argmax for distribution :math:`k\max p(z)`.
Parameters :
k : Number of solutions to return
Returns:
kmax (*k x batch_shape x event_shape*)
"""
with torch.enable_grad():
return self._struct(KMaxSemiring(k)).marginals(self.log_potentials, self.lengths, _raw=True)
@lazy_property
def mode(self):
return self.argmax
@lazy_property
def marginals(self):
"""
Compute marginals for distribution :math:`p(z_t)`.
Can be used in higher-order calculations, i.e.
*
Returns:
marginals (*batch_shape x event_shape*)
"""
return self._struct(LogSemiring).marginals(self.log_potentials, self.lengths)
@lazy_property
def count(self):
"Compute the total number of structures in the CRF support set."
ones = torch.ones_like(self.log_potentials)
ones[self.log_potentials.eq(-float("inf"))] = 0
return self._struct(StdSemiring).sum(ones, self.lengths)
def expected_value(self, values):
"""
Compute expectated value for distribution :math:`E_z[f(z)]` where f decomposes additively over the factors of p_z.
Parameters:
values (:class: torch.FloatTensor): (*batch_shape x *event_shape x *value_shape), assigns a value to each
part of the structure. `values` can have 0 or more trailing dimensions in addition to the `event_shape`,
which allows for computing the expected value of, say, a vector valued function.
Returns:
expected value (*batch_shape, *value_shape)
"""
# For these "part-level" expectations, this can be computed by multiplying the marginals element-wise
# on the values and summing. This is faster than the semiring because of FastLogSemiring.
# (w/o genbmm it's about the same.)
ps = self.marginals
ps_bcast = ps.reshape(*ps.shape, *((1,) * (len(values.shape) - len(ps.shape))))
return ps_bcast.mul(values).reshape(ps.shape[0], -1, *values.shape[len(ps.shape) :]).sum(1)
def gumbel_crf(self, temperature=1.0):
with torch.enable_grad():
st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(self.log_potentials, self.lengths)
return st_gumbel
# @constraints.dependent_property
# def support(self):
# pass
# @property
# def param_shape(self):
# return self._param.size()
@lazy_property
def partition(self):
"Compute the log-partition function."
return self._struct(LogSemiring).sum(self.log_potentials, self.lengths)
def sample(self, sample_shape=torch.Size()):
r"""
Compute structured samples from the distribution :math:`z \sim p(z)`.
Parameters:
sample_shape (int): number of samples
Returns:
samples (*sample_shape x batch_shape x event_shape*)
"""
batch_size = MultiSampledSemiring.batch_size
if type(sample_shape) == int:
nsamples = sample_shape
else:
assert len(sample_shape) == 1
nsamples = sample_shape[0]
samples = []
for k in range(nsamples):
if k % batch_size == 0:
sample = self._struct(MultiSampledSemiring).marginals(self.log_potentials, lengths=self.lengths)
sample = sample.detach()
tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % batch_size) + 1)
samples.append(tmp_sample)
return torch.stack(samples)
def to_event(self, sequence, extra, lengths=None):
"Convert simple representation to event."
return self.struct.to_parts(sequence, extra, lengths=None)
def from_event(self, event):
"Convert event to simple representation."
return self.struct.from_parts(event)
def _struct(self, sr=None):
return self.struct(sr if sr is not None else LogSemiring)
class LinearChainCRF(StructDistribution):
r"""
Represents structured linear-chain CRFs with C classes.
For reference see:
* An introduction to conditional random fields :cite:`sutton2012introduction`
Example application:
* Bidirectional LSTM-CRF Models for Sequence Tagging :cite:`huang2015bidirectional`
Event shape is of the form:
Parameters:
log_potentials (tensor) : event shape (*(N-1) x C x C*) e.g.
:math:`\phi(n, z_{n+1}, z_{n})`
lengths (long tensor) : batch_shape integers for length masking.
Compact representation: N long tensor in [0, ..., C-1]
Implementation uses linear-scan, forward-pass only.
* Parallel Time: :math:`O(\log(N))` parallel merges.
* Forward Memory: :math:`O(N \log(N) C^2)`
"""
struct = LinearChain
class AlignmentCRF(StructDistribution):
r"""
Represents basic alignment algorithm, i.e. dynamic-time warping, Needleman-Wunsch, and Smith-Waterman.
Event shape is of the form:
Parameters:
log_potentials (tensor) : event_shape (*N x M x 3*), e.g.
:math:`\phi(i, j, op)`
Ops are 0 -> j-1, 1->i-1,j-1, and 2->i-1
local (bool): if true computes local alignment (Smith-Waterman), else Needleman-Wunsch
max_gap (int or None): the maximum gap to allow in the dynamic program
lengths (long tensor) : batch shape integers for length masking.
Implementation uses convolution and linear-scan. Use max_gap for long sequences.
* Parallel Time: :math:`O(\log (M + N))` parallel merges.
* Forward Memory: :math:`O((M+N)^2)`
"""
struct = Alignment
def __init__(self, log_potentials, local=False, lengths=None, max_gap=None):
self.local = local
self.max_gap = max_gap
super().__init__(log_potentials, lengths)
def _struct(self, sr=None):
return self.struct(sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap)
class HMM(StructDistribution):
r"""
Represents hidden-markov smoothing with C hidden states.
Event shape is of the form:
Parameters:
transition (tensor): log-probabilities (*C X C*) :math:`p(z_n| z_n-1)`
emission (tensor): log-probabilities (*V x C*) :math:`p(x_n| z_n)`
init (tensor): log-probabilities (*C*) :math:`p(z_1)`
observations (long tensor): indices (*batch x N*) between [0, V-1]
Compact representation: N long tensor in [0, ..., C-1]
Implemented as a special case of linear chain CRF.
"""
def __init__(self, transition, emission, init, observations, lengths=None):
log_potentials = HMM.struct.hmm(transition, emission, init, observations)
super().__init__(log_potentials, lengths)
struct = LinearChain
class SemiMarkovCRF(StructDistribution):
r"""
Represents a semi-markov or segmental CRF with C classes of max width K
Event shape is of the form:
Parameters:
log_potentials : event shape (*N x K x C x C*) e.g.
:math:`\phi(n, k, z_{n+1}, z_{n})`
lengths (long tensor) : batch shape integers for length masking.
Compact representation: N long tensor in [-1, 0, ..., C-1]
Implementation uses linear-scan, forward-pass only.
* Parallel Time: :math:`O(\log(N))` parallel merges.
* Forward Memory: :math:`O(N \log(N) C^2 K^2)`
"""
struct = SemiMarkov
class DependencyCRF(StructDistribution):
r"""
Represents a projective dependency CRF.
Reference:
* Bilexical grammars and their cubic-time parsing algorithms :cite:`eisner2000bilexical`
Event shape is of the form:
Parameters:
log_potentials (tensor) : event shape (*N x N*) head, child or (*N x N x L*) head,
child, labels with arc scores with root scores on diagonal
e.g. :math:`\phi(i, j)` where :math:`\phi(i, i)` is (root, i).
lengths (long tensor) : batch shape integers for length masking.
Compact representation: N long tensor in [0, .. N] (indexing is +1)
Implementation uses linear-scan, forward-pass only.
* Parallel Time: :math:`O(N)` parallel merges.
* Forward Memory: :math:`O(N \log(N) C^2 K^2)`
"""
def __init__(self, log_potentials, lengths=None, args={}, multiroot=True):
super(DependencyCRF, self).__init__(log_potentials, lengths, args)
self.struct = DepTree
setattr(self.struct, "multiroot", multiroot)
class TreeCRF(StructDistribution):
r"""
Represents a 0th-order span parser with NT nonterminals. Implemented using a
fast CKY algorithm.
For example usage see:
* A Minimal Span-Based Neural Constituency Parser :cite:`stern2017minimal`
Event shape is of the form:
Parameters:
log_potentials (tensor) : event_shape (*N x N x NT*), e.g.
:math:`\phi(i, j, nt)`
lengths (long tensor) : batch shape integers for length masking.
Implementation uses width-batched, forward-pass only
* Parallel Time: :math:`O(N)` parallel merges.
* Forward Memory: :math:`O(N^2)`
Compact representation: *N x N x NT* long tensor (Same)
"""
struct = CKY_CRF
class FullTreeCRF(StructDistribution):
r"""
Represents a 1st-order span parser with NT nonterminals. Implemented using a
fast CKY algorithm.
For a description see:
* Inside-Outside Algorithm, by Michael Collins
Event shape is of the form:
Parameters:
log_potentials (tensor) : event_shape (*N x N x N x NT x NT x NT*), e.g.
:math:`\phi(i, j, k, A_i^j \rightarrow B_i^k C_{k+1}^j)`
lengths (long tensor) : batch shape integers for length masking.
Implementation uses width-batched, forward-pass only
* Parallel Time: :math:`O(N)` parallel merges.
* Forward Memory: :math:`O(N^3 NT^3)`
Compact representation: *N x N x N x NT x NT x NT* long tensor (Same)
"""
struct = Full_CKY_CRF
class SentCFG(StructDistribution):
"""
Represents a full generative context-free grammar with
non-terminals NT and terminals T.
Event shape is of the form:
Parameters:
log_potentials (tuple) : event tuple with event shapes
terms (*N x T*)
rules (*NT x (NT+T) x (NT+T)*)
root (*NT*)
lengths (long tensor) : batch shape integers for length masking.
Implementation uses width-batched, forward-pass only
* Parallel Time: :math:`O(N)` parallel merges.
* Forward Memory: :math:`O(N^2 (NT+T))`
Compact representation: (*N x N x NT*) long tensor
"""
struct = CKY
def __init__(self, log_potentials, lengths=None):
batch_shape = log_potentials[0].shape[:1]
event_shape = log_potentials[0].shape[1:]
self.log_potentials = log_potentials
self.lengths = lengths
super(StructDistribution, self).__init__(batch_shape=batch_shape, event_shape=event_shape)
class NonProjectiveDependencyCRF(StructDistribution):
r"""
Represents a non-projective dependency CRF.
For references see:
* Non-projective dependency parsing using spanning tree algorithms :cite:`mcdonald2005non`
* Structured prediction models via the matrix-tree theorem :cite:`koo2007structured`
Event shape is of the form:
Parameters:
log_potentials (tensor) : event shape (*N x N*) head, child with
arc scores with root scores on diagonal e.g.
:math:`\phi(i, j)` where :math:`\phi(i, i)` is (root, i).
Compact representation: N long tensor in [0, .. N] (indexing is +1)
Note: Does not currently implement argmax (Chiu-Liu) or sampling.
"""
def __init__(self, log_potentials, lengths=None, args={}, multiroot=False):
super(NonProjectiveDependencyCRF, self).__init__(log_potentials, lengths, args)
self.multiroot = multiroot
@lazy_property
def marginals(self):
"""
Compute marginals for distribution :math:`p(z_t)`.
Algorithm is :math:`O(N^3)` but very fast on batched GPU.
Returns:
marginals (*batch_shape x event_shape*)
"""
return deptree_nonproj(self.log_potentials, self.multiroot, self.lengths)
def sample(self, sample_shape=torch.Size()):
raise NotImplementedError()
@lazy_property
def partition(self):
"""
Compute the partition function.
"""
return deptree_part(self.log_potentials, self.multiroot, self.lengths)
@lazy_property
def argmax(self):
"""
Use Chiu-Liu Algorithm. :math:`O(N^2)`
(Currently not implemented)
"""
pass
@lazy_property
def entropy(self):
pass