forked from harvardnlp/pytorch-struct
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhelpers.py
202 lines (167 loc) · 6.1 KB
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import torch
import math
from .semirings import LogSemiring
from torch.autograd import Function
class Get(torch.autograd.Function):
@staticmethod
def forward(ctx, chart, grad_chart, indices):
ctx.save_for_backward(grad_chart)
out = chart[indices]
ctx.indices = indices
return out
@staticmethod
def backward(ctx, grad_output):
(grad_chart,) = ctx.saved_tensors
grad_chart[ctx.indices] += grad_output
return grad_chart, None, None
class Set(torch.autograd.Function):
@staticmethod
def forward(ctx, chart, indices, vals):
chart[indices] = vals
ctx.indices = indices
return chart
@staticmethod
def backward(ctx, grad_output):
z = grad_output[ctx.indices]
return None, None, z
class Chart:
def __init__(self, size, potentials, semiring, cache=True):
self.data = semiring.zero_(
torch.zeros(
*((semiring.size(),) + size),
dtype=potentials.dtype,
device=potentials.device
)
)
self.grad = self.data.detach().clone().fill_(0.0)
self.cache = cache
def __getitem__(self, ind):
I = slice(None)
if self.cache:
return Get.apply(self.data, self.grad, (I, I) + ind)
else:
return self.data[(I, I) + ind]
def __setitem__(self, ind, new):
I = slice(None)
if self.cache:
self.data = Set.apply(self.data, (I, I) + ind, new)
else:
self.data[(I, I) + ind] = new
def get(self, ind):
return Get.apply(self.data, self.grad, ind)
def set(self, ind, new):
self.data = Set.apply(self.data, ind, new)
class _Struct:
def __init__(self, semiring=LogSemiring):
self.semiring = semiring
def score(self, potentials, parts, batch_dims=[0]):
score = torch.mul(potentials, parts)
batch = tuple((score.shape[b] for b in batch_dims))
return self.semiring.prod(score.view(batch + (-1,)))
def _bin_length(self, length):
log_N = int(math.ceil(math.log(length, 2)))
bin_N = int(math.pow(2, log_N))
return log_N, bin_N
def _get_dimension(self, edge):
if isinstance(edge, list):
for t in edge:
t.requires_grad_(True)
return edge[0].shape
else:
edge.requires_grad_(True)
return edge.shape
def _chart(self, size, potentials, force_grad):
return self._make_chart(1, size, potentials, force_grad)[0]
def _make_chart(self, N, size, potentials, force_grad=False):
return [
(
self.semiring.zero_(
torch.zeros(
*((self.semiring.size(),) + size),
dtype=potentials.dtype,
device=potentials.device
)
).requires_grad_(force_grad and not potentials.requires_grad)
)
for _ in range(N)
]
def sum(self, edge, lengths=None, _autograd=True, _raw=False):
"""
Compute the (semiring) sum over all structures model.
Parameters:
params : generic params (see class)
lengths: None or b long tensor mask
Returns:
v: b tensor of total sum
"""
if (
_autograd
or self.semiring is not LogSemiring
or not hasattr(self, "_dp_backward")
):
v = self._dp(edge, lengths)[0]
if _raw:
return v
return self.semiring.unconvert(v)
else:
v, _, alpha = self._dp(edge, lengths, False)
class DPManual(Function):
@staticmethod
def forward(ctx, input):
return v
@staticmethod
def backward(ctx, grad_v):
marginals = self._dp_backward(edge, lengths, alpha)
return marginals.mul(
grad_v.view((grad_v.shape[0],) + tuple([1] * marginals.dim()))
)
return DPManual.apply(edge)
def marginals(self, edge, lengths=None, _autograd=True, _raw=False):
"""
Compute the marginals of a structured model.
Parameters:
params : generic params (see class)
lengths: None or b long tensor mask
Returns:
marginals: b x (N-1) x C x C table
"""
if (
_autograd
or self.semiring is not LogSemiring
or not hasattr(self, "_dp_backward")
):
with torch.enable_grad(): # allows marginals even when input tensors don't need grad
v, edges, _ = self._dp(
edge, lengths=lengths, force_grad=True, cache=not _raw
)
if _raw:
all_m = []
for k in range(v.shape[0]):
obj = v[k].sum(dim=0)
marg = torch.autograd.grad(
obj,
edges,
create_graph=True,
only_inputs=True,
allow_unused=False,
)
all_m.append(self.semiring.unconvert(self._arrange_marginals(marg)))
return torch.stack(all_m, dim=0)
else:
obj = self.semiring.unconvert(v).sum(dim=0)
marg = torch.autograd.grad(
obj, edges, create_graph=True, only_inputs=True, allow_unused=False
)
a_m = self._arrange_marginals(marg)
return self.semiring.unconvert(a_m)
else:
v, _, alpha = self._dp(edge, lengths=lengths, force_grad=True)
return self._dp_backward(edge, lengths, alpha)
@staticmethod
def to_parts(spans, extra, lengths=None):
return spans
@staticmethod
def from_parts(spans):
return spans, None
def _arrange_marginals(self, marg):
return marg[0]