diff --git a/pymc/backends/zarr.py b/pymc/backends/zarr.py index e9aba5fe0d..b9c1e49ea3 100644 --- a/pymc/backends/zarr.py +++ b/pymc/backends/zarr.py @@ -15,14 +15,11 @@ from typing import Any import arviz as az -import numcodecs import numpy as np import xarray as xr -import zarr from arviz.data.base import make_attrs from arviz.data.inference_data import WARMUP_TAG -from numcodecs.abc import Codec from pytensor.tensor.variable import TensorVariable import pymc @@ -44,11 +41,23 @@ from pymc.util import UNSET, _UnsetType, get_default_varnames, is_transformed_name try: + import numcodecs + import zarr + + from numcodecs.abc import Codec + from zarr import Group from zarr.storage import BaseStore, default_compressor from zarr.sync import Synchronizer _zarr_available = True except ImportError: + from typing import TYPE_CHECKING, TypeVar + + if not TYPE_CHECKING: + Codec = TypeVar("Codec") + Group = TypeVar("Group") + BaseStore = TypeVar("BaseStore") + Synchronizer = TypeVar("Synchronizer") _zarr_available = False @@ -243,7 +252,7 @@ def flush(self): def get_initial_fill_value_and_codec( dtype: Any, -) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, numcodecs.abc.Codec | None]: +) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, Codec | None]: _dtype = np.dtype(dtype) fill_value: FILL_VALUE_TYPE = None codec = None @@ -366,27 +375,27 @@ def groups(self) -> list[str]: return [str(group_name) for group_name, _ in self.root.groups()] @property - def posterior(self) -> zarr.Group: + def posterior(self) -> Group: return self.root.posterior @property - def unconstrained_posterior(self) -> zarr.Group: + def unconstrained_posterior(self) -> Group: return self.root.unconstrained_posterior @property - def sample_stats(self) -> zarr.Group: + def sample_stats(self) -> Group: return self.root.sample_stats @property - def constant_data(self) -> zarr.Group: + def constant_data(self) -> Group: return self.root.constant_data @property - def observed_data(self) -> zarr.Group: + def observed_data(self) -> Group: return self.root.observed_data @property - def _sampling_state(self) -> zarr.Group: + def _sampling_state(self) -> Group: return self.root._sampling_state def init_trace( @@ -646,12 +655,12 @@ def init_sampling_state_group(self, tune: int, chains: int): def init_group_with_empty( self, - group: zarr.Group, + group: Group, var_dtype_and_shape: dict[str, tuple[StatDtype, StatShape]], chains: int, draws: int, extra_var_attrs: dict | None = None, - ) -> zarr.Group: + ) -> Group: group_coords: dict[str, Any] = {"chain": range(chains), "draw": range(draws)} for name, (_dtype, shape) in var_dtype_and_shape.items(): fill_value, dtype, object_codec = get_initial_fill_value_and_codec(_dtype) @@ -689,8 +698,8 @@ def init_group_with_empty( array.attrs.update({"_ARRAY_DIMENSIONS": [dim]}) return group - def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> zarr.Group | None: - group: zarr.Group | None = None + def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> Group | None: + group: Group | None = None if data_dict: group_coords = {} group = self.root.create_group(name=name, overwrite=True) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index ca91325ff1..7cbb6df26e 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -41,7 +41,6 @@ from rich.theme import Theme from threadpoolctl import threadpool_limits from typing_extensions import Protocol -from zarr.storage import MemoryStore import pymc as pm @@ -80,6 +79,11 @@ ) from pymc.vartypes import discrete_types +try: + from zarr.storage import MemoryStore +except ImportError: + MemoryStore = type("MemoryStore", (), {}) + sys.setrecursionlimit(10000) __all__ = [