Skip to content

Commit adf6fff

Browse files
ferrinemichaelosthege
authored andcommitted
Refactor Minibatch and stop using MRG_RandomStream in VI
Closes #4523 Closes #6277
1 parent 43d5699 commit adf6fff

File tree

13 files changed

+193
-427
lines changed

13 files changed

+193
-427
lines changed

pymc/data.py

Lines changed: 59 additions & 333 deletions
Large diffs are not rendered by default.

pymc/distributions/logprob.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ def _get_scaling(total_size: TOTAL_SIZE, shape, ndim: int) -> TensorVariable:
102102

103103
def _check_no_rvs(logp_terms: Sequence[TensorVariable]):
104104
# Raise if there are unexpected RandomVariables in the logp graph
105-
# Only SimulatorRVs are allowed
105+
# Only SimulatorRVs MinibatchIndexRVs are allowed
106+
from pymc.data import MinibatchIndexRV
106107
from pymc.distributions.simulator import SimulatorRV
107108

108109
unexpected_rv_nodes = [
@@ -111,7 +112,7 @@ def _check_no_rvs(logp_terms: Sequence[TensorVariable]):
111112
if (
112113
node.owner
113114
and isinstance(node.owner.op, RandomVariable)
114-
and not isinstance(node.owner.op, SimulatorRV)
115+
and not isinstance(node.owner.op, (SimulatorRV, MinibatchIndexRV))
115116
)
116117
]
117118
if unexpected_rv_nodes:

pymc/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from pytensor.tensor.var import TensorConstant, TensorVariable
5151

5252
from pymc.blocking import DictToArrayBijection, RaveledVars
53-
from pymc.data import GenTensorVariable, Minibatch
53+
from pymc.data import GenTensorVariable, is_minibatch
5454
from pymc.distributions.logprob import _joint_logp
5555
from pymc.distributions.transforms import _default_transform
5656
from pymc.exceptions import (
@@ -1329,14 +1329,15 @@ def register_rv(
13291329
else:
13301330
if (
13311331
isinstance(observed, Variable)
1332-
and not isinstance(observed, (GenTensorVariable, Minibatch))
1332+
and not isinstance(observed, GenTensorVariable)
13331333
and observed.owner is not None
13341334
# The only PyTensor operation we allow on observed data is type casting
13351335
# Although we could allow for any graph that does not depend on other RVs
13361336
and not (
13371337
isinstance(observed.owner.op, Elemwise)
13381338
and isinstance(observed.owner.op.scalar_op, Cast)
13391339
)
1340+
and not is_minibatch(observed)
13401341
):
13411342
raise TypeError(
13421343
"Variables that depend on other nodes cannot be used for observed data."

pymc/pytensorf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@
4747
)
4848
from pytensor.graph.fg import FunctionGraph
4949
from pytensor.graph.op import Op
50-
from pytensor.sandbox.rng_mrg import MRG_RandomStream as RandomStream
5150
from pytensor.scalar.basic import Cast
5251
from pytensor.tensor.basic import _as_tensor_variable
5352
from pytensor.tensor.elemwise import Elemwise
53+
from pytensor.tensor.random import RandomStream
5454
from pytensor.tensor.random.op import RandomVariable
5555
from pytensor.tensor.random.var import (
5656
RandomGeneratorSharedVariable,

pymc/tests/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from pytensor.gradient import verify_grad as at_verify_grad
2727
from pytensor.graph import ancestors
2828
from pytensor.graph.rewriting.basic import in2out
29-
from pytensor.sandbox.rng_mrg import MRG_RandomStream as RandomStream
29+
from pytensor.tensor.random import RandomStream
3030
from pytensor.tensor.random.op import RandomVariable
3131

3232
import pymc as pm

pymc/tests/test_data.py

Lines changed: 28 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
import pymc as pm
2929

30+
from pymc.data import is_minibatch
3031
from pymc.pytensorf import GeneratorOp, floatX
3132
from pymc.tests.helpers import SeededTest, select_by_precision
3233

@@ -696,15 +697,10 @@ def test_common_errors(self):
696697

697698
def test_mixed1(self):
698699
with pm.Model():
699-
data = np.random.rand(10, 20, 30, 40, 50)
700-
mb = pm.Minibatch(data, [2, None, 20, Ellipsis, 10])
701-
pm.Normal("n", observed=mb, total_size=(10, None, 30, Ellipsis, 50))
702-
703-
def test_mixed2(self):
704-
with pm.Model():
705-
data = np.random.rand(10, 20, 30, 40, 50)
706-
mb = pm.Minibatch(data, [2, None, 20])
707-
pm.Normal("n", observed=mb, total_size=(10, None, 30))
700+
data = np.random.rand(10, 20)
701+
mb = pm.Minibatch(data, batch_size=5)
702+
v = pm.Normal("n", observed=mb, total_size=10)
703+
assert pm.logp(v, 1) is not None, "Check index is allowed in graph"
708704

709705
def test_free_rv(self):
710706
with pm.Model() as model4:
@@ -719,51 +715,28 @@ def test_free_rv(self):
719715

720716
@pytest.mark.usefixtures("strict_float32")
721717
class TestMinibatch:
722-
data = np.random.rand(30, 10, 40, 10, 50)
718+
data = np.random.rand(30, 10)
723719

724720
def test_1d(self):
725-
mb = pm.Minibatch(self.data, 20)
726-
assert mb.eval().shape == (20, 10, 40, 10, 50)
727-
728-
def test_2d(self):
729-
mb = pm.Minibatch(self.data, [(10, 42), (4, 42)])
730-
assert mb.eval().shape == (10, 4, 40, 10, 50)
731-
732-
@pytest.mark.parametrize(
733-
"batch_size, expected",
734-
[
735-
([(10, 42), None, (4, 42)], (10, 10, 4, 10, 50)),
736-
([(10, 42), Ellipsis, (4, 42)], (10, 10, 40, 10, 4)),
737-
([(10, 42), None, Ellipsis, (4, 42)], (10, 10, 40, 10, 4)),
738-
([10, None, Ellipsis, (4, 42)], (10, 10, 40, 10, 4)),
739-
],
740-
)
741-
def test_special_batch_size(self, batch_size, expected):
742-
mb = pm.Minibatch(self.data, batch_size)
743-
assert mb.eval().shape == expected
744-
745-
def test_cloning_available(self):
746-
gop = pm.Minibatch(np.arange(100), 1)
747-
res = gop**2
748-
shared = pytensor.shared(np.array([10]))
749-
res1 = pytensor.clone_replace(res, {gop: shared})
750-
f = pytensor.function([], res1)
751-
assert f() == np.array([100])
752-
753-
def test_align(self):
754-
m = pm.Minibatch(np.arange(1000), 1, random_seed=1)
755-
n = pm.Minibatch(np.arange(1000), 1, random_seed=1)
756-
f = pytensor.function([], [m, n])
757-
n.eval() # not aligned
758-
a, b = zip(*(f() for _ in range(1000)))
759-
assert a != b
760-
pm.align_minibatches()
761-
a, b = zip(*(f() for _ in range(1000)))
762-
assert a == b
763-
n.eval() # not aligned
764-
pm.align_minibatches([m])
765-
a, b = zip(*(f() for _ in range(1000)))
766-
assert a != b
767-
pm.align_minibatches([m, n])
768-
a, b = zip(*(f() for _ in range(1000)))
769-
assert a == b
721+
mb = pm.Minibatch(self.data, batch_size=20)
722+
assert is_minibatch(mb)
723+
assert mb.eval().shape == (20, 10)
724+
725+
def test_allowed(self):
726+
mb = pm.Minibatch(at.as_tensor(self.data).astype(int), batch_size=20)
727+
assert is_minibatch(mb)
728+
729+
def test_not_allowed(self):
730+
with pytest.raises(ValueError, match="not valid for Minibatch"):
731+
mb = pm.Minibatch(at.as_tensor(self.data) * 2, batch_size=20)
732+
733+
def test_not_allowed2(self):
734+
with pytest.raises(ValueError, match="not valid for Minibatch"):
735+
mb = pm.Minibatch(self.data, at.as_tensor(self.data) * 2, batch_size=20)
736+
737+
def test_assert(self):
738+
with pytest.raises(
739+
AssertionError, match=r"All variables shape\[0\] in Minibatch should be equal"
740+
):
741+
d1, d2 = pm.Minibatch(self.data, self.data[::2], batch_size=20)
742+
d1.eval()

pymc/tests/variational/test_approximations.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,15 @@ def test_elbo_beta_kl(aux_total_size):
168168
np.testing.assert_allclose(
169169
elbo_via_total_size_scaled.eval(), elbo_via_beta_kl.eval(), rtol=0, atol=1e-1
170170
)
171+
172+
173+
def test_seeding_advi_fit():
174+
with pm.Model():
175+
x = pm.Normal("x", 0, 10, initval="prior")
176+
approx1 = pm.fit(
177+
random_seed=42, n=10, method="advi", obj_optimizer=pm.adagrad_window, progressbar=False
178+
)
179+
approx2 = pm.fit(
180+
random_seed=42, n=10, method="advi", obj_optimizer=pm.adagrad_window, progressbar=False
181+
)
182+
np.testing.assert_allclose(approx1.mean.eval(), approx2.mean.eval())

pymc/tests/variational/test_inference.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import io
1616
import operator
1717

18+
import cloudpickle
1819
import numpy as np
1920
import pytensor
2021
import pytensor.tensor as at
@@ -26,6 +27,7 @@
2627

2728
from pymc.pytensorf import intX
2829
from pymc.variational.inference import ADVI, ASVGD, SVGD, FullRankADVI
30+
from pymc.variational.opvi import NotImplementedInference
2931

3032
pytestmark = pytest.mark.usefixtures("strict_float32", "seeded_test")
3133

@@ -60,7 +62,7 @@ def simple_model_data(use_minibatch):
6062
d = n / sigma**2 + 1 / sigma0**2
6163
mu_post = (n * np.mean(data) / sigma**2 + mu0 / sigma0**2) / d
6264
if use_minibatch:
63-
data = pm.Minibatch(data)
65+
data = pm.Minibatch(data, batch_size=128)
6466
return dict(
6567
n=n,
6668
data=data,
@@ -118,7 +120,7 @@ def init_(**kw):
118120
@pytest.fixture(scope="function")
119121
def inference(inference_spec, simple_model):
120122
with simple_model:
121-
return inference_spec()
123+
return inference_spec(random_seed=42)
122124

123125

124126
@pytest.fixture(scope="function")
@@ -129,7 +131,7 @@ def fit_kwargs(inference, use_minibatch):
129131
obj_optimizer=pm.adagrad_window(learning_rate=0.01, n_win=50), n=12000
130132
),
131133
(FullRankADVI, "full"): dict(
132-
obj_optimizer=pm.adagrad_window(learning_rate=0.01, n_win=50), n=6000
134+
obj_optimizer=pm.adagrad_window(learning_rate=0.015, n_win=50), n=6000
133135
),
134136
(FullRankADVI, "mini"): dict(
135137
obj_optimizer=pm.adagrad_window(learning_rate=0.007, n_win=50), n=12000
@@ -149,6 +151,8 @@ def fit_kwargs(inference, use_minibatch):
149151
inference.approx.scale_cost_to_minibatch = False
150152
else:
151153
key = "full"
154+
if (type(inference), key) in {(SVGD, "mini"), (ASVGD, "mini")}:
155+
pytest.skip("Not Implemented Inference")
152156
return _select[(type(inference), key)]
153157

154158

@@ -179,7 +183,10 @@ def test_fit_start(inference_spec, simple_model):
179183

180184
with simple_model:
181185
inference = inference_spec(**kw)
182-
trace = inference.fit(n=0).sample(10000)
186+
try:
187+
trace = inference.fit(n=0).sample(10000)
188+
except NotImplementedInference as e:
189+
pytest.skip(str(e))
183190
np.testing.assert_allclose(np.mean(trace.posterior["mu"]), mu_init, rtol=0.05)
184191
if has_start_sigma:
185192
np.testing.assert_allclose(np.std(trace.posterior["mu"]), mu_sigma_init, rtol=0.05)
@@ -218,6 +225,8 @@ def test_fit_fn_text(method, kwargs, error):
218225

219226

220227
def test_profile(inference):
228+
if type(inference) in {SVGD, ASVGD}:
229+
pytest.skip("Not Implemented Inference")
221230
inference.run_profiling(n=100).summary()
222231

223232

@@ -239,8 +248,7 @@ def binomial_model_inference(binomial_model, inference_spec):
239248

240249
@pytest.mark.xfail("pytensor.config.warn_float64 == 'raise'", reason="too strict float32")
241250
def test_replacements(binomial_model_inference):
242-
d = at.bscalar()
243-
d.tag.test_value = 1
251+
d = pytensor.shared(1)
244252
approx = binomial_model_inference.approx
245253
p = approx.model.p
246254
p_t = p**3
@@ -252,7 +260,7 @@ def test_replacements(binomial_model_inference):
252260
), "p should be replaced"
253261
if pytensor.config.compute_test_value != "off":
254262
assert p_s.tag.test_value.shape == p_t.tag.test_value.shape
255-
sampled = [p_s.eval() for _ in range(100)]
263+
sampled = [pm.draw(p_s) for _ in range(100)]
256264
assert any(map(operator.ne, sampled[1:], sampled[:-1])) # stochastic
257265
p_z = approx.sample_node(p_t, deterministic=False, size=10)
258266
assert p_z.shape.eval() == (10,)
@@ -264,15 +272,17 @@ def test_replacements(binomial_model_inference):
264272

265273
try:
266274
p_d = approx.sample_node(p_t, deterministic=True)
267-
sampled = [p_d.eval() for _ in range(100)]
275+
sampled = [pm.draw(p_d) for _ in range(100)]
268276
assert all(map(operator.eq, sampled[1:], sampled[:-1])) # deterministic
269277
except opvi.NotImplementedInference:
270278
pass
271279

272280
p_r = approx.sample_node(p_t, deterministic=d)
273-
sampled = [p_r.eval({d: 1}) for _ in range(100)]
281+
d.set_value(1)
282+
sampled = [pm.draw(p_r) for _ in range(100)]
274283
assert all(map(operator.eq, sampled[1:], sampled[:-1])) # deterministic
275-
sampled = [p_r.eval({d: 0}) for _ in range(100)]
284+
d.set_value(0)
285+
sampled = [pm.draw(p_r) for _ in range(100)]
276286
assert any(map(operator.ne, sampled[1:], sampled[:-1])) # stochastic
277287

278288

@@ -325,8 +335,6 @@ def test_var_replacement():
325335

326336

327337
def test_clear_cache():
328-
import cloudpickle
329-
330338
with pm.Model():
331339
pm.Normal("n", 0, 1)
332340
inference = ADVI()

pymc/variational/approximations.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
1516
import numpy as np
1617
import pytensor
1718

@@ -24,11 +25,13 @@
2425

2526
from pymc.blocking import DictToArrayBijection
2627
from pymc.distributions.dist_math import rho2sigma
28+
from pymc.pytensorf import makeiter
2729
from pymc.variational import opvi
2830
from pymc.variational.opvi import (
2931
Approximation,
3032
Group,
3133
NotImplementedInference,
34+
_known_scan_ignored_inputs,
3235
node_property,
3336
)
3437

@@ -248,9 +251,12 @@ def randidx(self, size=None):
248251
pass
249252
else:
250253
size = tuple(np.atleast_1d(size))
251-
return self._rng.uniform(
252-
size=size, low=pm.floatX(0), high=pm.floatX(self.histogram.shape[0]) - pm.floatX(1e-16)
253-
).astype("int32")
254+
return at.random.integers(
255+
size=size,
256+
low=0,
257+
high=self.histogram.shape[0],
258+
rng=pytensor.shared(np.random.default_rng()),
259+
)
254260

255261
def _new_initial(self, size, deterministic, more_replacements=None):
256262
pytensor_condition_is_here = isinstance(deterministic, Variable)
@@ -383,8 +389,10 @@ def evaluate_over_trace(self, node):
383389
"""
384390
node = self.to_flat_input(node)
385391

386-
def sample(post, node):
392+
def sample(post, *_):
387393
return pytensor.clone_replace(node, {self.input: post})
388394

389-
nodes, _ = pytensor.scan(sample, self.histogram, non_sequences=[node])
395+
nodes, _ = pytensor.scan(
396+
sample, self.histogram, non_sequences=_known_scan_ignored_inputs(makeiter(node))
397+
)
390398
return nodes

pymc/variational/inference.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,9 @@ def __init__(self, approx=None, estimator=KSD, kernel=test_functions.rbf, **kwar
631631
"is often **underestimated** when using temperature = 1."
632632
)
633633
if approx is None:
634-
approx = FullRank(model=kwargs.pop("model", None))
634+
approx = FullRank(
635+
model=kwargs.pop("model", None), random_seed=kwargs.pop("random_seed", None)
636+
)
635637
super().__init__(estimator=estimator, approx=approx, kernel=kernel, **kwargs)
636638

637639
def fit(

pymc/variational/operators.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
import pymc as pm
2121

2222
from pymc.variational import opvi
23-
from pymc.variational.opvi import ObjectiveFunction, Operator
23+
from pymc.variational.opvi import (
24+
NotImplementedInference,
25+
ObjectiveFunction,
26+
Operator,
27+
_known_scan_ignored_inputs,
28+
)
2429
from pymc.variational.stein import Stein
2530

2631
__all__ = ["KL", "KSD"]
@@ -136,6 +141,10 @@ def __init__(self, approx, temperature=1):
136141

137142
def apply(self, f):
138143
# f: kernel function for KSD f(histogram) -> (k(x,.), \nabla_x k(x,.))
144+
if _known_scan_ignored_inputs([self.approx.model.logp()]):
145+
raise NotImplementedInference(
146+
"SVGD does not currently support Minibatch or Simulator RV"
147+
)
139148
stein = Stein(
140149
approx=self.approx,
141150
kernel=f,

0 commit comments

Comments
 (0)