Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 163 additions & 16 deletions icechunk-python/python/icechunk/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import xarray as xr
import zarr
from icechunk import IcechunkStore, Session
from icechunk.distributed import merge_sessions
from icechunk.session import ForkSession
from icechunk.vendor.xarray import _choose_default_mode
from xarray import DataArray, Dataset
from xarray import DataArray, Dataset, DataTree
from xarray.backends.common import ArrayWriter
from xarray.backends.zarr import ZarrStore

Expand All @@ -23,8 +24,13 @@

try:
has_dask = importlib.util.find_spec("dask") is not None
if has_dask:
from dask.highlevelgraph import HighLevelGraph
else:
HighLevelGraph = None
except ImportError:
has_dask = False
HighLevelGraph = None

if Version(xr.__version__) < Version("2024.10.0"):
raise ValueError(
Expand All @@ -44,11 +50,39 @@ def is_dask_collection(x: Any) -> bool:
if has_dask:
import dask

return dask.base.is_dask_collection(x)
if isinstance(x, DataTree):
return bool(datatree_dask_graph(x))
else:
return dask.base.is_dask_collection(x)
else:
return False


def datatree_dask_graph(dt: DataTree) -> "HighLevelGraph | None": # type: ignore[name-defined]
# copied from `Dataset.__dask_graph__()`.
# Should ideally be upstreamed into xarray as part of making DataTree a true dask collection - see https://github.com/pydata/xarray/issues/9355.

all_variables = {
f"{path}/{var_name}" if path != "." else var_name: variable
for path, node in dt.subtree_with_keys
for var_name, variable in node.variables.items()
}

graphs = {k: v.__dask_graph__() for k, v in all_variables.items()}
graphs = {k: v for k, v in graphs.items() if v is not None}
if not graphs:
return None
else:
try:
from dask.highlevelgraph import HighLevelGraph

return HighLevelGraph.merge(*graphs.values())
except ImportError:
from dask import sharedict

return sharedict.merge(*graphs.values())


class LazyArrayWriter(ArrayWriter):
def __init__(self) -> None:
super().__init__() # type: ignore[no-untyped-call]
Expand Down Expand Up @@ -188,6 +222,48 @@ def write_lazy(
return session_merge_reduction(stored_arrays, split_every=split_every)


def write_ds(
ds,
store,
safe_chunks,
group,
mode,
append_dim,
region,
encoding,
chunkmanager_store_kwargs,
) -> ForkSession | None:
writer = _XarrayDatasetWriter(ds, store=store, safe_chunks=safe_chunks)
writer._open_group(group=group, mode=mode, append_dim=append_dim, region=region)

# write metadata
writer.write_metadata(encoding)
# write in-memory arrays
writer.write_eager()
# eagerly write dask arrays
maybe_fork_session = writer.write_lazy(
chunkmanager_store_kwargs=chunkmanager_store_kwargs
)

return maybe_fork_session


# overload because several kwargs are currently forbidden for DataTree, and ``write_inherited_coords`` only applies to DataTree
@overload
def to_icechunk(
obj: DataTree,
session: Session,
*,
mode: ZarrWriteModes | None = None,
safe_chunks: bool = True,
encoding: Mapping[Any, Any] | None = None,
chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None,
write_inherited_coords: bool = False,
split_every: int | None = None,
) -> None: ...


@overload
def to_icechunk(
obj: DataArray | Dataset,
session: Session,
Expand All @@ -200,14 +276,32 @@ def to_icechunk(
encoding: Mapping[Any, Any] | None = None,
chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None,
split_every: int | None = None,
) -> None: ...


def to_icechunk(
obj: DataArray | Dataset | DataTree,
session: Session,
*,
group: str | None = None,
mode: ZarrWriteModes | None = None,
safe_chunks: bool = True,
append_dim: Hashable | None = None,
region: Region = None,
encoding: Mapping[Any, Any] | None = None,
chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None,
write_inherited_coords: bool = False,
split_every: int | None = None,
) -> None:
"""
Write an Xarray object to a group of an Icechunk store.

Parameters
----------
obj: DataArray or Dataset
Xarray object to write
obj: DataArray, Dataset, or DataTree
Xarray object to write.

Note: When passing a DataTree, the ``append_dim``, ``region``, and ``group`` parameters are not yet supported.
session : icechunk.Session
Writable Icechunk Session
mode : {"w", "w-", "a", "a-", r+", None}, optional
Expand Down Expand Up @@ -265,6 +359,11 @@ def to_icechunk(
Additional keyword arguments passed on to the `ChunkManager.store` method used to store
chunked arrays. For example for a dask array additional kwargs will be passed eventually to
`dask.array.store()`. Experimental API that should not be relied upon.
write_inherited_coords : bool, default: False
If true, replicate inherited coordinates on all descendant nodes.
Otherwise, only write coordinates at the level at which they are
originally defined. This saves disk space, but requires opening the
full tree to load inherited coordinates.
split_every: int, optional
Number of tasks to merge at every level of the tree reduction.

Expand All @@ -283,8 +382,20 @@ def to_icechunk(
``append_dim`` at the same time. To create empty arrays to fill
in with ``region``, use the `_XarrayDatasetWriter` directly.
"""

as_dataset = _make_dataset(obj)
# Validate parameters for DataTree
if isinstance(obj, DataTree):
if group is not None:
raise NotImplementedError(
"specifying a root group for the tree has not been implemented"
)
if append_dim is not None:
raise NotImplementedError(
"The 'append_dim' parameter is not yet supported when writing DataTree objects."
)
if region is not None:
raise NotImplementedError(
"The 'region' parameter is not yet supported when writing DataTree objects."
)

# This ugliness is needed so that we allow users to call `to_icechunk` with a dirty Session
# for _serial_ writes
Expand All @@ -299,18 +410,54 @@ def to_icechunk(
else:
fork = session

writer = _XarrayDatasetWriter(as_dataset, store=fork.store, safe_chunks=safe_chunks)
if isinstance(obj, DataTree):
dt = obj

writer._open_group(group=group, mode=mode, append_dim=append_dim, region=region)
if encoding is None:
encoding = {}
if set(encoding) - set(dt.groups):
raise ValueError(
f"unexpected encoding group name(s) provided: {set(encoding) - set(dt.groups)}"
)

maybe_forked_sessions: list[ForkSession | None] = []
for rel_path, node in dt.subtree_with_keys:
at_root = node is dt
dataset = node.to_dataset(inherit=write_inherited_coords or at_root)

maybe_fork_session = write_ds(
ds=dataset,
store=fork.store,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for the future: this should be safe since each iteration of the loop writes to different group, so there are no conflicts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this as a comment in 7b0fec5

safe_chunks=safe_chunks,
group=dt[rel_path].path,
mode=mode,
append_dim=append_dim,
region=region,
encoding=encoding,
chunkmanager_store_kwargs=chunkmanager_store_kwargs,
)
maybe_forked_sessions.append(maybe_fork_session)

if any(maybe_forked_sessions) and is_dask:
# Note: This should be safe since each iteration of the loop writes to a different group, so there are no conflicts.
maybe_fork_session = merge_sessions(maybe_forked_sessions)
else:
maybe_fork_session = None

else:
as_dataset = _make_dataset(obj)
maybe_fork_session = write_ds(
ds=as_dataset,
store=fork.store,
safe_chunks=safe_chunks,
group=group,
mode=mode,
append_dim=append_dim,
region=region,
encoding=encoding,
chunkmanager_store_kwargs=chunkmanager_store_kwargs,
)

# write metadata
writer.write_metadata(encoding)
# write in-memory arrays
writer.write_eager()
# eagerly write dask arrays
maybe_fork_session = writer.write_lazy(
chunkmanager_store_kwargs=chunkmanager_store_kwargs
)
if is_dask:
if maybe_fork_session is None:
raise RuntimeError(
Expand Down
63 changes: 62 additions & 1 deletion icechunk-python/tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,48 @@ def create_test_data(
return obj


def create_test_datatree() -> xr.DataTree:
return xr.DataTree.from_dict(
{
"/": xr.Dataset(
data_vars={
"bar": ("x", ["hello", "world"]),
},
coords={
"x": (
"x",
[1, 2],
), # inherited dimension coordinate that can't be overridden
"w": (
"x",
[0.1, 0.2],
), # inherited non-dimension coordinate to override
},
),
"/a": xr.Dataset(
data_vars={
"foo": ("x", ["alpha", "beta"]),
},
coords={
"w": ("x", [10, 20]), # override inherited non-dimension coordinate
"z": ("z", ["alpha", "beta"]), # non-inherited dimension coordinate
},
),
"/b": xr.Dataset(
data_vars={
"foo": ("x", ["gamma", "delta"]),
},
coords={
"z": (
"z",
["alpha", "beta", "gamma"],
), # override inherited non-dimension coordinate with different length (i.e. multi-resolution)
},
),
}
)


@contextlib.contextmanager
def roundtrip(
data: xr.Dataset, *, commit: bool = False
Expand All @@ -62,12 +104,31 @@ def roundtrip(
yield ds


def test_xarray_to_icechunk() -> None:
def test_xarray_dataset_to_icechunk() -> None:
ds = create_test_data()
with roundtrip(ds) as actual:
assert_identical(actual, ds)


@contextlib.contextmanager
def roundtrip_datatree(
dt: xr.DataTree, *, commit: bool = False
) -> Generator[xr.DataTree, None, None]:
with tempfile.TemporaryDirectory() as tmpdir:
repo = Repository.create(local_filesystem_storage(tmpdir))
session = repo.writable_session("main")
to_icechunk(dt, session=session, mode="w")
session.commit("write")
with xr.open_datatree(session.store, consolidated=False, engine="zarr") as dt:
yield dt


def test_xarray_datatree_to_icechunk() -> None:
dt = create_test_datatree()
with roundtrip_datatree(dt) as actual:
assert_identical(actual, dt)


def test_repeated_to_icechunk_serial() -> None:
ds = create_test_data()
repo = Repository.create(in_memory_storage())
Expand Down
Loading