diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 0e5f8ad3..8cf847c6 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -76,6 +76,18 @@ def test_simple(data, seed): dist.kmax(5) dist.count + val_func = torch.rand(*vals.shape, 10) + E_val = dist.expected_value(val_func) + struct_vals = ( + edges.unsqueeze(-1) + .mul(val_func.unsqueeze(0)) + .reshape(*edges.shape[:2], -1, val_func.shape[-1]) + .sum(2) + ) + assert torch.isclose( + E_val, log_probs.exp().unsqueeze(-1).mul(struct_vals).sum(0) + ).all(), "Efficient expected value not equal to enumeration" + @given(data(), integers(min_value=1, max_value=20)) @settings(max_examples=50, deadline=None) diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index 2065b250..d8c0d3d2 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -7,6 +7,7 @@ from .alignment import Alignment from .deptree import DepTree, deptree_nonproj, deptree_part from .cky_crf import CKY_CRF +from .full_cky_crf import Full_CKY_CRF from .semirings import ( LogSemiring, MaxSemiring, @@ -91,9 +92,7 @@ def cross_entropy(self, other): cross entropy (*batch_shape*) """ - return self._struct(CrossEntropySemiring).sum( - [self.log_potentials, other.log_potentials], self.lengths - ) + return self._struct(CrossEntropySemiring).sum([self.log_potentials, other.log_potentials], self.lengths) def kl(self, other): """ @@ -105,9 +104,7 @@ def kl(self, other): Returns: cross entropy (*batch_shape*) """ - return self._struct(KLDivergenceSemiring).sum( - [self.log_potentials, other.log_potentials], self.lengths - ) + return self._struct(KLDivergenceSemiring).sum([self.log_potentials, other.log_potentials], self.lengths) @lazy_property def max(self): @@ -140,9 +137,7 @@ def kmax(self, k): kmax (*k x batch_shape*) """ with torch.enable_grad(): - return self._struct(KMaxSemiring(k)).sum( - self.log_potentials, self.lengths, _raw=True - ) + return self._struct(KMaxSemiring(k)).sum(self.log_potentials, self.lengths, _raw=True) def topk(self, k): r""" @@ -155,9 +150,7 @@ def topk(self, k): kmax (*k x batch_shape x event_shape*) """ with torch.enable_grad(): - return self._struct(KMaxSemiring(k)).marginals( - self.log_potentials, self.lengths, _raw=True - ) + return self._struct(KMaxSemiring(k)).marginals(self.log_potentials, self.lengths, _raw=True) @lazy_property def mode(self): @@ -179,16 +172,33 @@ def marginals(self): @lazy_property def count(self): - "Compute the log-partition function." + "Compute the total number of structures in the CRF support set." ones = torch.ones_like(self.log_potentials) ones[self.log_potentials.eq(-float("inf"))] = 0 return self._struct(StdSemiring).sum(ones, self.lengths) + def expected_value(self, values): + """ + Compute expectated value for distribution :math:`E_z[f(z)]` where f decomposes additively over the factors of p_z. + + Parameters: + values (:class: torch.FloatTensor): (*batch_shape x *event_shape x *value_shape), assigns a value to each + part of the structure. `values` can have 0 or more trailing dimensions in addition to the `event_shape`, + which allows for computing the expected value of, say, a vector valued function. + + Returns: + expected value (*batch_shape, *value_shape) + """ + # For these "part-level" expectations, this can be computed by multiplying the marginals element-wise + # on the values and summing. This is faster than the semiring because of FastLogSemiring. + # (w/o genbmm it's about the same.) + ps = self.marginals + ps_bcast = ps.reshape(*ps.shape, *((1,) * (len(values.shape) - len(ps.shape)))) + return ps_bcast.mul(values).reshape(ps.shape[0], -1, *values.shape[len(ps.shape) :]).sum(1) + def gumbel_crf(self, temperature=1.0): with torch.enable_grad(): - st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals( - self.log_potentials, self.lengths - ) + st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(self.log_potentials, self.lengths) return st_gumbel # @constraints.dependent_property @@ -214,16 +224,18 @@ def sample(self, sample_shape=torch.Size()): Returns: samples (*sample_shape x batch_shape x event_shape*) """ - assert len(sample_shape) == 1 - nsamples = sample_shape[0] + batch_size = MultiSampledSemiring.batch_size + if type(sample_shape) == int: + nsamples = sample_shape + else: + assert len(sample_shape) == 1 + nsamples = sample_shape[0] samples = [] for k in range(nsamples): - if k % 10 == 0: - sample = self._struct(MultiSampledSemiring).marginals( - self.log_potentials, lengths=self.lengths - ) + if k % batch_size == 0: + sample = self._struct(MultiSampledSemiring).marginals(self.log_potentials, lengths=self.lengths) sample = sample.detach() - tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % 10) + 1) + tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % batch_size) + 1) samples.append(tmp_sample) return torch.stack(samples) @@ -301,9 +313,7 @@ def __init__(self, log_potentials, local=False, lengths=None, max_gap=None): super().__init__(log_potentials, lengths) def _struct(self, sr=None): - return self.struct( - sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap - ) + return self.struct(sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap) class HMM(StructDistribution): @@ -411,6 +421,32 @@ class TreeCRF(StructDistribution): struct = CKY_CRF +class FullTreeCRF(StructDistribution): + r""" + Represents a 1st-order span parser with NT nonterminals. Implemented using a + fast CKY algorithm. + + For a description see: + + * Inside-Outside Algorithm, by Michael Collins + + Event shape is of the form: + + Parameters: + log_potentials (tensor) : event_shape (*N x N x N x NT x NT x NT*), e.g. + :math:`\phi(i, j, k, A_i^j \rightarrow B_i^k C_{k+1}^j)` + lengths (long tensor) : batch shape integers for length masking. + + Implementation uses width-batched, forward-pass only + + * Parallel Time: :math:`O(N)` parallel merges. + * Forward Memory: :math:`O(N^3 NT^3)` + + Compact representation: *N x N x N x NT x NT x NT* long tensor (Same) + """ + struct = Full_CKY_CRF + + class SentCFG(StructDistribution): """ Represents a full generative context-free grammar with @@ -440,9 +476,7 @@ def __init__(self, log_potentials, lengths=None): event_shape = log_potentials[0].shape[1:] self.log_potentials = log_potentials self.lengths = lengths - super(StructDistribution, self).__init__( - batch_shape=batch_shape, event_shape=event_shape - ) + super(StructDistribution, self).__init__(batch_shape=batch_shape, event_shape=event_shape) class NonProjectiveDependencyCRF(StructDistribution): diff --git a/torch_struct/full_cky_crf.py b/torch_struct/full_cky_crf.py new file mode 100644 index 00000000..d42a5105 --- /dev/null +++ b/torch_struct/full_cky_crf.py @@ -0,0 +1,115 @@ +import torch +from .helpers import _Struct + +A, B = 0, 1 + + +class Full_CKY_CRF(_Struct): + def _check_potentials(self, edge, lengths=None): + batch, N, N1, N2, NT, NT1, NT2 = self._get_dimension(edge) + assert ( + N == N1 == N2 and NT == NT1 == NT2 + ), f"Want N:{N} == N1:{N1} == N2:{N2} and NT:{NT} == NT1:{NT1} == NT2:{NT2}" + edge = self.semiring.convert(edge) + semiring_shape = edge.shape[:-7] + if lengths is None: + lengths = torch.LongTensor([N] * batch).to(edge.device) + + return edge, semiring_shape, batch, N, NT, lengths + + def logpartition(self, scores, lengths=None, force_grad=False, cache=True): + sr = self.semiring + + # Scores.shape = *sshape, B, N, N, N, NT, NT, NT + # w/ semantics [ *semiring stuff, b, i, j, k, A, B, C] + # where b is batch index, i is left endpoint, j is right endpoint, k is splitpoint, with rule A -> B C + scores, sshape, batch, N, NT, lengths = self._check_potentials(scores, lengths) + sshape, sdims = list(sshape), list(range(len(sshape))) # usually [0] + S, b = len(sdims), batch + + # Initialize data structs + LEFT, RIGHT = 0, 1 + L_DIM, R_DIM = S + 1, S + 2 # one and two to the right of the batch dim + + # Initialize the base cases with scores from diagonal i=j=k, A=B=C + term_scores = ( + scores.diagonal(0, L_DIM, R_DIM) # diag i,j now at dim -1 + .diagonal(0, L_DIM, -1) # diag of k with that gives i=j=k, now at dim -1 + .diagonal(0, -4, -3) # diag of A, B, now at dim -1, ijk moves to -2 + .diagonal(0, -3, -1) # diag of C with that gives A=B=C + ) + assert term_scores.shape[S + 1 :] == ( + N, + NT, + ), f"{term_scores.shape[S + 1 :]} == {(N, NT)}" + alpha_left = term_scores + alpha_right = term_scores + alphas = [[alpha_left], [alpha_right]] + + # Run vectorized inside alg + for w in range(1, N): + # Scores + # What we want is a tensor with: + # shape: *sshape, batch, (N-w), NT, w, NT, NT + # w/ semantics: [...batch, (i,j=i+w), A, k, B, C] + # where (i,j=i+w) means the diagonal of trees nodes with width w + # Shape: *sshape, batch, N, NT, NT, NT, (N-w) w/ semantics [ ...batch, k, A, B, C, (i,j=i+w)] + score = scores.diagonal(w, L_DIM, R_DIM) # get diagonal scores + + score = score.permute( + sdims + [-6, -1, -4, -5, -3, -2] + ) # move diag (-1) dim and head NT (-4) dim to front + score = score[..., :w, :, :] # remove illegal splitpoints + assert score.shape[S:] == ( + batch, + N - w, + NT, + w, + NT, + NT, + ), f"{score.shape[S:]} == {(b, N-w, NT, w, NT, NT)}" + + # Sums of left subtrees + # Shape: *sshape, batch, (N-w), w, NT + # where L[..., i, d, B] is the sum of subtrees up to (i,j=(i+d),B) + left = slice(None, N - w) # left indices + L = torch.stack(alphas[LEFT][:w], dim=-2)[..., left, :, :] + + # Sums of right subtrees + # Shape: *sshape, batch, (N-w), w, NT + # where R[..., h, d, C] is the sum of subtrees up to (i=(N-h-d),j=(N-h),C) + right = slice(w, None) # right indices + R = torch.stack(list(reversed(alphas[RIGHT][:w])), dim=-2)[..., right, :, :] + + # Broadcast them both to match missing dims in score + # Left B is duplicated for all head and right symbols A C + L_bcast = L.reshape(list(sshape) + [b, N - w, 1, w, NT, 1]) + + # Right C is duplicated for all head and left symbols A B + R_bcast = R.reshape(list(sshape) + [b, N - w, 1, w, 1, NT]) + + # Now multiply all the scores and sum over k, B, C dimensions (the last three dims) + assert sr.times(score, L_bcast, R_bcast).shape == tuple( + list(sshape) + [b, N - w, NT, w, NT, NT] + ) + # sum_prod_w = sr.sum(sr.sum(sr.sum(sr.times(score, L_bcast, R_bcast)))) + sum_prod_w = sr.sum( + sr.times(score, L_bcast, R_bcast).reshape(*score.shape[:-3], -1) + ) + assert sum_prod_w.shape[S:] == ( + b, + N - w, + NT, + ), f"{sum_prod_w.shape[S:]} == {(b,N-w, NT)}" + + pad = sr.zero_(torch.ones(sshape + [b, w, NT]).to(sum_prod_w.device)) + sum_prod_w_left = torch.cat([sum_prod_w, pad], dim=-2) + sum_prod_w_right = torch.cat([pad, sum_prod_w], dim=-2) + alphas[LEFT].append(sum_prod_w_left) + alphas[RIGHT].append(sum_prod_w_right) + + final = sr.sum(torch.stack(alphas[LEFT], dim=-2))[ + ..., 0, : + ] # sum out root symbol + log_Z = final[:, torch.arange(batch), lengths - 1] + return log_Z, [scores] diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index 3b7c0a1a..58116a36 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -6,11 +6,7 @@ class Chart: def __init__(self, size, potentials, semiring): self.data = semiring.zero_( - torch.zeros( - *((semiring.size(),) + size), - dtype=potentials.dtype, - device=potentials.device - ) + torch.zeros(*((semiring.size(),) + size), dtype=potentials.dtype, device=potentials.device) ) self.grad = self.data.detach().clone().fill_(0.0) @@ -24,15 +20,43 @@ def __setitem__(self, ind, new): class _Struct: + """`_Struct` is base class used to represent the graphical structure of a model. + + Subclasses should implement a `logpartition` method which computes the partition function (under the standard `_BaseSemiring`). + Different `StructDistribution` methods will instantiate the `_Struct` subclasses + """ + def __init__(self, semiring=LogSemiring): self.semiring = semiring + def logpartition(self, scores, lengths=None, force_grad=False): + """Implement computation equivalent to the computing log partition constant logZ (if self.semiring == `_BaseSemiring`). + + Parameters: + scores (torch.FloatTensor) : log potential scores for each factor of the model. Shape (* x batch size x *event_shape ) + lengths (torch.LongTensor) : = None, lengths of batch padded examples. Shape = ( * x batch size ) + force_grad: bool = False + + Returns: + v (torch.Tensor) : the resulting output of the dynammic program + logpotentials (List[torch.Tensor]): the log edge potentials of the model. + When `scores` is already in a log_potential format for the distribution (typical), this will be + [scores], as in `Alignment`, `LinearChain`, `SemiMarkov`, `CKY_CRF`. + An exceptional case is the `CKY` struct, which takes log potential parameters from production rules + for a PCFG, which are by definition independent of position in the sequence. + + # noqa: DAR401, DAR202 + """ + raise NotImplementedError() + def score(self, potentials, parts, batch_dims=[0]): - score = torch.mul(potentials, parts) + """Score for entire structure is product of potentials for all activated "parts".""" + score = torch.mul(potentials, parts) # mask potentials by activated "parts" batch = tuple((score.shape[b] for b in batch_dims)) return self.semiring.prod(score.view(batch + (-1,))) def _bin_length(self, length): + """Find least upper bound for lengths that is a power of 2. Used in parallel scans.""" log_N = int(math.ceil(math.log(length, 2))) bin_N = int(math.pow(2, log_N)) return log_N, bin_N @@ -53,11 +77,7 @@ def _make_chart(self, N, size, potentials, force_grad=False): return [ ( self.semiring.zero_( - torch.zeros( - *((self.semiring.size(),) + size), - dtype=potentials.dtype, - device=potentials.device - ) + torch.zeros(*((self.semiring.size(),) + size), dtype=potentials.dtype, device=potentials.device) ).requires_grad_(force_grad and not potentials.requires_grad) ) for _ in range(N) @@ -92,28 +112,27 @@ def marginals(self, logpotentials, lengths=None, _raw=False): marginals: b x (N-1) x C x C table """ - v, edges = self.logpartition(logpotentials, lengths=lengths, force_grad=True) - if _raw: - all_m = [] - for k in range(v.shape[0]): - obj = v[k].sum(dim=0) - - marg = torch.autograd.grad( - obj, - edges, - create_graph=True, - only_inputs=True, - allow_unused=False, - ) - all_m.append(self.semiring.unconvert(self._arrange_marginals(marg))) - return torch.stack(all_m, dim=0) - else: - obj = self.semiring.unconvert(v).sum(dim=0) - marg = torch.autograd.grad( - obj, edges, create_graph=True, only_inputs=True, allow_unused=False - ) - a_m = self._arrange_marginals(marg) - return self.semiring.unconvert(a_m) + with torch.autograd.enable_grad(): # in case input potentials don't have grads enabled. + v, edges = self.logpartition(logpotentials, lengths=lengths, force_grad=True) + if _raw: + all_m = [] + for k in range(v.shape[0]): + obj = v[k].sum(dim=0) + + marg = torch.autograd.grad( + obj, + edges, + create_graph=True, + only_inputs=True, + allow_unused=False, + ) + all_m.append(self.semiring.unconvert(self._arrange_marginals(marg))) + return torch.stack(all_m, dim=0) + else: + obj = self.semiring.unconvert(v).sum(dim=0) + marg = torch.autograd.grad(obj, edges, create_graph=True, only_inputs=True, allow_unused=False) + a_m = self._arrange_marginals(marg) + return self.semiring.unconvert(a_m) @staticmethod def to_parts(spans, extra, lengths=None): diff --git a/torch_struct/linearchain.py b/torch_struct/linearchain.py index 593b2404..31547640 100644 --- a/torch_struct/linearchain.py +++ b/torch_struct/linearchain.py @@ -122,7 +122,7 @@ def from_parts(edge): batch, N_1, C, _ = edge.shape N = N_1 + 1 labels = torch.zeros(batch, N).long() - on = edge.nonzero() + on = edge.nonzero(as_tuple=False) for i in range(on.shape[0]): if on[i][1] == 0: labels[on[i][0], on[i][1]] = on[i][3] diff --git a/torch_struct/semirings/sample.py b/torch_struct/semirings/sample.py index 09ec189c..8628da0a 100644 --- a/torch_struct/semirings/sample.py +++ b/torch_struct/semirings/sample.py @@ -215,6 +215,8 @@ class MultiSampledSemiring(_BaseLog): "Gradients" give up to 16 samples with replacement. """ + batch_size = 10 + @staticmethod def sum(xs, dim=-1): return _MultiSampledLogSumExp.apply(xs, dim) diff --git a/torch_struct/semirings/semirings.py b/torch_struct/semirings/semirings.py index bb7b9ec1..9d7dd525 100644 --- a/torch_struct/semirings/semirings.py +++ b/torch_struct/semirings/semirings.py @@ -20,6 +20,9 @@ def matmul(cls, a, b): return c +INF = 1e5 # numerically stable large value + + class Semiring: """ Base semiring class. @@ -28,6 +31,10 @@ class Semiring: * Semiring parsing :cite:`goodman1999semiring` + Attributes: + * zero: the additive identity, subclasses should override + * one: the multiplicative identity, subclasses should override + """ @classmethod @@ -47,6 +54,11 @@ def dot(cls, a, b): b = b.unsqueeze(-1) return cls.matmul(a, b).squeeze(-1).squeeze(-1) + @classmethod + def mul(cls, a, b): + "Multiply a and b under the semirings" + raise NotImplementedError() + @classmethod def times(cls, *ls): "Multiply a list of tensors together" @@ -65,20 +77,20 @@ def unconvert(cls, potentials): "Unconvert from semiring by removing extra first dimension." return potentials.squeeze(0) - @staticmethod - def zero_(xs): + @classmethod + def zero_(cls, xs): "Fill *ssize x ...* tensor with additive identity." - raise NotImplementedError() + return xs.fill_(cls.zero) @classmethod def zero_mask_(cls, xs, mask): "Fill *ssize x ...* tensor with additive identity." xs.masked_fill_(mask.unsqueeze(0), cls.zero) - @staticmethod - def one_(xs): + @classmethod + def one_(cls, xs): "Fill *ssize x ...* tensor with multiplicative identity." - raise NotImplementedError() + return xs.fill_(cls.one) @staticmethod def sum(xs, dim=-1): @@ -92,6 +104,7 @@ def plus(cls, a, b): class _Base(Semiring): zero = 0 + one = 1 @staticmethod def mul(a, b): @@ -101,17 +114,10 @@ def mul(a, b): def prod(a, dim=-1): return torch.prod(a, dim=dim) - @staticmethod - def zero_(xs): - return xs.fill_(0) - - @staticmethod - def one_(xs): - return xs.fill_(1) - class _BaseLog(Semiring): - zero = -1e9 + zero = -INF + one = 0 @staticmethod def sum(xs, dim=-1): @@ -121,14 +127,6 @@ def sum(xs, dim=-1): def mul(a, b): return a + b - @staticmethod - def zero_(xs): - return xs.fill_(-1e5) - - @staticmethod - def one_(xs): - return xs.fill_(0.0) - @staticmethod def prod(a, dim=-1): return torch.sum(a, dim=dim) @@ -277,7 +275,8 @@ class KLDivergenceSemiring(Semiring): """ - zero = 0 + zero = (-INF, -INF, 0) + one = (0, 0, 0) @staticmethod def size(): @@ -308,9 +307,7 @@ def sum(xs, dim=-1): ( part_p, part_q, - torch.sum( - xs[2].mul(sm_p) - log_sm_q.mul(sm_p) + log_sm_p.mul(sm_p), dim=d - ), + torch.sum(xs[2].mul(sm_p) - log_sm_q.mul(sm_p) + log_sm_p.mul(sm_p), dim=d), ) ) @@ -325,22 +322,22 @@ def prod(cls, xs, dim=-1): @classmethod def zero_mask_(cls, xs, mask): "Fill *ssize x ...* tensor with additive identity." - xs[0].masked_fill_(mask, -1e5) - xs[1].masked_fill_(mask, -1e5) - xs[2].masked_fill_(mask, 0) + xs[0].masked_fill_(mask, cls.zero[0]) + xs[1].masked_fill_(mask, cls.zero[1]) + xs[2].masked_fill_(mask, cls.zero[2]) - @staticmethod - def zero_(xs): - xs[0].fill_(-1e5) - xs[1].fill_(-1e5) - xs[2].fill_(0) + @classmethod + def zero_(cls, xs): + xs[0].fill_(cls.zero[0]) + xs[1].fill_(cls.zero[1]) + xs[2].fill_(cls.zero[2]) return xs - @staticmethod - def one_(xs): - xs[0].fill_(0) - xs[1].fill_(0) - xs[2].fill_(0) + @classmethod + def one_(cls, xs): + xs[0].fill_(cls.one[0]) + xs[1].fill_(cls.one[1]) + xs[2].fill_(cls.one[2]) return xs @@ -357,7 +354,8 @@ class CrossEntropySemiring(Semiring): * Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf` """ - zero = 0 + zero = (-INF, -INF, 0) + one = (0, 0, 0) @staticmethod def size(): @@ -384,9 +382,7 @@ def sum(xs, dim=-1): log_sm_p = xs[0] - part_p.unsqueeze(d) log_sm_q = xs[1] - part_q.unsqueeze(d) sm_p = log_sm_p.exp() - return torch.stack( - (part_p, part_q, torch.sum(xs[2].mul(sm_p) - log_sm_q.mul(sm_p), dim=d)) - ) + return torch.stack((part_p, part_q, torch.sum(xs[2].mul(sm_p) - log_sm_q.mul(sm_p), dim=d))) @staticmethod def mul(a, b): @@ -399,22 +395,22 @@ def prod(cls, xs, dim=-1): @classmethod def zero_mask_(cls, xs, mask): "Fill *ssize x ...* tensor with additive identity." - xs[0].masked_fill_(mask, -1e5) - xs[1].masked_fill_(mask, -1e5) - xs[2].masked_fill_(mask, 0) + xs[0].masked_fill_(mask, cls.zero[0]) + xs[1].masked_fill_(mask, cls.zero[1]) + xs[2].masked_fill_(mask, cls.zero[2]) - @staticmethod - def zero_(xs): - xs[0].fill_(-1e5) - xs[1].fill_(-1e5) - xs[2].fill_(0) + @classmethod + def zero_(cls, xs): + xs[0].fill_(cls.zero[0]) + xs[1].fill_(cls.zero[1]) + xs[2].fill_(cls.zero[2]) return xs - @staticmethod - def one_(xs): - xs[0].fill_(0) - xs[1].fill_(0) - xs[2].fill_(0) + @classmethod + def one_(cls, xs): + xs[0].fill_(cls.one[0]) + xs[1].fill_(cls.one[1]) + xs[2].fill_(cls.one[2]) return xs @@ -431,7 +427,8 @@ class EntropySemiring(Semiring): * Sample Selection for Statistical Grammar Induction :cite:`hwa2000samplesf` """ - zero = 0 + zero = (-INF, 0) + one = (0, 0) @staticmethod def size(): @@ -468,19 +465,19 @@ def prod(cls, xs, dim=-1): @classmethod def zero_mask_(cls, xs, mask): "Fill *ssize x ...* tensor with additive identity." - xs[0].masked_fill_(mask, -1e5) - xs[1].masked_fill_(mask, 0) + xs[0].masked_fill_(mask, cls.zero[0]) + xs[1].masked_fill_(mask, cls.zero[1]) - @staticmethod - def zero_(xs): - xs[0].fill_(-1e5) - xs[1].fill_(0) + @classmethod + def zero_(cls, xs): + xs[0].fill_(cls.zero[0]) + xs[1].fill_(cls.zero[1]) return xs - @staticmethod - def one_(xs): - xs[0].fill_(0) - xs[1].fill_(0) + @classmethod + def one_(cls, xs): + xs[0].fill_(cls.one[0]) + xs[1].fill_(cls.one[1]) return xs