|
15 | 15 | from typing import Any
|
16 | 16 |
|
17 | 17 | import arviz as az
|
18 |
| -import numcodecs |
19 | 18 | import numpy as np
|
20 | 19 | import xarray as xr
|
21 |
| -import zarr |
22 | 20 |
|
23 | 21 | from arviz.data.base import make_attrs
|
24 | 22 | from arviz.data.inference_data import WARMUP_TAG
|
25 |
| -from numcodecs.abc import Codec |
26 | 23 | from pytensor.tensor.variable import TensorVariable
|
27 | 24 |
|
28 | 25 | import pymc
|
|
44 | 41 | from pymc.util import UNSET, _UnsetType, get_default_varnames, is_transformed_name
|
45 | 42 |
|
46 | 43 | try:
|
| 44 | + import numcodecs |
| 45 | + import zarr |
| 46 | + |
| 47 | + from numcodecs.abc import Codec |
| 48 | + from zarr import Group |
47 | 49 | from zarr.storage import BaseStore, default_compressor
|
48 | 50 | from zarr.sync import Synchronizer
|
49 | 51 |
|
50 | 52 | _zarr_available = True
|
51 | 53 | except ImportError:
|
| 54 | + from typing import TYPE_CHECKING, TypeVar |
| 55 | + |
| 56 | + if not TYPE_CHECKING: |
| 57 | + Codec = TypeVar("Codec") |
| 58 | + Group = TypeVar("Group") |
| 59 | + BaseStore = TypeVar("BaseStore") |
| 60 | + Synchronizer = TypeVar("Synchronizer") |
52 | 61 | _zarr_available = False
|
53 | 62 |
|
54 | 63 |
|
@@ -243,7 +252,7 @@ def flush(self):
|
243 | 252 |
|
244 | 253 | def get_initial_fill_value_and_codec(
|
245 | 254 | dtype: Any,
|
246 |
| -) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, numcodecs.abc.Codec | None]: |
| 255 | +) -> tuple[FILL_VALUE_TYPE, np.typing.DTypeLike, Codec | None]: |
247 | 256 | _dtype = np.dtype(dtype)
|
248 | 257 | fill_value: FILL_VALUE_TYPE = None
|
249 | 258 | codec = None
|
@@ -366,27 +375,27 @@ def groups(self) -> list[str]:
|
366 | 375 | return [str(group_name) for group_name, _ in self.root.groups()]
|
367 | 376 |
|
368 | 377 | @property
|
369 |
| - def posterior(self) -> zarr.Group: |
| 378 | + def posterior(self) -> Group: |
370 | 379 | return self.root.posterior
|
371 | 380 |
|
372 | 381 | @property
|
373 |
| - def unconstrained_posterior(self) -> zarr.Group: |
| 382 | + def unconstrained_posterior(self) -> Group: |
374 | 383 | return self.root.unconstrained_posterior
|
375 | 384 |
|
376 | 385 | @property
|
377 |
| - def sample_stats(self) -> zarr.Group: |
| 386 | + def sample_stats(self) -> Group: |
378 | 387 | return self.root.sample_stats
|
379 | 388 |
|
380 | 389 | @property
|
381 |
| - def constant_data(self) -> zarr.Group: |
| 390 | + def constant_data(self) -> Group: |
382 | 391 | return self.root.constant_data
|
383 | 392 |
|
384 | 393 | @property
|
385 |
| - def observed_data(self) -> zarr.Group: |
| 394 | + def observed_data(self) -> Group: |
386 | 395 | return self.root.observed_data
|
387 | 396 |
|
388 | 397 | @property
|
389 |
| - def _sampling_state(self) -> zarr.Group: |
| 398 | + def _sampling_state(self) -> Group: |
390 | 399 | return self.root._sampling_state
|
391 | 400 |
|
392 | 401 | def init_trace(
|
@@ -646,12 +655,12 @@ def init_sampling_state_group(self, tune: int, chains: int):
|
646 | 655 |
|
647 | 656 | def init_group_with_empty(
|
648 | 657 | self,
|
649 |
| - group: zarr.Group, |
| 658 | + group: Group, |
650 | 659 | var_dtype_and_shape: dict[str, tuple[StatDtype, StatShape]],
|
651 | 660 | chains: int,
|
652 | 661 | draws: int,
|
653 | 662 | extra_var_attrs: dict | None = None,
|
654 |
| - ) -> zarr.Group: |
| 663 | + ) -> Group: |
655 | 664 | group_coords: dict[str, Any] = {"chain": range(chains), "draw": range(draws)}
|
656 | 665 | for name, (_dtype, shape) in var_dtype_and_shape.items():
|
657 | 666 | fill_value, dtype, object_codec = get_initial_fill_value_and_codec(_dtype)
|
@@ -689,8 +698,8 @@ def init_group_with_empty(
|
689 | 698 | array.attrs.update({"_ARRAY_DIMENSIONS": [dim]})
|
690 | 699 | return group
|
691 | 700 |
|
692 |
| - def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> zarr.Group | None: |
693 |
| - group: zarr.Group | None = None |
| 701 | + def create_group(self, name: str, data_dict: dict[str, np.ndarray]) -> Group | None: |
| 702 | + group: Group | None = None |
694 | 703 | if data_dict:
|
695 | 704 | group_coords = {}
|
696 | 705 | group = self.root.create_group(name=name, overwrite=True)
|
|
0 commit comments