Skip to content

Commit bd519d4

Browse files
lucianopazricardoV94
authored andcommitted
Fix conditional import of zarr
1 parent 671d704 commit bd519d4

File tree

2 files changed

+28
-15
lines changed

2 files changed

+28
-15
lines changed

pymc/backends/zarr.py

+23-14
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,11 @@
1515
from typing import Any
1616

1717
import arviz as az
18-
import numcodecs
1918
import numpy as np
2019
import xarray as xr
21-
import zarr
2220

2321
from arviz.data.base import make_attrs
2422
from arviz.data.inference_data import WARMUP_TAG
25-
from numcodecs.abc import Codec
2623
from pytensor.tensor.variable import TensorVariable
2724

2825
import pymc
@@ -44,11 +41,23 @@
4441
from pymc.util import UNSET, _UnsetType, get_default_varnames, is_transformed_name
4542

4643
try:
44+
import numcodecs
45+
import zarr
46+
47+
from numcodecs.abc import Codec
48+
from zarr import Group
4749
from zarr.storage import BaseStore, default_compressor
4850
from zarr.sync import Synchronizer
4951

5052
_zarr_available = True
5153
except ImportError:
54+
from typing import TYPE_CHECKING, TypeVar
55+
56+
if not TYPE_CHECKING:
57+
Codec = TypeVar("Codec")
58+
Group = TypeVar("Group")
59+
BaseStore = TypeVar("BaseStore")
60+
Synchronizer = TypeVar("Synchronizer")
5261
_zarr_available = False
5362

5463

@@ -243,7 +252,7 @@ def flush(self):
243252

244253
def get_initial_fill_value_and_codec(
245254
dtype: Any,
246-
) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, numcodecs.abc.Codec | None]:
255+
) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, Codec | None]:
247256
_dtype = np.dtype(dtype)
248257
fill_value: FILL_VALUE_TYPE = None
249258
codec = None
@@ -366,27 +375,27 @@ def groups(self) -> list[str]:
366375
return [str(group_name) for group_name, _ in self.root.groups()]
367376

368377
@property
369-
def posterior(self) -> zarr.Group:
378+
def posterior(self) -> Group:
370379
return self.root.posterior
371380

372381
@property
373-
def unconstrained_posterior(self) -> zarr.Group:
382+
def unconstrained_posterior(self) -> Group:
374383
return self.root.unconstrained_posterior
375384

376385
@property
377-
def sample_stats(self) -> zarr.Group:
386+
def sample_stats(self) -> Group:
378387
return self.root.sample_stats
379388

380389
@property
381-
def constant_data(self) -> zarr.Group:
390+
def constant_data(self) -> Group:
382391
return self.root.constant_data
383392

384393
@property
385-
def observed_data(self) -> zarr.Group:
394+
def observed_data(self) -> Group:
386395
return self.root.observed_data
387396

388397
@property
389-
def _sampling_state(self) -> zarr.Group:
398+
def _sampling_state(self) -> Group:
390399
return self.root._sampling_state
391400

392401
def init_trace(
@@ -646,12 +655,12 @@ def init_sampling_state_group(self, tune: int, chains: int):
646655

647656
def init_group_with_empty(
648657
self,
649-
group: zarr.Group,
658+
group: Group,
650659
var_dtype_and_shape: dict[str, tuple[StatDtype, StatShape]],
651660
chains: int,
652661
draws: int,
653662
extra_var_attrs: dict | None = None,
654-
) -> zarr.Group:
663+
) -> Group:
655664
group_coords: dict[str, Any] = {"chain": range(chains), "draw": range(draws)}
656665
for name, (_dtype, shape) in var_dtype_and_shape.items():
657666
fill_value, dtype, object_codec = get_initial_fill_value_and_codec(_dtype)
@@ -689,8 +698,8 @@ def init_group_with_empty(
689698
array.attrs.update({"_ARRAY_DIMENSIONS": [dim]})
690699
return group
691700

692-
def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> zarr.Group | None:
693-
group: zarr.Group | None = None
701+
def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> Group | None:
702+
group: Group | None = None
694703
if data_dict:
695704
group_coords = {}
696705
group = self.root.create_group(name=name, overwrite=True)

pymc/sampling/mcmc.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
from rich.theme import Theme
4242
from threadpoolctl import threadpool_limits
4343
from typing_extensions import Protocol
44-
from zarr.storage import MemoryStore
4544

4645
import pymc as pm
4746

@@ -80,6 +79,11 @@
8079
)
8180
from pymc.vartypes import discrete_types
8281

82+
try:
83+
from zarr.storage import MemoryStore
84+
except ImportError:
85+
MemoryStore = type("MemoryStore", (), {})
86+
8387
sys.setrecursionlimit(10000)
8488

8589
__all__ = [

0 commit comments

Comments
 (0)