Skip to content

Commit c721232

Browse files
michaelosthegetwiecki
authored andcommitted
Introduce pm.Data(..., mutable) kwarg
By passing `pm.Data(mutable=False)` one can create a `TensorConstant` instead of a `SharedVariable`. Data variables with known, fixed shape can enhance performance and compatibility in some situations. `pm.ConstantData` or `pm.MutableData` wrappers are provided as alternative syntax. This is the basis for solving #4441.
1 parent a9b8afe commit c721232

File tree

3 files changed

+109
-57
lines changed

3 files changed

+109
-57
lines changed

pymc/data.py

+105-53
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,17 @@
1919
import urllib.request
2020

2121
from copy import copy
22-
from typing import Any, Dict, List, Optional, Sequence
22+
from typing import Any, Dict, List, Optional, Sequence, Union
2323

2424
import aesara
2525
import aesara.tensor as at
2626
import numpy as np
2727
import pandas as pd
2828

29+
from aesara.compile.sharedvalue import SharedVariable
2930
from aesara.graph.basic import Apply
3031
from aesara.tensor.type import TensorType
31-
from aesara.tensor.var import TensorVariable
32+
from aesara.tensor.var import TensorConstant, TensorVariable
3233

3334
import pymc as pm
3435

@@ -40,6 +41,8 @@
4041
"Minibatch",
4142
"align_minibatches",
4243
"Data",
44+
"ConstantData",
45+
"MutableData",
4346
]
4447
BASE_URL = "https://raw.githubusercontent.com/pymc-devs/pymc-examples/main/examples/data/{filename}"
4548

@@ -502,9 +505,64 @@ def determine_coords(model, value, dims: Optional[Sequence[str]] = None) -> Dict
502505
return coords
503506

504507

505-
class Data:
506-
"""Data container class that wraps :func:`aesara.shared` and lets
507-
the model be aware of its inputs and outputs.
508+
def ConstantData(
509+
name: str,
510+
value,
511+
*,
512+
dims: Optional[Sequence[str]] = None,
513+
export_index_as_coords=False,
514+
**kwargs,
515+
) -> TensorConstant:
516+
"""Alias for ``pm.Data(..., mutable=False)``.
517+
518+
Registers the ``value`` as a ``TensorConstant`` with the model.
519+
"""
520+
return Data(
521+
name,
522+
value,
523+
dims=dims,
524+
export_index_as_coords=export_index_as_coords,
525+
mutable=False,
526+
**kwargs,
527+
)
528+
529+
530+
def MutableData(
531+
name: str,
532+
value,
533+
*,
534+
dims: Optional[Sequence[str]] = None,
535+
export_index_as_coords=False,
536+
**kwargs,
537+
) -> SharedVariable:
538+
"""Alias for ``pm.Data(..., mutable=True)``.
539+
540+
Registers the ``value`` as a ``SharedVariable`` with the model.
541+
"""
542+
return Data(
543+
name,
544+
value,
545+
dims=dims,
546+
export_index_as_coords=export_index_as_coords,
547+
mutable=True,
548+
**kwargs,
549+
)
550+
551+
552+
def Data(
553+
name: str,
554+
value,
555+
*,
556+
dims: Optional[Sequence[str]] = None,
557+
export_index_as_coords=False,
558+
mutable: bool = True,
559+
**kwargs,
560+
) -> Union[SharedVariable, TensorConstant]:
561+
"""Data container that registers a data variable with the model.
562+
563+
Depending on the ``mutable`` setting (default: True), the variable
564+
is registered as a ``SharedVariable``, enabling it to be altered
565+
in value and shape, but NOT in dimensionality using ``pm.set_data()``.
508566
509567
Parameters
510568
----------
@@ -552,52 +610,46 @@ class Data:
552610
For more information, take a look at this example notebook
553611
https://docs.pymc.io/notebooks/data_container.html
554612
"""
613+
if isinstance(value, list):
614+
value = np.array(value)
555615

556-
def __new__(
557-
self,
558-
name: str,
559-
value,
560-
*,
561-
dims: Optional[Sequence[str]] = None,
562-
export_index_as_coords=False,
563-
**kwargs,
564-
):
565-
if isinstance(value, list):
566-
value = np.array(value)
567-
568-
# Add data container to the named variables of the model.
569-
try:
570-
model = pm.Model.get_context()
571-
except TypeError:
572-
raise TypeError(
573-
"No model on context stack, which is needed to instantiate a data container. "
574-
"Add variable inside a 'with model:' block."
575-
)
576-
name = model.name_for(name)
577-
578-
# `pandas_to_array` takes care of parameter `value` and
579-
# transforms it to something digestible for pymc
580-
shared_object = aesara.shared(pandas_to_array(value), name, **kwargs)
581-
582-
if isinstance(dims, str):
583-
dims = (dims,)
584-
if not (dims is None or len(dims) == shared_object.ndim):
585-
raise pm.exceptions.ShapeError(
586-
"Length of `dims` must match the dimensions of the dataset.",
587-
actual=len(dims),
588-
expected=shared_object.ndim,
589-
)
590-
591-
coords = determine_coords(model, value, dims)
592-
593-
if export_index_as_coords:
594-
model.add_coords(coords)
595-
elif dims:
596-
# Register new dimension lengths
597-
for d, dname in enumerate(dims):
598-
if not dname in model.dim_lengths:
599-
model.add_coord(dname, values=None, length=shared_object.shape[d])
600-
601-
model.add_random_variable(shared_object, dims=dims)
602-
603-
return shared_object
616+
# Add data container to the named variables of the model.
617+
try:
618+
model = pm.Model.get_context()
619+
except TypeError:
620+
raise TypeError(
621+
"No model on context stack, which is needed to instantiate a data container. "
622+
"Add variable inside a 'with model:' block."
623+
)
624+
name = model.name_for(name)
625+
626+
# `pandas_to_array` takes care of parameter `value` and
627+
# transforms it to something digestible for Aesara.
628+
arr = pandas_to_array(value)
629+
if mutable:
630+
x = aesara.shared(arr, name, **kwargs)
631+
else:
632+
x = at.as_tensor_variable(arr, name, **kwargs)
633+
634+
if isinstance(dims, str):
635+
dims = (dims,)
636+
if not (dims is None or len(dims) == x.ndim):
637+
raise pm.exceptions.ShapeError(
638+
"Length of `dims` must match the dimensions of the dataset.",
639+
actual=len(dims),
640+
expected=x.ndim,
641+
)
642+
643+
coords = determine_coords(model, value, dims)
644+
645+
if export_index_as_coords:
646+
model.add_coords(coords)
647+
elif dims:
648+
# Register new dimension lengths
649+
for d, dname in enumerate(dims):
650+
if not dname in model.dim_lengths:
651+
model.add_coord(dname, values=None, length=x.shape[d])
652+
653+
model.add_random_variable(x, dims=dims)
654+
655+
return x

pymc/model_graph.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from aesara.compile.sharedvalue import SharedVariable
2121
from aesara.graph.basic import walk
2222
from aesara.tensor.random.op import RandomVariable
23-
from aesara.tensor.var import TensorVariable
23+
from aesara.tensor.var import TensorConstant, TensorVariable
2424

2525
import pymc as pm
2626

@@ -133,7 +133,7 @@ def _make_node(self, var_name, graph, *, formatting: str = "plain"):
133133
shape = "octagon"
134134
style = "filled"
135135
label = f"{var_name}\n~\nPotential"
136-
elif isinstance(v, SharedVariable):
136+
elif isinstance(v, (SharedVariable, TensorConstant)):
137137
shape = "box"
138138
style = "rounded, filled"
139139
label = f"{var_name}\n~\nData"

pymc/tests/test_data_container.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def test_sample(self):
7171

7272
def test_sample_posterior_predictive_after_set_data(self):
7373
with pm.Model() as model:
74-
x = pm.Data("x", [1.0, 2.0, 3.0])
75-
y = pm.Data("y", [1.0, 2.0, 3.0])
74+
x = pm.MutableData("x", [1.0, 2.0, 3.0])
75+
y = pm.ConstantData("y", [1.0, 2.0, 3.0])
7676
beta = pm.Normal("beta", 0, 10.0)
7777
pm.Normal("obs", beta * x, np.sqrt(1e-2), observed=y)
7878
trace = pm.sample(

0 commit comments

Comments
 (0)