Skip to content

Commit 7fde39a

Browse files
purna135Sayam753
andcommitted
allow alpha to take batched data for StickBreakingWeights
Co-authored-by: Sayam Kumar <[email protected]>
1 parent ad16bf4 commit 7fde39a

File tree

3 files changed

+42
-29
lines changed

3 files changed

+42
-29
lines changed

pymc/distributions/multivariate.py

+5-11
Original file line numberDiff line numberDiff line change
@@ -2192,9 +2192,6 @@ def make_node(self, rng, size, dtype, alpha, K):
21922192
alpha = at.as_tensor_variable(alpha)
21932193
K = at.as_tensor_variable(intX(K))
21942194

2195-
if alpha.ndim > 0:
2196-
raise ValueError("The concentration parameter needs to be a scalar.")
2197-
21982195
if K.ndim > 0:
21992196
raise ValueError("K must be a scalar.")
22002197

@@ -2205,20 +2202,17 @@ def _infer_shape(self, size, dist_params, param_shapes=None):
22052202

22062203
size = tuple(size)
22072204

2208-
return size + (K + 1,)
2205+
return size + tuple(alpha.shape) + (K + 1,)
22092206

22102207
@classmethod
22112208
def rng_fn(cls, rng, alpha, K, size):
22122209
if K < 0:
22132210
raise ValueError("K needs to be positive.")
22142211

2215-
if size is None:
2216-
size = (K,)
2217-
elif isinstance(size, int):
2218-
size = (size,) + (K,)
2219-
else:
2220-
size = tuple(size) + (K,)
2212+
distribution_shape = alpha.shape + (K,)
2213+
size = to_tuple(size) + distribution_shape
22212214

2215+
alpha = alpha[..., np.newaxis]
22222216
betas = rng.beta(1, alpha, size=size)
22232217

22242218
sticks = np.concatenate(
@@ -2262,7 +2256,7 @@ class StickBreakingWeights(SimplexContinuous):
22622256
22632257
Parameters
22642258
----------
2265-
alpha : tensor_like of float
2259+
alpha: float or array_like of floats
22662260
Concentration parameter (alpha > 0).
22672261
K : tensor_like of int
22682262
The number of "sticks" to break off from an initial one-unit stick. The length of the weight

pymc/tests/test_distributions.py

+18
Original file line numberDiff line numberDiff line change
@@ -2291,6 +2291,24 @@ def test_dirichlet_multinomial_vectorized(self, n, a, extra_size):
22912291
3,
22922292
np.array([1.29317672, 1.50126157]),
22932293
),
2294+
(
2295+
np.array([5, 4, 3, 2, 1]) / 15,
2296+
np.array([0.5, 1, 2], dtype="float64"),
2297+
4,
2298+
np.array([1.51263013, 2.93119375, 2.99573227]),
2299+
),
2300+
(
2301+
np.array([5, 4, 3, 2, 1]) / 15,
2302+
np.arange(1, 10, dtype="float64").reshape(3, 3),
2303+
4,
2304+
np.array(
2305+
[
2306+
[2.93119375, 2.99573227, 1.9095425],
2307+
[0.35222059, -1.4632554, -3.44201938],
2308+
[-5.53346686, -7.70739149, -9.94430955],
2309+
]
2310+
),
2311+
),
22942312
],
22952313
)
22962314
def test_stickbreakingweights_logp(self, value, alpha, K, logp):

pymc/tests/test_distributions_random.py

+19-18
Original file line numberDiff line numberDiff line change
@@ -1287,25 +1287,26 @@ class TestDirichletMultinomial_1D_n_2D_a(BaseTestDistributionRandom):
12871287

12881288

12891289
class TestStickBreakingWeights(BaseTestDistributionRandom):
1290-
pymc_dist = pm.StickBreakingWeights
1291-
pymc_dist_params = {"alpha": 2.0, "K": 19}
1292-
expected_rv_op_params = {"alpha": 2.0, "K": 19}
1293-
sizes_to_check = [None, 17, (5,), (11, 5), (3, 13, 5)]
1294-
sizes_expected = [
1295-
(20,),
1296-
(17, 20),
1297-
(
1298-
5,
1299-
20,
1300-
),
1301-
(11, 5, 20),
1302-
(3, 13, 5, 20),
1303-
]
1304-
checks_to_run = [
1305-
"check_pymc_params_match_rv_op",
1306-
"check_rv_size",
1307-
"check_basic_properties",
1290+
parameters = [
1291+
(np.array(3.5), 19),
1292+
(np.array([1, 2, 3], dtype="float64"), 17),
1293+
(np.arange(1, 10, dtype="float64").reshape(3, 3), 15),
1294+
(np.arange(1, 25, dtype="float64").reshape(2, 3, 4), 5),
13081295
]
1296+
for alpha, K in parameters:
1297+
pymc_dist = pm.StickBreakingWeights
1298+
pymc_dist_params = {"alpha": alpha, "K": K}
1299+
expected_rv_op_params = {"alpha": alpha, "K": K}
1300+
sizes_to_check = [None, 17, (5,), (11, 5), (3, 13, 5)]
1301+
sizes_expected = []
1302+
for size in sizes_to_check:
1303+
sizes_expected.append(to_tuple(size) + alpha.shape + (K + 1,))
1304+
1305+
checks_to_run = [
1306+
"check_pymc_params_match_rv_op",
1307+
"check_rv_size",
1308+
"check_basic_properties",
1309+
]
13091310

13101311
def check_basic_properties(self):
13111312
default_rng = aesara.shared(np.random.default_rng(1234))

0 commit comments

Comments
 (0)