diff --git a/pymc/model/transform/basic.py b/pymc/model/transform/basic.py index 877814cd6..558275d97 100644 --- a/pymc/model/transform/basic.py +++ b/pymc/model/transform/basic.py @@ -14,17 +14,21 @@ from collections.abc import Sequence from pytensor import Variable, clone_replace +from pytensor.compile import SharedVariable from pytensor.graph import ancestors from pytensor.graph.fg import FunctionGraph -from pymc.data import MinibatchOp +from pymc.data import Minibatch, MinibatchOp from pymc.model.core import Model from pymc.model.fgraph import ( ModelObservedRV, ModelVar, + extract_dims, fgraph_from_model, model_from_fgraph, + model_observed_rv, ) +from pymc.pytensorf import toposort_replace ModelVariable = Variable | str @@ -62,6 +66,47 @@ def parse_vars(model: Model, vars: ModelVariable | Sequence[ModelVariable]) -> l return [model[var] if isinstance(var, str) else var for var in vars_seq] +def model_to_minibatch(model: Model, batch_size: int) -> Model: + """Replace all Data containers with pm.Minibatch, and add total_size to all observed RVs.""" + from pymc.variational.minibatch_rv import create_minibatch_rv + + fgraph, memo = fgraph_from_model(model, inlined_views=True) + + # obs_rvs, data_vars = model.rvs_to_values.items() + + data_vars = [ + memo[datum].owner.inputs[0] + for datum in (model.named_vars[datum_name] for datum_name in model.named_vars) + if isinstance(datum, SharedVariable) + ] + + minibatch_vars = Minibatch(*data_vars, batch_size=batch_size) + replacements = {datum: minibatch_vars[i] for i, datum in enumerate(data_vars)} + assert 0 + # Add total_size to all observed RVs + total_size = data_vars[0].get_value().shape[0] + for obs_var in model.observed_RVs: + model_var = memo[obs_var] + var = model_var.owner.inputs[0] + var.name = model_var.name + dims = extract_dims(model_var) + + new_rv = create_minibatch_rv(var, total_size=total_size) + new_rv.name = var.name + + replacements[model_var] = model_observed_rv(new_rv, model.rvs_to_values[obs_var], *dims) + + # old_outs, old_coords, old_dim_lengths = fgraph.outputs, fgraph._coords, fgraph._dim_lengths + toposort_replace(fgraph, tuple(replacements.items())) + # new_outs = clone_replace(old_outs, replacements, rebuild_strict=False) # type: ignore[arg-type] + + # fgraph = FunctionGraph(outputs=new_outs, clone=False) + # fgraph._coords = old_coords # type: ignore[attr-defined] + # fgraph._dim_lengths = old_dim_lengths # type: ignore[attr-defined] + + return model_from_fgraph(fgraph, mutate_fgraph=True) + + def remove_minibatched_nodes(model: Model) -> Model: """Remove all uses of pm.Minibatch in the Model.""" fgraph, _ = fgraph_from_model(model) diff --git a/tests/model/transform/test_basic.py b/tests/model/transform/test_basic.py index 856fbf0b2..c3b33730d 100644 --- a/tests/model/transform/test_basic.py +++ b/tests/model/transform/test_basic.py @@ -15,7 +15,11 @@ import pymc as pm -from pymc.model.transform.basic import prune_vars_detached_from_observed, remove_minibatched_nodes +from pymc.model.transform.basic import ( + model_to_minibatch, + prune_vars_detached_from_observed, + remove_minibatched_nodes, +) def test_prune_vars_detached_from_observed(): @@ -34,6 +38,28 @@ def test_prune_vars_detached_from_observed(): assert set(pruned_m.named_vars.keys()) == {"obs_data", "a0", "a1", "a2", "obs"} +def test_model_to_minibatch(): + data_size = 100 + n_features = 4 + + obs_data = np.zeros((data_size,)) + X_data = np.random.normal(size=(data_size, n_features)) + + with pm.Model(coords={"feature": range(n_features), "data_dim": range(data_size)}) as m1: + obs_data = pm.Data("obs_data", obs_data, dims=["data_dim"]) + X_data = pm.Data("X_data", X_data, dims=["data_dim", "feature"]) + beta = pm.Normal("beta", dims="feature") + + mu = X_data @ beta + + y = pm.Normal("y", mu=mu, sigma=1, observed=obs_data, dims="data_dim") + + m2 = model_to_minibatch(m1, batch_size=10) + m2["y"].dprint() + + assert 0 + + def test_remove_minibatches(): data_size = 100 data = np.zeros((data_size,))