|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2024 The Perch Authors. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Call density estimation using a bernoulli kernel.""" |
| 17 | + |
| 18 | +from flax import nnx |
| 19 | +import jax |
| 20 | +from jax import numpy as jnp |
| 21 | + |
| 22 | + |
| 23 | +def log_sum_exp(xs: jnp.ndarray, axis=None): |
| 24 | + # Log-of-sum-of-exponentials admits a nice, more-stable form using the max |
| 25 | + # of the sequence. |
| 26 | + # https://mc-stan.org/docs/2_27/stan-users-guide/log-sum-of-exponentials.html |
| 27 | + max_x = jnp.max(xs, axis=axis, keepdims=True) |
| 28 | + max_sq = jnp.max(xs, axis=axis) |
| 29 | + sums_x = jnp.log(jnp.sum(jnp.exp(xs - max_x), axis=axis)) |
| 30 | + return sums_x + max_sq |
| 31 | + |
| 32 | + |
| 33 | +sq_norm = lambda x: jnp.sum(x * x, axis=1) |
| 34 | +dots_ab = lambda x, y: jnp.dot(x, y.T) |
| 35 | +dists_ab = lambda x, y: ( |
| 36 | + -2 * dots_ab(x, y) + sq_norm(x)[:, jnp.newaxis] + sq_norm(y)[jnp.newaxis, :] |
| 37 | +) |
| 38 | + |
| 39 | +# Scaled distances. |
| 40 | +dists_ab_s = lambda a, b, s: ( |
| 41 | + dists_ab(a * s[jnp.newaxis, :], b * s[jnp.newaxis, :]) |
| 42 | +) |
| 43 | + |
| 44 | + |
| 45 | +def scaled_rbf_kernel( |
| 46 | + x: jnp.ndarray, y: jnp.ndarray, scale: jnp.ndarray, bias: float |
| 47 | +): |
| 48 | + return dists_ab_s(x, y, scale) + bias |
| 49 | + |
| 50 | + |
| 51 | +class BernoulliData(nnx.Variable): |
| 52 | + """Container for data and labels for a BernoulliProcessor.""" |
| 53 | + |
| 54 | + # We declare this subclass so that the groundtruth data is not updated |
| 55 | + # during training. |
| 56 | + pass |
| 57 | + |
| 58 | + |
| 59 | +class BernoulliRBF(nnx.Module): |
| 60 | + r"""Model P(+|x) ~ \beta(a(x), b(x)). |
| 61 | +
|
| 62 | + Given some input data x, we want to estimate the number of virtual positive |
| 63 | + and negative observations to associate with x. These are used as parameters in |
| 64 | + a beta distribution, allowing us to have both an expected value for P(+|x) |
| 65 | + and a measure of certainty, according to the total weight a(x) + b(x). |
| 66 | +
|
| 67 | + We combine two approaches for estimating a(x), b(x): |
| 68 | + First, a learned RBF kernel over the ground-truth observations acts as a KNN |
| 69 | + classifier, contributing positive and negative observations at arbitrary x |
| 70 | + according to learned similarity between x and the groundtruth. |
| 71 | +
|
| 72 | + Second, we (optionally) directly predict a number of pos/neg observations |
| 73 | + a_f(x), b_f(x) from the features themselves. For example, if one of the |
| 74 | + features is a classifier score, this allows the model to directly use the |
| 75 | + classifier score as a prior, with some learned weight. |
| 76 | + """ |
| 77 | + |
| 78 | + def __init__( |
| 79 | + self, |
| 80 | + data: jnp.ndarray, |
| 81 | + data_labels: jnp.ndarray, |
| 82 | + data_mean: float | None = 0.0, |
| 83 | + data_std: float | None = 1.0, |
| 84 | + learn_feature_weights: bool = False, |
| 85 | + *, |
| 86 | + rngs: nnx.Rngs |
| 87 | + ): |
| 88 | + key = rngs.params() |
| 89 | + num_features = data.shape[-1] |
| 90 | + self.scales_pos = nnx.Param(jax.random.uniform(key, (num_features,))) |
| 91 | + self.scales_neg = nnx.Param(jax.random.uniform(key, (num_features,))) |
| 92 | + self.weight_bias = nnx.Param(jnp.zeros([2])) |
| 93 | + if data_mean is None: |
| 94 | + self.data_mean = BernoulliData(jnp.mean(data, axis=0, keepdims=True)) |
| 95 | + else: |
| 96 | + self.data_mean = BernoulliData(data_mean) |
| 97 | + if data_std is None: |
| 98 | + self.data_stds = BernoulliData(jnp.std(data, axis=0, keepdims=True)) |
| 99 | + else: |
| 100 | + self.data_stds = BernoulliData(data_std) |
| 101 | + data_pos, data_neg = self.split_labeled_data(data, data_labels) |
| 102 | + self.data_pos = jax.lax.stop_gradient( |
| 103 | + BernoulliData(self._normalize(data_pos)) |
| 104 | + ) |
| 105 | + self.data_neg = jax.lax.stop_gradient( |
| 106 | + BernoulliData(self._normalize(data_neg)) |
| 107 | + ) |
| 108 | + self.data_labels = jax.lax.stop_gradient(BernoulliData(data_labels)) |
| 109 | + self.learn_feature_weights = learn_feature_weights |
| 110 | + |
| 111 | + # Matrices for assigning pos/neg weight directly from features. |
| 112 | + self.feature_weights = nnx.Param(jax.random.uniform(key, (num_features, 2))) |
| 113 | + self.feature_bias = nnx.Param(jax.random.uniform(key, (2,))) |
| 114 | + |
| 115 | + @classmethod |
| 116 | + def split_labeled_data(cls, data: jnp.ndarray, data_labels: jnp.ndarray): |
| 117 | + pos_idxes = jnp.where(data_labels == 1)[0] |
| 118 | + neg_idxes = jnp.where(data_labels == 0)[0] |
| 119 | + data_pos = data[pos_idxes] |
| 120 | + data_neg = data[neg_idxes] |
| 121 | + return data_pos, data_neg |
| 122 | + |
| 123 | + def _normalize(self, x): |
| 124 | + return (x - self.data_mean.value) / self.data_stds.value |
| 125 | + |
| 126 | + def _log_counts(self, x: jnp.ndarray, normalize: bool = True): |
| 127 | + if normalize: |
| 128 | + x = self._normalize(x) |
| 129 | + pos_count = scaled_rbf_kernel( |
| 130 | + x, self.data_pos, self.scales_pos, self.weight_bias[0] |
| 131 | + ) |
| 132 | + neg_count = scaled_rbf_kernel( |
| 133 | + x, self.data_neg, self.scales_neg, self.weight_bias[1] |
| 134 | + ) |
| 135 | + |
| 136 | + if self.learn_feature_weights: |
| 137 | + feature_count = jnp.dot(x, self.feature_weights.value) + self.feature_bias |
| 138 | + pos_count = jnp.concat([pos_count, feature_count[:, :1]], axis=1) |
| 139 | + neg_count = jnp.concat([neg_count, feature_count[:, 1:]], axis=1) |
| 140 | + log_pos_count = log_sum_exp(-pos_count, axis=1) |
| 141 | + log_neg_count = log_sum_exp(-neg_count, axis=1) |
| 142 | + log_weight_count = log_sum_exp( |
| 143 | + jnp.concatenate([-pos_count, -neg_count], axis=1), axis=1 |
| 144 | + ) |
| 145 | + return log_pos_count, log_neg_count, log_weight_count |
| 146 | + |
| 147 | + def __call__(self, x: jnp.ndarray, normalize: bool = True): |
| 148 | + """Compute log(P(+|x)) and the total example weight of x.""" |
| 149 | + log_pos_count, _, log_weight_count = self._log_counts(x, normalize) |
| 150 | + log_p_x = log_pos_count - log_weight_count |
| 151 | + return log_p_x, log_weight_count |
| 152 | + |
| 153 | + def sampled_counts(self, seed: int, x: jnp.ndarray, n_samples: int = 1024): |
| 154 | + """Create sampled positive counts from the learned distribution at x.""" |
| 155 | + log_pos_count, log_neg_count, unused_log_wt = self._log_counts(x) |
| 156 | + pos_count = jnp.exp(log_pos_count)[:, jnp.newaxis] |
| 157 | + neg_count = jnp.exp(log_neg_count)[:, jnp.newaxis] |
| 158 | + |
| 159 | + k = jax.random.PRNGKey(seed) |
| 160 | + beta_samp = jax.random.beta( |
| 161 | + k, pos_count, neg_count, shape=[pos_count.shape[0], n_samples] |
| 162 | + ) |
| 163 | + sample_counts = jnp.sum( |
| 164 | + jax.random.uniform(k, shape=beta_samp.shape) < beta_samp, axis=0 |
| 165 | + ) |
| 166 | + return sample_counts |
| 167 | + |
| 168 | + def gt_log_likelihood(self): |
| 169 | + """Total log likelihood of the GT data, given learned params.""" |
| 170 | + # Counts for positive points. |
| 171 | + pos_pos_count = scaled_rbf_kernel( # [N+, N+] |
| 172 | + self.data_pos, self.data_pos, self.scales_pos, self.weight_bias[0] |
| 173 | + ) |
| 174 | + pos_neg_count = scaled_rbf_kernel( |
| 175 | + self.data_pos, self.data_neg, self.scales_neg, self.weight_bias[1] |
| 176 | + ) |
| 177 | + |
| 178 | + # Counts for negative points. |
| 179 | + neg_neg_count = scaled_rbf_kernel( |
| 180 | + self.data_neg, self.data_neg, self.scales_neg, self.weight_bias[1] |
| 181 | + ) |
| 182 | + neg_pos_count = scaled_rbf_kernel( |
| 183 | + self.data_neg, self.data_pos, self.scales_pos, self.weight_bias[0] |
| 184 | + ) |
| 185 | + |
| 186 | + # Estimate pos/neg priors from raw features. |
| 187 | + if self.learn_feature_weights: |
| 188 | + pos_feature_count = ( |
| 189 | + jnp.dot(self.data_pos.value, self.feature_weights.value) # [N+, 2] |
| 190 | + + self.feature_bias.value |
| 191 | + ) |
| 192 | + neg_feature_count = ( |
| 193 | + jnp.dot(self.data_neg.value, self.feature_weights.value) # [N-, 2] |
| 194 | + + self.feature_bias.value |
| 195 | + ) |
| 196 | + # Add feature counts to the list of actual counts from data. |
| 197 | + pos_pos_count = jnp.concat( |
| 198 | + [pos_pos_count, pos_feature_count[:, :1]], axis=-1 |
| 199 | + ) |
| 200 | + pos_neg_count = jnp.concat( |
| 201 | + [pos_neg_count, pos_feature_count[:, 1:]], axis=-1 |
| 202 | + ) |
| 203 | + neg_pos_count = jnp.concat( |
| 204 | + [neg_pos_count, neg_feature_count[:, :1]], axis=-1 |
| 205 | + ) |
| 206 | + neg_neg_count = jnp.concat( |
| 207 | + [neg_neg_count, neg_feature_count[:, 1:]], axis=-1 |
| 208 | + ) |
| 209 | + |
| 210 | + pos_log_prob = log_sum_exp(-pos_pos_count, axis=1) - log_sum_exp( |
| 211 | + jnp.concatenate([-pos_pos_count, -pos_neg_count], axis=1), axis=1 |
| 212 | + ) |
| 213 | + neg_log_prob = log_sum_exp(-neg_neg_count, axis=1) - log_sum_exp( |
| 214 | + jnp.concatenate([-neg_neg_count, -neg_pos_count], axis=1), axis=1 |
| 215 | + ) |
| 216 | + |
| 217 | + pos_log_prob = jnp.mean(pos_log_prob) |
| 218 | + neg_log_prob = jnp.mean(neg_log_prob) |
| 219 | + return pos_log_prob + neg_log_prob |
| 220 | + |
| 221 | + def matching_loss(self): |
| 222 | + """Difference between observed log P(+) and estimated log P(+).""" |
| 223 | + data_log_p_x, _ = self( |
| 224 | + jnp.concatenate([self.data_pos, self.data_neg], axis=0), normalize=False |
| 225 | + ) |
| 226 | + target_log_p_x = jnp.log(self.data_pos.shape[0]) - jnp.log( |
| 227 | + self.data_pos.shape[0] + self.data_neg.shape[0] |
| 228 | + ) |
| 229 | + return jnp.abs(data_log_p_x.mean() - target_log_p_x) |
| 230 | + |
| 231 | + |
| 232 | +@nnx.jit |
| 233 | +def train_step( |
| 234 | + model: BernoulliRBF, optimizer: nnx.optimizer.Optimizer, mu: float |
| 235 | +) -> float: |
| 236 | + def loss_fn(model: BernoulliRBF): |
| 237 | + gt_log_likelihood_loss = -model.gt_log_likelihood() |
| 238 | + matching_loss = model.matching_loss() |
| 239 | + return gt_log_likelihood_loss + mu * matching_loss |
| 240 | + |
| 241 | + loss, grads = nnx.value_and_grad(loss_fn)(model) |
| 242 | + optimizer.update(grads) |
| 243 | + return loss |
0 commit comments