Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add incremental linear-chain CRF #119

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions torch_struct/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch.distributions.distribution import Distribution
from torch.distributions.utils import lazy_property
from .linearchain import LinearChain
from .incr_linearchain import IncrementLinearChainFwBw
from .cky import CKY
from .semimarkov import SemiMarkov
from .alignment import Alignment
Expand Down Expand Up @@ -240,6 +241,10 @@ def _struct(self, sr=None):
return self.struct(sr if sr is not None else LogSemiring)


class ARLinearChainCRF(StructDistribution):

struct = IncrementLinearChainFwBw

class LinearChainCRF(StructDistribution):
r"""
Represents structured linear-chain CRFs with C classes.
Expand Down
125 changes: 125 additions & 0 deletions torch_struct/incr_linearchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@

import torch
from torch_struct.helpers import _Struct
from einops import rearrange, repeat
import math

class IncrementLinearChainFwBw(_Struct):

def _check_potentials(self, edge, lengths=None):
batch, N_1, C, C2 = self._get_dimension(edge)
edge = self.semiring.convert(edge)
N = N_1 + 1

if lengths is None:
lengths = torch.LongTensor([N] * batch).to(edge.device)
# pass
else:
assert max(lengths) <= N, "Length longer than edge scores"
assert max(lengths) == N, "One length must be at least N"
assert C == C2, "Transition shape doesn't match"
return edge, batch, N, C, lengths

def marginals(self, log_potentials, lengths=None, force_grad=True):

semiring = self.semiring
ssize = semiring.size()
log_potentials, batch, N, C, lengths = self._check_potentials(
log_potentials, lengths
)
N1 = N - 1
log_N, bin_N = self._bin_length(N - 1)
chart = self._chart((batch, N1, C, C), log_potentials, force_grad)
#print(chart.size())

init = torch.zeros(ssize, batch, N1, C, C).bool().to(log_potentials.device)
init.diagonal(0, -2, -1).fill_(True)
chart = semiring.fill(chart, init, semiring.one)

big = torch.zeros(
ssize,
batch,
N1,
C,
C,
dtype=log_potentials.dtype,
device=log_potentials.device,
)
big[:, :, : N1] = log_potentials
c = chart[:, :, :].view(ssize, batch * N1, C, C)
lp = big[:, :, :].view(ssize, batch * N1, C, C)
mask = torch.arange(N1).view(1, N1).expand(batch, N1).type_as(c)
mask = mask >= (lengths - 1).view(batch, 1)
mask = mask.view(batch * N1, 1, 1).to(lp.device)
lp.data[:] = semiring.fill(lp.data, mask, semiring.zero)
c.data[:] = semiring.fill(c.data, ~mask, semiring.zero)

c[:] = semiring.sum(torch.stack([c.data, lp], dim=-1))
#print(chart)

# initialize interval chart
interval_chart = self._chart((batch, N1, N1, C, C), log_potentials, force_grad)
init = torch.zeros(ssize, batch, N1, N1, C, C).bool().to(log_potentials.device)
init.diagonal(0, -2, -1).fill_(True)
interval_chart = semiring.fill(interval_chart, init, semiring.one)

# first row of interval chart
interval_chart[:, :, 0] = chart

#print(interval_chart.size())
#print(interval_chart[:, :, 0])

for n in range(0, log_N):

index = int(math.pow(2, n))

# special case for the last iteration
height = min(index, N1 - index)

l2 = interval_chart[:, :, : height, index :].contiguous()
l1 = interval_chart[:, :, index - 1, : N1 - index].unsqueeze(2).expand_as(l2).contiguous()
#assert(l1.shape == l2.shape), (l1.shape, l2.shape)
#print(f"n: {n}, N1: {N1}, logN: {log_N}")
#print(l2.size())
#rint(l1.size())
#_l2 = l2.view(-1, l2.size(-2), l2.size(-1))
#_l1 = l1.view(-1, l1.size(-2), l1.size(-1))
l_update = semiring.matmul(l2, l1)#.contiguous()#.clone(memory_format = torch.contiguous_format)
#print(_l_update.size())
#l_uptate = _l_update.view(*l2.shape)
#fill_mask = torch.zeros()
interval_chart[:, :, index : index + height, : N1 - index] = l_uptate
#interval_chart = interval_chart.contiguous()

#print(semiring.sum(semiring.sum(interval_chart)))


# calculate marginal using fw-bw property
# p(z_i=c | x_1:t) = exp( l_1:i(c,.) + l_i+1:t(.,c) - l_1:t(.,.) ) = exp( L_prefix + L_interval - L_total)

#mask_value = - float("inf")
tril_mask = torch.tril(torch.ones(N1, N1)).to(interval_chart.device).unsqueeze(0).unsqueeze(0).unsqueeze(-1)

L_p = semiring.sum(interval_chart[:, :, :, 0].contiguous(), dim=-1).unsqueeze(2).expand(*interval_chart.shape[:-1])
#print(L_p.size(), tril_mask.size())
L_p = L_p * tril_mask.float()
#print(L_p)

L_i = semiring.sum(interval_chart, dim=-2)
index_matrix = torch.stack([torch.arange(-i, N1 - i) % (N1) for i in range(N1)], dim=1).to(L_i.device)
index_matrix = repeat(index_matrix, "i j -> () () i j ()").expand_as(L_i)
# re-index into tril
#print(L_i)
L_i = torch.gather(L_i, 2, index_matrix)
L_i = L_i * tril_mask.float()
# shift
L_i_final = torch.zeros_like(L_p, device=L_p.device)
L_i_final[:, :, :, :-1] = L_i[:, :, :, 1:]
#print(L_i_final)

L_t = semiring.sum(semiring.sum(interval_chart[:, :, :, 0].contiguous())).unsqueeze(3).expand(*interval_chart.shape[:-2])
# don't mask denominator
#print(L_t)

#print(L_p.size(), L_i_final.size(), L_t.size())
return ((L_p + L_i_final - L_t.unsqueeze(-1)).exp() * tril_mask.float()).squeeze(0)
9 changes: 7 additions & 2 deletions torch_struct/semirings/semirings.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,13 @@ class LogSemiring(_BaseLog):

@classmethod
def matmul(cls, a, b):
if has_genbmm and isinstance(a, genbmm.BandedMatrix):
return b.multiply_log(a.transpose())
if has_genbmm and a.device != "cpu":
if isinstance(a, genbmm.BandedMatrix):
return b.multiply_log(a.transpose())
return genbmm.logbmm(
a.contiguous(),#.view(-1, a.size(-1), a.size(-1)),
b.contiguous()#.view(-1, b.size(-1), b.size(-1))
)#.view(*a.shape)
else:
return _BaseLog.matmul(a, b)

Expand Down