Skip to content

Commit 8796272

Browse files
Manage coordinates as tuples (#5061)
* Added regression test for issue #5043 * Automatically convert coord values given to the model to tuples * Make them numpy arrays for the InferenceData conversion. Closes #5043
1 parent 595e164 commit 8796272

File tree

4 files changed

+58
-9
lines changed

4 files changed

+58
-9
lines changed

pymc/backends/arviz.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,12 @@ def arbitrary_element(dct: Dict[Any, np.ndarray]) -> np.ndarray:
222222
aelem = arbitrary_element(get_from)
223223
self.ndraws = aelem.shape[0]
224224

225-
self.coords = {} if coords is None else coords
226-
if hasattr(self.model, "coords"):
227-
self.coords = {**self.model.coords, **self.coords}
228-
self.coords = {key: value for key, value in self.coords.items() if value is not None}
225+
self.coords = {**self.model.coords, **(coords or {})}
226+
self.coords = {
227+
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
228+
for cname, cvals in self.coords.items()
229+
if cvals is not None
230+
}
229231

230232
self.dims = {} if dims is None else dims
231233
if hasattr(self.model, "RV_dims"):

pymc/model.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,7 @@ def RV_dims(self) -> Dict[str, Tuple[Union[str, None], ...]]:
871871
return self._RV_dims
872872

873873
@property
874-
def coords(self) -> Dict[str, Union[Sequence, None]]:
874+
def coords(self) -> Dict[str, Union[Tuple, None]]:
875875
"""Coordinate values for model dimensions."""
876876
return self._coords
877877

@@ -1096,8 +1096,12 @@ def add_coord(
10961096
raise ValueError(
10971097
f"The `length` passed for the '{name}' coord must be an Aesara Variable or None."
10981098
)
1099+
if values is not None:
1100+
# Conversion to a tuple ensures that the coordinate values are immutable.
1101+
# Also unlike numpy arrays the's tuple.index(...) which is handy to work with.
1102+
values = tuple(values)
10991103
if name in self.coords:
1100-
if not values.equals(self.coords[name]):
1104+
if not np.array_equal(values, self.coords[name]):
11011105
raise ValueError(f"Duplicate and incompatible coordinate: {name}.")
11021106
else:
11031107
self._coords[name] = values

pymc/tests/test_data_container.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -287,12 +287,12 @@ def test_explicit_coords(self):
287287
pm.Data("observations", data, dims=("rows", "columns"))
288288

289289
assert "rows" in pmodel.coords
290-
assert pmodel.coords["rows"] == ["R1", "R2", "R3", "R4", "R5"]
290+
assert pmodel.coords["rows"] == ("R1", "R2", "R3", "R4", "R5")
291291
assert "rows" in pmodel.dim_lengths
292292
assert isinstance(pmodel.dim_lengths["rows"], ScalarSharedVariable)
293293
assert pmodel.dim_lengths["rows"].eval() == 5
294294
assert "columns" in pmodel.coords
295-
assert pmodel.coords["columns"] == ["C1", "C2", "C3", "C4", "C5", "C6", "C7"]
295+
assert pmodel.coords["columns"] == ("C1", "C2", "C3", "C4", "C5", "C6", "C7")
296296
assert pmodel.RV_dims == {"observations": ("rows", "columns")}
297297
assert "columns" in pmodel.dim_lengths
298298
assert isinstance(pmodel.dim_lengths["columns"], ScalarSharedVariable)

pymc/tests/test_idata_conversion.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212

1313
import pymc as pm
1414

15-
from pymc.backends.arviz import predictions_to_inference_data, to_inference_data
15+
from pymc.backends.arviz import (
16+
InferenceDataConverter,
17+
predictions_to_inference_data,
18+
to_inference_data,
19+
)
1620

1721

1822
@pytest.fixture(scope="module")
@@ -598,6 +602,45 @@ def test_constant_data_coords_issue_5046(self):
598602
for dname, cvals in coords.items():
599603
np.testing.assert_array_equal(ds[dname].values, cvals)
600604

605+
def test_issue_5043_autoconvert_coord_values(self):
606+
coords = {
607+
"city": pd.Series(["Bonn", "Berlin"]),
608+
}
609+
with pm.Model(coords=coords) as pmodel:
610+
# The model tracks coord values as (immutable) tuples
611+
assert isinstance(pmodel.coords["city"], tuple)
612+
pm.Normal("x", dims="city")
613+
mtrace = pm.sample(
614+
return_inferencedata=False,
615+
compute_convergence_checks=False,
616+
step=pm.Metropolis(),
617+
cores=1,
618+
tune=7,
619+
draws=15,
620+
)
621+
# The converter must convert coord values them to numpy arrays
622+
# because tuples as coordinate values causes problems with xarray.
623+
converter = InferenceDataConverter(trace=mtrace)
624+
assert isinstance(converter.coords["city"], np.ndarray)
625+
converter.to_inference_data()
626+
627+
# We're not automatically converting things other than tuple,
628+
# so advanced use cases remain supported at the InferenceData level.
629+
# They just can't be used in the model construction already.
630+
converter = InferenceDataConverter(
631+
trace=mtrace,
632+
coords={
633+
"city": pd.MultiIndex.from_tuples(
634+
[
635+
("Bonn", 53111),
636+
("Berlin", 10178),
637+
],
638+
names=["name", "zipcode"],
639+
)
640+
},
641+
)
642+
assert isinstance(converter.coords["city"], pd.MultiIndex)
643+
601644

602645
class TestPyMCWarmupHandling:
603646
@pytest.mark.parametrize("save_warmup", [False, True])

0 commit comments

Comments
 (0)