Skip to content

Commit 3194534

Browse files
authored
Fix create_dataset with data kwarg (#2638)
* Add failing test for #2631 * Fix create_dataset with data argument
1 parent 584d66d commit 3194534

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

src/zarr/core/group.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1165,7 +1165,16 @@ async def create_dataset(
11651165
.. deprecated:: 3.0.0
11661166
The h5py compatibility methods will be removed in 3.1.0. Use `AsyncGroup.create_array` instead.
11671167
"""
1168-
return await self.create_array(name, shape=shape, **kwargs)
1168+
data = kwargs.pop("data", None)
1169+
# create_dataset in zarr 2.x requires shape but not dtype if data is
1170+
# provided. Allow this configuration by inferring dtype from data if
1171+
# necessary and passing it to create_array
1172+
if "dtype" not in kwargs and data is not None:
1173+
kwargs["dtype"] = data.dtype
1174+
array = await self.create_array(name, shape=shape, **kwargs)
1175+
if data is not None:
1176+
await array.setitem(slice(None), data)
1177+
return array
11691178

11701179
@deprecated("Use AsyncGroup.require_array instead.")
11711180
async def require_dataset(

tests/test_group.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,6 +1137,18 @@ async def test_require_groups(store: LocalStore | MemoryStore, zarr_format: Zarr
11371137
assert no_group == ()
11381138

11391139

1140+
def test_create_dataset_with_data(store: Store, zarr_format: ZarrFormat) -> None:
1141+
"""Check that deprecated create_dataset method allows input data.
1142+
1143+
See https://github.com/zarr-developers/zarr-python/issues/2631.
1144+
"""
1145+
root = Group.from_store(store=store, zarr_format=zarr_format)
1146+
arr = np.random.random((5, 5))
1147+
with pytest.warns(DeprecationWarning):
1148+
data = root.create_dataset("random", data=arr, shape=arr.shape)
1149+
np.testing.assert_array_equal(np.asarray(data), arr)
1150+
1151+
11401152
async def test_create_dataset(store: Store, zarr_format: ZarrFormat) -> None:
11411153
root = await AsyncGroup.from_store(store=store, zarr_format=zarr_format)
11421154
with pytest.warns(DeprecationWarning):

0 commit comments

Comments
 (0)