Skip to content

Commit e7c3f7f

Browse files
sdenton4copybara-github
authored andcommitted
Bernoulli RBF for call density estimation.
PiperOrigin-RevId: 724507934
1 parent 1a49cfb commit e7c3f7f

File tree

2 files changed

+349
-0
lines changed

2 files changed

+349
-0
lines changed

chirp/projects/bernoulli_rbf.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
# coding=utf-8
2+
# Copyright 2024 The Perch Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Call density estimation using a bernoulli kernel."""
17+
18+
from flax import nnx
19+
import jax
20+
from jax import numpy as jnp
21+
22+
23+
def log_sum_exp(xs: jnp.ndarray, axis=None):
24+
# Log-of-sum-of-exponentials admits a nice, more-stable form using the max
25+
# of the sequence.
26+
# https://mc-stan.org/docs/2_27/stan-users-guide/log-sum-of-exponentials.html
27+
max_x = jnp.max(xs, axis=axis, keepdims=True)
28+
max_sq = jnp.max(xs, axis=axis)
29+
sums_x = jnp.log(jnp.sum(jnp.exp(xs - max_x), axis=axis))
30+
return sums_x + max_sq
31+
32+
33+
sq_norm = lambda x: jnp.sum(x * x, axis=1)
34+
dots_ab = lambda x, y: jnp.dot(x, y.T)
35+
dists_ab = lambda x, y: (
36+
-2 * dots_ab(x, y) + sq_norm(x)[:, jnp.newaxis] + sq_norm(y)[jnp.newaxis, :]
37+
)
38+
39+
# Scaled distances.
40+
dists_ab_s = lambda a, b, s: (
41+
dists_ab(a * s[jnp.newaxis, :], b * s[jnp.newaxis, :])
42+
)
43+
44+
45+
def scaled_rbf_kernel(
46+
x: jnp.ndarray, y: jnp.ndarray, scale: jnp.ndarray, bias: float
47+
):
48+
return dists_ab_s(x, y, scale) + bias
49+
50+
51+
class BernoulliData(nnx.Variable):
52+
"""Container for data and labels for a BernoulliProcessor."""
53+
54+
# We declare this subclass so that the groundtruth data is not updated
55+
# during training.
56+
pass
57+
58+
59+
class BernoulliRBF(nnx.Module):
60+
r"""Model P(+|x) ~ \beta(a(x), b(x)).
61+
62+
Given some input data x, we want to estimate the number of virtual positive
63+
and negative observations to associate with x. These are used as parameters in
64+
a beta distribution, allowing us to have both an expected value for P(+|x)
65+
and a measure of certainty, according to the total weight a(x) + b(x).
66+
67+
We combine two approaches for estimating a(x), b(x):
68+
First, a learned RBF kernel over the ground-truth observations acts as a KNN
69+
classifier, contributing positive and negative observations at arbitrary x
70+
according to learned similarity between x and the groundtruth.
71+
72+
Second, we (optionally) directly predict a number of pos/neg observations
73+
a_f(x), b_f(x) from the features themselves. For example, if one of the
74+
features is a classifier score, this allows the model to directly use the
75+
classifier score as a prior, with some learned weight.
76+
"""
77+
78+
def __init__(
79+
self,
80+
data: jnp.ndarray,
81+
data_labels: jnp.ndarray,
82+
data_mean: float | None = 0.0,
83+
data_std: float | None = 1.0,
84+
learn_feature_weights: bool = False,
85+
*,
86+
rngs: nnx.Rngs
87+
):
88+
key = rngs.params()
89+
num_features = data.shape[-1]
90+
self.scales_pos = nnx.Param(jax.random.uniform(key, (num_features,)))
91+
self.scales_neg = nnx.Param(jax.random.uniform(key, (num_features,)))
92+
self.weight_bias = nnx.Param(jnp.zeros([2]))
93+
if data_mean is None:
94+
self.data_mean = BernoulliData(jnp.mean(data, axis=0, keepdims=True))
95+
else:
96+
self.data_mean = BernoulliData(data_mean)
97+
if data_std is None:
98+
self.data_stds = BernoulliData(jnp.std(data, axis=0, keepdims=True))
99+
else:
100+
self.data_stds = BernoulliData(data_std)
101+
data_pos, data_neg = self.split_labeled_data(data, data_labels)
102+
self.data_pos = jax.lax.stop_gradient(
103+
BernoulliData(self._normalize(data_pos))
104+
)
105+
self.data_neg = jax.lax.stop_gradient(
106+
BernoulliData(self._normalize(data_neg))
107+
)
108+
self.data_labels = jax.lax.stop_gradient(BernoulliData(data_labels))
109+
self.learn_feature_weights = learn_feature_weights
110+
111+
# Matrices for assigning pos/neg weight directly from features.
112+
self.feature_weights = nnx.Param(jax.random.uniform(key, (num_features, 2)))
113+
self.feature_bias = nnx.Param(jax.random.uniform(key, (2,)))
114+
115+
@classmethod
116+
def split_labeled_data(cls, data: jnp.ndarray, data_labels: jnp.ndarray):
117+
pos_idxes = jnp.where(data_labels == 1)[0]
118+
neg_idxes = jnp.where(data_labels == 0)[0]
119+
data_pos = data[pos_idxes]
120+
data_neg = data[neg_idxes]
121+
return data_pos, data_neg
122+
123+
def _normalize(self, x):
124+
return (x - self.data_mean.value) / self.data_stds.value
125+
126+
def _log_counts(self, x: jnp.ndarray, normalize: bool = True):
127+
if normalize:
128+
x = self._normalize(x)
129+
pos_count = scaled_rbf_kernel(
130+
x, self.data_pos, self.scales_pos, self.weight_bias[0]
131+
)
132+
neg_count = scaled_rbf_kernel(
133+
x, self.data_neg, self.scales_neg, self.weight_bias[1]
134+
)
135+
136+
if self.learn_feature_weights:
137+
feature_count = jnp.dot(x, self.feature_weights.value) + self.feature_bias
138+
pos_count = jnp.concat([pos_count, feature_count[:, :1]], axis=1)
139+
neg_count = jnp.concat([neg_count, feature_count[:, 1:]], axis=1)
140+
log_pos_count = log_sum_exp(-pos_count, axis=1)
141+
log_neg_count = log_sum_exp(-neg_count, axis=1)
142+
log_weight_count = log_sum_exp(
143+
jnp.concatenate([-pos_count, -neg_count], axis=1), axis=1
144+
)
145+
return log_pos_count, log_neg_count, log_weight_count
146+
147+
def __call__(self, x: jnp.ndarray, normalize: bool = True):
148+
"""Compute log(P(+|x)) and the total example weight of x."""
149+
log_pos_count, _, log_weight_count = self._log_counts(x, normalize)
150+
log_p_x = log_pos_count - log_weight_count
151+
return log_p_x, log_weight_count
152+
153+
def sampled_counts(self, seed: int, x: jnp.ndarray, n_samples: int = 1024):
154+
"""Create sampled positive counts from the learned distribution at x."""
155+
log_pos_count, log_neg_count, unused_log_wt = self._log_counts(x)
156+
pos_count = jnp.exp(log_pos_count)[:, jnp.newaxis]
157+
neg_count = jnp.exp(log_neg_count)[:, jnp.newaxis]
158+
159+
k = jax.random.PRNGKey(seed)
160+
beta_samp = jax.random.beta(
161+
k, pos_count, neg_count, shape=[pos_count.shape[0], n_samples]
162+
)
163+
sample_counts = jnp.sum(
164+
jax.random.uniform(k, shape=beta_samp.shape) < beta_samp, axis=0
165+
)
166+
return sample_counts
167+
168+
def gt_log_likelihood(self):
169+
"""Total log likelihood of the GT data, given learned params."""
170+
# Counts for positive points.
171+
pos_pos_count = scaled_rbf_kernel( # [N+, N+]
172+
self.data_pos, self.data_pos, self.scales_pos, self.weight_bias[0]
173+
)
174+
pos_neg_count = scaled_rbf_kernel(
175+
self.data_pos, self.data_neg, self.scales_neg, self.weight_bias[1]
176+
)
177+
178+
# Counts for negative points.
179+
neg_neg_count = scaled_rbf_kernel(
180+
self.data_neg, self.data_neg, self.scales_neg, self.weight_bias[1]
181+
)
182+
neg_pos_count = scaled_rbf_kernel(
183+
self.data_neg, self.data_pos, self.scales_pos, self.weight_bias[0]
184+
)
185+
186+
# Estimate pos/neg priors from raw features.
187+
if self.learn_feature_weights:
188+
pos_feature_count = (
189+
jnp.dot(self.data_pos.value, self.feature_weights.value) # [N+, 2]
190+
+ self.feature_bias.value
191+
)
192+
neg_feature_count = (
193+
jnp.dot(self.data_neg.value, self.feature_weights.value) # [N-, 2]
194+
+ self.feature_bias.value
195+
)
196+
# Add feature counts to the list of actual counts from data.
197+
pos_pos_count = jnp.concat(
198+
[pos_pos_count, pos_feature_count[:, :1]], axis=-1
199+
)
200+
pos_neg_count = jnp.concat(
201+
[pos_neg_count, pos_feature_count[:, 1:]], axis=-1
202+
)
203+
neg_pos_count = jnp.concat(
204+
[neg_pos_count, neg_feature_count[:, :1]], axis=-1
205+
)
206+
neg_neg_count = jnp.concat(
207+
[neg_neg_count, neg_feature_count[:, 1:]], axis=-1
208+
)
209+
210+
pos_log_prob = log_sum_exp(-pos_pos_count, axis=1) - log_sum_exp(
211+
jnp.concatenate([-pos_pos_count, -pos_neg_count], axis=1), axis=1
212+
)
213+
neg_log_prob = log_sum_exp(-neg_neg_count, axis=1) - log_sum_exp(
214+
jnp.concatenate([-neg_neg_count, -neg_pos_count], axis=1), axis=1
215+
)
216+
217+
pos_log_prob = jnp.mean(pos_log_prob)
218+
neg_log_prob = jnp.mean(neg_log_prob)
219+
return pos_log_prob + neg_log_prob
220+
221+
def matching_loss(self):
222+
"""Difference between observed log P(+) and estimated log P(+)."""
223+
data_log_p_x, _ = self(
224+
jnp.concatenate([self.data_pos, self.data_neg], axis=0), normalize=False
225+
)
226+
target_log_p_x = jnp.log(self.data_pos.shape[0]) - jnp.log(
227+
self.data_pos.shape[0] + self.data_neg.shape[0]
228+
)
229+
return jnp.abs(data_log_p_x.mean() - target_log_p_x)
230+
231+
232+
@nnx.jit
233+
def train_step(
234+
model: BernoulliRBF, optimizer: nnx.optimizer.Optimizer, mu: float
235+
) -> float:
236+
def loss_fn(model: BernoulliRBF):
237+
gt_log_likelihood_loss = -model.gt_log_likelihood()
238+
matching_loss = model.matching_loss()
239+
return gt_log_likelihood_loss + mu * matching_loss
240+
241+
loss, grads = nnx.value_and_grad(loss_fn)(model)
242+
optimizer.update(grads)
243+
return loss

chirp/projects/bernoulli_rbf_test.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# coding=utf-8
2+
# Copyright 2024 The Perch Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for Bernoulli RBF."""
17+
18+
from chirp.projects import bernoulli_rbf
19+
from flax import nnx
20+
from jax import numpy as jnp
21+
import numpy as np
22+
import optax
23+
24+
from absl.testing import absltest
25+
26+
27+
class BernoulliRbfTest(absltest.TestCase):
28+
29+
def test_kernel_numerics(self):
30+
31+
xs = jnp.array([[0, 1], [2, 3]])
32+
ys = jnp.array([[1, 0], [1, 1], [2, 3]])
33+
34+
with self.subTest('unit_scale'):
35+
scales = jnp.array([1.0, 1.0])
36+
got = bernoulli_rbf.scaled_rbf_kernel(xs, ys, scales, 0.0)
37+
expect = jnp.array([
38+
[2.0, 1.0, 8.0],
39+
[10.0, 5.0, 0.0],
40+
])
41+
np.testing.assert_array_equal(got.shape, (2, 3))
42+
np.testing.assert_array_equal(got, expect)
43+
44+
with self.subTest('scaled'):
45+
scales = jnp.array([2.0, 3.0])
46+
got = bernoulli_rbf.scaled_rbf_kernel(xs, ys, scales, 0.0)
47+
expect = jnp.array([
48+
[13.0, 4.0, 52.0],
49+
[85.0, 40.0, 0.0],
50+
])
51+
np.testing.assert_array_equal(got.shape, (2, 3))
52+
np.testing.assert_array_equal(got, expect)
53+
54+
def test_split_labeled_data(self):
55+
data = jnp.array([[0, 1], [2, 3], [4, 5], [6, 7]])
56+
labels = jnp.array([0, 1, 0, 1])
57+
pos, neg = bernoulli_rbf.BernoulliRBF.split_labeled_data(data, labels)
58+
np.testing.assert_array_equal(pos, jnp.array([[2, 3], [6, 7]]))
59+
np.testing.assert_array_equal(neg, jnp.array([[0, 1], [4, 5]]))
60+
61+
def test_log_prob(self):
62+
data = jnp.array([[1, 0], [1, 1], [2, 3]])
63+
labels = jnp.array([0, 1, 0])
64+
model = bernoulli_rbf.BernoulliRBF(
65+
data,
66+
labels,
67+
rngs=nnx.Rngs(666),
68+
learn_feature_weights=False,
69+
)
70+
# Set unit scales and bias for simplicity.
71+
model.scales_pos = jnp.array([1.0, 1.0])
72+
model.scales_neg = jnp.array([1.0, 1.0])
73+
model.weight_bias = jnp.array([0.0, 0.0])
74+
75+
got_log_prob, got_log_wt = model(jnp.array([[0, 1]]))
76+
# Squared distances to the data points are [2.0, 1.0, 8.0], and only
77+
# the second example is positive.
78+
# Then the positive example weight is [exp(-1)] and negative example weights
79+
# are [exp(-2), exp(-8)].
80+
# Then our predicted probability is:
81+
# exp(-1) / (exp(-1) + exp(-2) + exp(-8)) ~= 0.7306.
82+
# The log-of-sum-of-exponentials is log(exp(-1) + exp(-2) + exp(-8)) = -1.
83+
# The total weight is exp(-1) + exp(-2) + exp(-8) = 1.
84+
expect_log_prob = -1.0 + -np.log((np.exp(-1) + np.exp(-2) + np.exp(-8)))
85+
np.testing.assert_allclose(got_log_prob, expect_log_prob, atol=1e-5)
86+
expect_log_wt = np.log(np.exp(-1) + np.exp(-2) + np.exp(-8))
87+
np.testing.assert_allclose(got_log_wt, expect_log_wt, atol=1e-5)
88+
89+
def test_train_step(self):
90+
data = jnp.array([[1, 0], [1, 1], [2, 3]])
91+
labels = jnp.array([0, 1, 0])
92+
model = bernoulli_rbf.BernoulliRBF(
93+
data,
94+
labels,
95+
rngs=nnx.Rngs(666),
96+
data_mean=None,
97+
data_std=None,
98+
learn_feature_weights=False,
99+
)
100+
optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing
101+
loss = bernoulli_rbf.train_step(model, optimizer, mu=1.0)
102+
self.assertLess(loss, 2.0)
103+
104+
105+
if __name__ == '__main__':
106+
absltest.main()

0 commit comments

Comments
 (0)