From e8b5bba0b356523a35c02650c9de4a97e1ee5a24 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Mon, 22 Sep 2025 12:59:34 -0400 Subject: [PATCH 1/8] typing overloads --- icechunk-python/python/icechunk/xarray.py | 64 +++++++++++++++++++++-- 1 file changed, 60 insertions(+), 4 deletions(-) diff --git a/icechunk-python/python/icechunk/xarray.py b/icechunk-python/python/icechunk/xarray.py index 26a2df5ff..93eb27c22 100644 --- a/icechunk-python/python/icechunk/xarray.py +++ b/icechunk-python/python/icechunk/xarray.py @@ -11,7 +11,7 @@ from icechunk import IcechunkStore, Session from icechunk.session import ForkSession from icechunk.vendor.xarray import _choose_default_mode -from xarray import DataArray, Dataset +from xarray import DataArray, Dataset, DataTree from xarray.backends.common import ArrayWriter from xarray.backends.zarr import ZarrStore @@ -181,6 +181,22 @@ def write_lazy( return session_merge_reduction(stored_arrays, split_every=split_every) +# Overload for DataTree - restricted parameters +@overload +def to_icechunk( + obj: DataTree, + session: Session, + *, + mode: ZarrWriteModes | None = None, + safe_chunks: bool = True, + encoding: Mapping[Any, Any] | None = None, + chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None, + split_every: int | None = None, +) -> None: ... + + +# Overload for DataArray/Dataset - full parameters +@overload def to_icechunk( obj: DataArray | Dataset, session: Session, @@ -193,14 +209,31 @@ def to_icechunk( encoding: Mapping[Any, Any] | None = None, chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None, split_every: int | None = None, +) -> None: ... + + +def to_icechunk( + obj: DataArray | Dataset | DataTree, + session: Session, + *, + group: str | None = None, + mode: ZarrWriteModes | None = None, + safe_chunks: bool = True, + append_dim: Hashable | None = None, + region: Region = None, + encoding: Mapping[Any, Any] | None = None, + chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None, + split_every: int | None = None, ) -> None: """ Write an Xarray object to a group of an Icechunk store. Parameters ---------- - obj: DataArray or Dataset - Xarray object to write + obj: DataArray, Dataset, or DataTree + Xarray object to write. + + Note: When passing a DataTree, the ``append_dim``, ``region``, and ``group`` parameters are not yet supported. session : icechunk.Session Writable Icechunk Session mode : {"w", "w-", "a", "a-", r+", None}, optional @@ -276,6 +309,20 @@ def to_icechunk( ``append_dim`` at the same time. To create empty arrays to fill in with ``region``, use the `_XarrayDatasetWriter` directly. """ + # Validate parameters for DataTree + if isinstance(obj, DataTree): + if group is not None: + raise NotImplementedError( + "specifying a root group for the tree has not been implemented" + ) + if append_dim is not None: + raise NotImplementedError( + "The 'append_dim' parameter is not yet supported when writing DataTree objects." + ) + if region is not None: + raise NotImplementedError( + "The 'region' parameter is not yet supported when writing DataTree objects." + ) as_dataset = _make_dataset(obj) @@ -321,7 +368,9 @@ def to_icechunk( def _make_dataset(obj: DataArray) -> Dataset: ... @overload def _make_dataset(obj: Dataset) -> Dataset: ... -def _make_dataset(obj: DataArray | Dataset) -> Dataset: +@overload +def _make_dataset(obj: "DataTree") -> Dataset: ... +def _make_dataset(obj: DataArray | Dataset | "DataTree") -> Dataset: """Copied from DataArray.to_zarr""" DATAARRAY_NAME = "__xarray_dataarray_name__" DATAARRAY_VARIABLE = "__xarray_dataarray_variable__" @@ -329,6 +378,13 @@ def _make_dataset(obj: DataArray | Dataset) -> Dataset: if isinstance(obj, Dataset): return obj + if DataTree is not None and isinstance(obj, DataTree): + # For DataTree, we currently only support the root dataset + # Implementation will be provided later + raise NotImplementedError( + "DataTree support is not yet implemented. Please provide the implementation." + ) + assert isinstance(obj, DataArray) if obj.name is None: From 29e76829dd731d7bb964e23ba569c0aa341dfdba Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Mon, 22 Sep 2025 18:01:01 -0400 Subject: [PATCH 2/8] test writing datatree --- icechunk-python/tests/test_xarray.py | 56 +++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/icechunk-python/tests/test_xarray.py b/icechunk-python/tests/test_xarray.py index 54f236eb2..e3263da3b 100644 --- a/icechunk-python/tests/test_xarray.py +++ b/icechunk-python/tests/test_xarray.py @@ -49,6 +49,39 @@ def create_test_data( return obj +def create_test_datatree() -> xr.DataTree: + return xr.DataTree.from_dict( + { + "/": xr.Dataset( + data_vars={ + "bar": ("x", ["hello", "world"]), + }, + coords={ + "x": ("x", [1, 2]), # inherited dimension coordinate that can't be overriden + "w": ("x", [0.1, 0.2]), # inherited non-dimension coordinate to override + }, + ), + "/a": xr.Dataset( + data_vars={ + "foo": ("x", ["alpha", "beta"]), + }, + coords={ + "w": ("x", [10, 20]), # override inherited non-dimension coordinate + "z": ("z", ["alpha", "beta"]), # non-inherited dimension coordinate + }, + ), + "/b": xr.Dataset( + data_vars={ + "foo": ("x", ["gamma", "delta"]), + }, + coords={ + "z": ("z", ["alpha", "beta", "gamma"]), # override inherited non-dimension coordinate with different length (i.e. multi-resolution) + }, + ), + } + ) + + @contextlib.contextmanager def roundtrip( data: xr.Dataset, *, commit: bool = False @@ -62,12 +95,33 @@ def roundtrip( yield ds -def test_xarray_to_icechunk() -> None: +def test_xarray_dataset_to_icechunk() -> None: ds = create_test_data() with roundtrip(ds) as actual: assert_identical(actual, ds) +@contextlib.contextmanager +def roundtrip_datatree( + dt: xr.DataTree, *, commit: bool = False +) -> Generator[xr.DataTree, None, None]: + with tempfile.TemporaryDirectory() as tmpdir: + repo = Repository.create(local_filesystem_storage(tmpdir)) + session = repo.writable_session("main") + to_icechunk(dt, session=session, mode="w") + session.commit("write") + with xr.open_datatree(session.store, consolidated=False, engine="zarr") as dt: + yield dt + + +def test_xarray_datatree_to_icechunk() -> None: + dt = create_test_datatree() + with roundtrip_datatree(dt) as actual: + print(actual) + print(dt) + assert_identical(actual, dt) + + def test_repeated_to_icechunk_serial() -> None: ds = create_test_data() repo = Repository.create(in_memory_storage()) From 99a5b05b9d13af4df0178b2578124ebd8085c7a3 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Mon, 22 Sep 2025 18:01:15 -0400 Subject: [PATCH 3/8] basic implementation --- icechunk-python/python/icechunk/xarray.py | 69 +++++++++++++++-------- 1 file changed, 45 insertions(+), 24 deletions(-) diff --git a/icechunk-python/python/icechunk/xarray.py b/icechunk-python/python/icechunk/xarray.py index de28eb9a2..71f3129a6 100644 --- a/icechunk-python/python/icechunk/xarray.py +++ b/icechunk-python/python/icechunk/xarray.py @@ -185,7 +185,25 @@ def write_lazy( return session_merge_reduction(stored_arrays, split_every=split_every) -# Overload for DataTree - restricted parameters + +def write_ds(ds, store, safe_chunks, group, mode, append_dim, region, encoding, chunkmanager_store_kwargs): + writer = _XarrayDatasetWriter(ds, store=store, safe_chunks=safe_chunks) + writer._open_group(group=group, mode=mode, append_dim=append_dim, region=region) + + # write metadata + writer.write_metadata(encoding) + # write in-memory arrays + writer.write_eager() + # eagerly write dask arrays + maybe_fork_session = writer.write_lazy( + chunkmanager_store_kwargs=chunkmanager_store_kwargs + ) + + return maybe_fork_session + + + +# overload because several kwargs are currently forbidden for DataTree @overload def to_icechunk( obj: DataTree, @@ -199,7 +217,6 @@ def to_icechunk( ) -> None: ... -# Overload for DataArray/Dataset - full parameters @overload def to_icechunk( obj: DataArray | Dataset, @@ -328,10 +345,11 @@ def to_icechunk( "The 'region' parameter is not yet supported when writing DataTree objects." ) - as_dataset = _make_dataset(obj) - # This ugliness is needed so that we allow users to call `to_icechunk` with a dirty Session # for _serial_ writes + + # TODO DataTree does not implement `__dask_graph__`, unlike `Dataset`, so will this ever trigger? + # TODO this doesn't even trigger for the current Dataset tests, because they use non-dask data... is_dask = is_dask_collection(obj) fork: Session | ForkSession if is_dask: @@ -343,18 +361,30 @@ def to_icechunk( else: fork = session - writer = _XarrayDatasetWriter(as_dataset, store=fork.store, safe_chunks=safe_chunks) + if isinstance(obj, DataTree): + dt = obj + + if encoding is None: + encoding = {} + if set(encoding) - set(dt.groups): + raise ValueError( + f"unexpected encoding group name(s) provided: {set(encoding) - set(dt.groups)}" + ) + + # TODO expose this + write_inherited_coords = False + + for rel_path, node in dt.subtree_with_keys: + at_root = node is dt + dataset = node.to_dataset(inherit=write_inherited_coords or at_root) - writer._open_group(group=group, mode=mode, append_dim=append_dim, region=region) + # TODO what do I do with all these maybe_fork_sessions here? + maybe_fork_session = write_ds(ds=dataset, store=fork.store, safe_chunks=safe_chunks, group=dt[rel_path].path, mode=mode, append_dim=append_dim, region=region, encoding=encoding, chunkmanager_store_kwargs=chunkmanager_store_kwargs) - # write metadata - writer.write_metadata(encoding) - # write in-memory arrays - writer.write_eager() - # eagerly write dask arrays - maybe_fork_session = writer.write_lazy( - chunkmanager_store_kwargs=chunkmanager_store_kwargs - ) + else: + as_dataset = _make_dataset(obj) + maybe_fork_session = write_ds(ds=as_dataset, store=fork.store, safe_chunks=safe_chunks, group=group, mode=mode, append_dim=append_dim, region=region, encoding=encoding, chunkmanager_store_kwargs=chunkmanager_store_kwargs) + if is_dask: if maybe_fork_session is None: raise RuntimeError( @@ -372,9 +402,7 @@ def to_icechunk( def _make_dataset(obj: DataArray) -> Dataset: ... @overload def _make_dataset(obj: Dataset) -> Dataset: ... -@overload -def _make_dataset(obj: "DataTree") -> Dataset: ... -def _make_dataset(obj: DataArray | Dataset | "DataTree") -> Dataset: +def _make_dataset(obj: DataArray | Dataset) -> Dataset: """Copied from DataArray.to_zarr""" DATAARRAY_NAME = "__xarray_dataarray_name__" DATAARRAY_VARIABLE = "__xarray_dataarray_variable__" @@ -382,13 +410,6 @@ def _make_dataset(obj: DataArray | Dataset | "DataTree") -> Dataset: if isinstance(obj, Dataset): return obj - if DataTree is not None and isinstance(obj, DataTree): - # For DataTree, we currently only support the root dataset - # Implementation will be provided later - raise NotImplementedError( - "DataTree support is not yet implemented. Please provide the implementation." - ) - assert isinstance(obj, DataArray) if obj.name is None: From 95ad390d5835cf3616b82519344b26d914625627 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Mon, 22 Sep 2025 18:04:23 -0400 Subject: [PATCH 4/8] expose write_inherited_coords --- icechunk-python/python/icechunk/xarray.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/icechunk-python/python/icechunk/xarray.py b/icechunk-python/python/icechunk/xarray.py index 71f3129a6..b77d53d7c 100644 --- a/icechunk-python/python/icechunk/xarray.py +++ b/icechunk-python/python/icechunk/xarray.py @@ -203,7 +203,7 @@ def write_ds(ds, store, safe_chunks, group, mode, append_dim, region, encoding, -# overload because several kwargs are currently forbidden for DataTree +# overload because several kwargs are currently forbidden for DataTree, and ``write_inherited_coords`` only applies to DataTree @overload def to_icechunk( obj: DataTree, @@ -213,6 +213,7 @@ def to_icechunk( safe_chunks: bool = True, encoding: Mapping[Any, Any] | None = None, chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None, + write_inherited_coords: bool = False, split_every: int | None = None, ) -> None: ... @@ -230,6 +231,7 @@ def to_icechunk( encoding: Mapping[Any, Any] | None = None, chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None, split_every: int | None = None, + ) -> None: ... @@ -244,6 +246,7 @@ def to_icechunk( region: Region = None, encoding: Mapping[Any, Any] | None = None, chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None, + write_inherited_coords: bool = False, split_every: int | None = None, ) -> None: """ @@ -312,6 +315,11 @@ def to_icechunk( Additional keyword arguments passed on to the `ChunkManager.store` method used to store chunked arrays. For example for a dask array additional kwargs will be passed eventually to `dask.array.store()`. Experimental API that should not be relied upon. + write_inherited_coords : bool, default: False + If true, replicate inherited coordinates on all descendant nodes. + Otherwise, only write coordinates at the level at which they are + originally defined. This saves disk space, but requires opening the + full tree to load inherited coordinates. split_every: int, optional Number of tasks to merge at every level of the tree reduction. @@ -349,7 +357,6 @@ def to_icechunk( # for _serial_ writes # TODO DataTree does not implement `__dask_graph__`, unlike `Dataset`, so will this ever trigger? - # TODO this doesn't even trigger for the current Dataset tests, because they use non-dask data... is_dask = is_dask_collection(obj) fork: Session | ForkSession if is_dask: @@ -370,9 +377,6 @@ def to_icechunk( raise ValueError( f"unexpected encoding group name(s) provided: {set(encoding) - set(dt.groups)}" ) - - # TODO expose this - write_inherited_coords = False for rel_path, node in dt.subtree_with_keys: at_root = node is dt From 70adcc0d2638be859083b7a943ba7ac392e5f4d7 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Mon, 22 Sep 2025 18:05:29 -0400 Subject: [PATCH 5/8] lint --- icechunk-python/python/icechunk/xarray.py | 52 ++++++++++++++++++----- icechunk-python/tests/test_xarray.py | 17 +++++--- 2 files changed, 53 insertions(+), 16 deletions(-) diff --git a/icechunk-python/python/icechunk/xarray.py b/icechunk-python/python/icechunk/xarray.py index b77d53d7c..effe04ce9 100644 --- a/icechunk-python/python/icechunk/xarray.py +++ b/icechunk-python/python/icechunk/xarray.py @@ -32,7 +32,10 @@ ) if Version(xr.__version__) > Version("2025.09.0"): - from xarray.backends.writers import _validate_dataset_names, dump_to_store # type: ignore[import-not-found] + from xarray.backends.writers import ( # type: ignore[import-not-found] + _validate_dataset_names, + dump_to_store, + ) else: from xarray.backends.api import _validate_dataset_names, dump_to_store @@ -185,8 +188,17 @@ def write_lazy( return session_merge_reduction(stored_arrays, split_every=split_every) - -def write_ds(ds, store, safe_chunks, group, mode, append_dim, region, encoding, chunkmanager_store_kwargs): +def write_ds( + ds, + store, + safe_chunks, + group, + mode, + append_dim, + region, + encoding, + chunkmanager_store_kwargs, +): writer = _XarrayDatasetWriter(ds, store=store, safe_chunks=safe_chunks) writer._open_group(group=group, mode=mode, append_dim=append_dim, region=region) @@ -202,7 +214,6 @@ def write_ds(ds, store, safe_chunks, group, mode, append_dim, region, encoding, return maybe_fork_session - # overload because several kwargs are currently forbidden for DataTree, and ``write_inherited_coords`` only applies to DataTree @overload def to_icechunk( @@ -231,7 +242,6 @@ def to_icechunk( encoding: Mapping[Any, Any] | None = None, chunkmanager_store_kwargs: MutableMapping[Any, Any] | None = None, split_every: int | None = None, - ) -> None: ... @@ -255,8 +265,8 @@ def to_icechunk( Parameters ---------- obj: DataArray, Dataset, or DataTree - Xarray object to write. - + Xarray object to write. + Note: When passing a DataTree, the ``append_dim``, ``region``, and ``group`` parameters are not yet supported. session : icechunk.Session Writable Icechunk Session @@ -370,7 +380,7 @@ def to_icechunk( if isinstance(obj, DataTree): dt = obj - + if encoding is None: encoding = {} if set(encoding) - set(dt.groups): @@ -383,12 +393,32 @@ def to_icechunk( dataset = node.to_dataset(inherit=write_inherited_coords or at_root) # TODO what do I do with all these maybe_fork_sessions here? - maybe_fork_session = write_ds(ds=dataset, store=fork.store, safe_chunks=safe_chunks, group=dt[rel_path].path, mode=mode, append_dim=append_dim, region=region, encoding=encoding, chunkmanager_store_kwargs=chunkmanager_store_kwargs) + maybe_fork_session = write_ds( + ds=dataset, + store=fork.store, + safe_chunks=safe_chunks, + group=dt[rel_path].path, + mode=mode, + append_dim=append_dim, + region=region, + encoding=encoding, + chunkmanager_store_kwargs=chunkmanager_store_kwargs, + ) else: as_dataset = _make_dataset(obj) - maybe_fork_session = write_ds(ds=as_dataset, store=fork.store, safe_chunks=safe_chunks, group=group, mode=mode, append_dim=append_dim, region=region, encoding=encoding, chunkmanager_store_kwargs=chunkmanager_store_kwargs) - + maybe_fork_session = write_ds( + ds=as_dataset, + store=fork.store, + safe_chunks=safe_chunks, + group=group, + mode=mode, + append_dim=append_dim, + region=region, + encoding=encoding, + chunkmanager_store_kwargs=chunkmanager_store_kwargs, + ) + if is_dask: if maybe_fork_session is None: raise RuntimeError( diff --git a/icechunk-python/tests/test_xarray.py b/icechunk-python/tests/test_xarray.py index e3263da3b..b0cf7bb41 100644 --- a/icechunk-python/tests/test_xarray.py +++ b/icechunk-python/tests/test_xarray.py @@ -57,8 +57,14 @@ def create_test_datatree() -> xr.DataTree: "bar": ("x", ["hello", "world"]), }, coords={ - "x": ("x", [1, 2]), # inherited dimension coordinate that can't be overriden - "w": ("x", [0.1, 0.2]), # inherited non-dimension coordinate to override + "x": ( + "x", + [1, 2], + ), # inherited dimension coordinate that can't be overridden + "w": ( + "x", + [0.1, 0.2], + ), # inherited non-dimension coordinate to override }, ), "/a": xr.Dataset( @@ -75,7 +81,10 @@ def create_test_datatree() -> xr.DataTree: "foo": ("x", ["gamma", "delta"]), }, coords={ - "z": ("z", ["alpha", "beta", "gamma"]), # override inherited non-dimension coordinate with different length (i.e. multi-resolution) + "z": ( + "z", + ["alpha", "beta", "gamma"], + ), # override inherited non-dimension coordinate with different length (i.e. multi-resolution) }, ), } @@ -117,8 +126,6 @@ def roundtrip_datatree( def test_xarray_datatree_to_icechunk() -> None: dt = create_test_datatree() with roundtrip_datatree(dt) as actual: - print(actual) - print(dt) assert_identical(actual, dt) From 59de7f1af81495e550bad17ebc5d9e10671b8ff8 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Mon, 22 Sep 2025 22:09:08 -0400 Subject: [PATCH 6/8] merge fork sessions from writing to each group --- icechunk-python/python/icechunk/xarray.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/icechunk-python/python/icechunk/xarray.py b/icechunk-python/python/icechunk/xarray.py index effe04ce9..3a12dcbbf 100644 --- a/icechunk-python/python/icechunk/xarray.py +++ b/icechunk-python/python/icechunk/xarray.py @@ -9,6 +9,7 @@ import xarray as xr import zarr from icechunk import IcechunkStore, Session +from icechunk.distributed import merge_sessions from icechunk.session import ForkSession from icechunk.vendor.xarray import _choose_default_mode from xarray import DataArray, Dataset, DataTree @@ -198,7 +199,7 @@ def write_ds( region, encoding, chunkmanager_store_kwargs, -): +) -> ForkSession | None: writer = _XarrayDatasetWriter(ds, store=store, safe_chunks=safe_chunks) writer._open_group(group=group, mode=mode, append_dim=append_dim, region=region) @@ -388,6 +389,7 @@ def to_icechunk( f"unexpected encoding group name(s) provided: {set(encoding) - set(dt.groups)}" ) + maybe_forked_sessions: list[ForkSession | None] = [] for rel_path, node in dt.subtree_with_keys: at_root = node is dt dataset = node.to_dataset(inherit=write_inherited_coords or at_root) @@ -404,6 +406,14 @@ def to_icechunk( encoding=encoding, chunkmanager_store_kwargs=chunkmanager_store_kwargs, ) + maybe_forked_sessions.append(maybe_fork_session) + + # TODO assumes bool(ForkSession) evaluates to True + # TODO add is_dask check here, once its actually supported + if any(maybe_forked_sessions): + maybe_fork_session = merge_sessions(maybe_forked_sessions) + else: + maybe_fork_session = None else: as_dataset = _make_dataset(obj) From 7b0fec54becbb38ee7843c9502b58c2bfc44e9b0 Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Mon, 22 Sep 2025 22:15:03 -0400 Subject: [PATCH 7/8] merge the forked sessions for each group --- icechunk-python/python/icechunk/xarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icechunk-python/python/icechunk/xarray.py b/icechunk-python/python/icechunk/xarray.py index 3a12dcbbf..64d909bfb 100644 --- a/icechunk-python/python/icechunk/xarray.py +++ b/icechunk-python/python/icechunk/xarray.py @@ -394,7 +394,6 @@ def to_icechunk( at_root = node is dt dataset = node.to_dataset(inherit=write_inherited_coords or at_root) - # TODO what do I do with all these maybe_fork_sessions here? maybe_fork_session = write_ds( ds=dataset, store=fork.store, @@ -411,6 +410,7 @@ def to_icechunk( # TODO assumes bool(ForkSession) evaluates to True # TODO add is_dask check here, once its actually supported if any(maybe_forked_sessions): + # Note: This should be safe since each iteration of the loop writes to a different group, so there are no conflicts. maybe_fork_session = merge_sessions(maybe_forked_sessions) else: maybe_fork_session = None From 26e4c119ed0b92a576744f0485bf86a857fd46af Mon Sep 17 00:00:00 2001 From: TomNicholas Date: Fri, 26 Sep 2025 15:30:23 -0400 Subject: [PATCH 8/8] add detection of dask arrays inside DataTree objects --- icechunk-python/python/icechunk/xarray.py | 41 +++++++++++++++++++---- 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/icechunk-python/python/icechunk/xarray.py b/icechunk-python/python/icechunk/xarray.py index 64d909bfb..540d3cf22 100644 --- a/icechunk-python/python/icechunk/xarray.py +++ b/icechunk-python/python/icechunk/xarray.py @@ -24,8 +24,13 @@ try: has_dask = importlib.util.find_spec("dask") is not None + if has_dask: + from dask.highlevelgraph import HighLevelGraph + else: + HighLevelGraph = None except ImportError: has_dask = False + HighLevelGraph = None if Version(xr.__version__) < Version("2024.10.0"): raise ValueError( @@ -45,11 +50,39 @@ def is_dask_collection(x: Any) -> bool: if has_dask: import dask - return dask.base.is_dask_collection(x) + if isinstance(x, DataTree): + return bool(datatree_dask_graph(x)) + else: + return dask.base.is_dask_collection(x) else: return False +def datatree_dask_graph(dt: DataTree) -> "HighLevelGraph | None": # type: ignore[name-defined] + # copied from `Dataset.__dask_graph__()`. + # Should ideally be upstreamed into xarray as part of making DataTree a true dask collection - see https://github.com/pydata/xarray/issues/9355. + + all_variables = { + f"{path}/{var_name}" if path != "." else var_name: variable + for path, node in dt.subtree_with_keys + for var_name, variable in node.variables.items() + } + + graphs = {k: v.__dask_graph__() for k, v in all_variables.items()} + graphs = {k: v for k, v in graphs.items() if v is not None} + if not graphs: + return None + else: + try: + from dask.highlevelgraph import HighLevelGraph + + return HighLevelGraph.merge(*graphs.values()) + except ImportError: + from dask import sharedict + + return sharedict.merge(*graphs.values()) + + class LazyArrayWriter(ArrayWriter): def __init__(self) -> None: super().__init__() # type: ignore[no-untyped-call] @@ -366,8 +399,6 @@ def to_icechunk( # This ugliness is needed so that we allow users to call `to_icechunk` with a dirty Session # for _serial_ writes - - # TODO DataTree does not implement `__dask_graph__`, unlike `Dataset`, so will this ever trigger? is_dask = is_dask_collection(obj) fork: Session | ForkSession if is_dask: @@ -407,9 +438,7 @@ def to_icechunk( ) maybe_forked_sessions.append(maybe_fork_session) - # TODO assumes bool(ForkSession) evaluates to True - # TODO add is_dask check here, once its actually supported - if any(maybe_forked_sessions): + if any(maybe_forked_sessions) and is_dask: # Note: This should be safe since each iteration of the loop writes to a different group, so there are no conflicts. maybe_fork_session = merge_sessions(maybe_forked_sessions) else: