Skip to content

Commit a764787

Browse files
authored
Merge pull request #4 from andersy005/master
Add support for caching datasets as zarr stores
2 parents 1beadf3 + 682c23c commit a764787

File tree

5 files changed

+92
-54
lines changed

5 files changed

+92
-54
lines changed

requirements-dev.txt

+1
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@ xarray
33
dask
44
toolz
55
netCDF4
6+
zarr
67
pytest
78
pytest-cov

requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@ numpy
22
xarray
33
dask
44
toolz
5+
netCDF4
6+
zarr

tests/cached_data/test-dset.nc

0 Bytes
Binary file not shown.

tests/test_core.py

+35-25
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
import os
22
import shutil
33
from glob import glob
4+
from tempfile import TemporaryDirectory
5+
46
import numpy as np
7+
import pytest
58
import xarray as xr
69

710
import xpersist as xp
811

9-
import pytest
10-
1112
here = os.path.abspath(os.path.dirname(__file__))
1213
xp.settings['cache_dir'] = os.path.join(here, 'cached_data')
1314

1415

1516
def rm_tmpfile():
16-
for p in ['tmp-*.nc', 'persisted_Dataset-*.nc']:
17+
for p in ['tmp-*.nc', 'PersistedDataset-*.nc']:
1718
for f in glob(os.path.join(here, 'cached_data', p)):
1819
os.remove(f)
1920

@@ -24,26 +25,27 @@ def cleanup():
2425
yield
2526
rm_tmpfile()
2627

28+
2729
def func(scaleby):
28-
return xr.Dataset({'x': xr.DataArray(np.ones((50,))*scaleby)})
30+
return xr.Dataset({'x': xr.DataArray(np.ones((50,)) * scaleby)})
2931

3032

3133
# must be first test
3234
def test_xpersist_actions():
33-
ds = xp.persist_ds(func, name='test-dset')(10)
34-
file, action = xp.persisted_Dataset._actions.popitem()
35+
_ = xp.persist_ds(func, name='test-dset')(10)
36+
file, action = xp.PersistedDataset._actions.popitem()
3537
assert action == 'read_cache_trusted'
3638

37-
ds = xp.persist_ds(func, name='test-dset')(10)
38-
file, action = xp.persisted_Dataset._actions.popitem()
39+
_ = xp.persist_ds(func, name='test-dset')(10)
40+
file, action = xp.PersistedDataset._actions.popitem()
3941
assert action == 'read_cache_verified'
4042

41-
ds = xp.persist_ds(func, name='test-dset')(11)
42-
file, action = xp.persisted_Dataset._actions.popitem()
43+
_ = xp.persist_ds(func, name='test-dset')(11)
44+
file, action = xp.PersistedDataset._actions.popitem()
4345
assert action == 'overwrite_cache'
4446

45-
ds = xp.persist_ds(func, name='tmp-test-dset')(11)
46-
file, action = xp.persisted_Dataset._actions.popitem()
47+
_ = xp.persist_ds(func, name='tmp-test-dset')(11)
48+
file, action = xp.PersistedDataset._actions.popitem()
4749
assert action == 'create_cache'
4850

4951

@@ -60,42 +62,50 @@ def test_make_cache_dir():
6062
shutil.rmtree(new)
6163
xp.settings['cache_dir'] = new
6264

63-
ds = xp.persist_ds(func, name='test-dset')(10)
65+
_ = xp.persist_ds(func, name='test-dset')(10)
6466

6567
assert os.path.exists(new)
6668

6769
shutil.rmtree(new)
6870
xp.settings['cache_dir'] = old
6971

7072

71-
7273
def test_xpersist_noname():
73-
ds = xp.persist_ds(func)(10)
74-
file, action = xp.persisted_Dataset._actions.popitem()
74+
_ = xp.persist_ds(func)(10)
75+
file, action = xp.PersistedDataset._actions.popitem()
7576
assert action == 'create_cache'
7677

7778

7879
def test_clobber():
79-
ds = xp.persist_ds(func, name='test-dset')(10)
80-
file, action = xp.persisted_Dataset._actions.popitem()
80+
_ = xp.persist_ds(func, name='test-dset')(10)
81+
file, action = xp.PersistedDataset._actions.popitem()
8182
assert action == 'read_cache_verified'
8283

83-
ds = xp.persist_ds(func, name='test-dset', clobber=True)(11)
84-
file, action = xp.persisted_Dataset._actions.popitem()
84+
_ = xp.persist_ds(func, name='test-dset', clobber=True)(11)
85+
file, action = xp.PersistedDataset._actions.popitem()
8586
assert action == 'overwrite_cache'
8687

8788

8889
def test_trusted():
89-
ds = xp.persist_ds(func, name='test-dset')(10)
90-
file, action = xp.persisted_Dataset._actions.popitem()
90+
_ = xp.persist_ds(func, name='test-dset')(10)
91+
file, action = xp.PersistedDataset._actions.popitem()
9192
assert action == 'read_cache_verified'
9293

93-
ds = xp.persist_ds(func, name='test-dset', trust_cache=True)(11)
94-
file, action = xp.persisted_Dataset._actions.popitem()
94+
_ = xp.persist_ds(func, name='test-dset', trust_cache=True)(11)
95+
file, action = xp.PersistedDataset._actions.popitem()
9596
assert action == 'read_cache_trusted'
9697

98+
9799
def test_validate_dset():
98100
dsp = xp.persist_ds(func, name='test-dset')(10)
99-
file, action = xp.persisted_Dataset._actions.popitem()
101+
file, action = xp.PersistedDataset._actions.popitem()
100102
ds = xr.open_dataset(file)
101103
xr.testing.assert_identical(dsp, ds)
104+
105+
106+
def test_save_as_zarr():
107+
with TemporaryDirectory() as local_store:
108+
dsp = xp.persist_ds(func, name='test-dset', path=local_store, format='zarr')(10)
109+
zarr_store, action = xp.PersistedDataset._actions.popitem()
110+
ds = xr.open_zarr(zarr_store, consolidated=True)
111+
xr.testing.assert_identical(dsp, ds)

xpersist/core.py

+54-29
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
import os
2+
import shutil
23

3-
from toolz import curry
4-
5-
import xarray as xr
64
import dask
5+
import xarray as xr
6+
from toolz import curry
77

88
from . import settings
99

10-
__all__ = ["persisted_Dataset", "persist_ds"]
10+
__all__ = ['PersistedDataset', 'persist_ds']
1111

1212
_actions = {'read_cache_trusted', 'read_cache_verified', 'overwrite_cache', 'create_cache'}
13-
_formats = {'nc'}
13+
_formats = {'nc', 'zarr'}
1414

15-
class persisted_Dataset(object):
15+
16+
class PersistedDataset(object):
1617
"""
1718
Generate an `xarray.Dataset` from a function and cache the result to file.
1819
If the cache file exists, don't recompute, but read back in from file.
@@ -30,8 +31,16 @@ class persisted_Dataset(object):
3031
# class property
3132
_actions = {}
3233

33-
def __init__(self, func, name=None, path=None, trust_cache=False, clobber=False,
34-
format='nc', open_ds_kwargs={}):
34+
def __init__(
35+
self,
36+
func,
37+
name=None,
38+
path=None,
39+
trust_cache=False,
40+
clobber=False,
41+
format='nc',
42+
open_ds_kwargs={},
43+
):
3544
"""set instance attributes"""
3645
self._func = func
3746
self._name = name
@@ -52,34 +61,37 @@ def _check_token_assign_action(self, token):
5261

5362
# if we don't yet know about this file, assume it's the right one;
5463
# this enables usage on first call in a Python session, for instance
55-
known_cache = self._cache_file in persisted_Dataset._tokens
64+
known_cache = self._cache_file in PersistedDataset._tokens
5665
if not known_cache or self._trust_cache and not self._clobber:
5766
print(f'assuming cache is correct')
58-
persisted_Dataset._tokens[self._cache_file] = token
59-
persisted_Dataset._actions[self._cache_file] = 'read_cache_trusted'
67+
PersistedDataset._tokens[self._cache_file] = token
68+
PersistedDataset._actions[self._cache_file] = 'read_cache_trusted'
6069

6170
# if the cache file is present and we know about it,
6271
# check the token; if the token doesn't match, remove the file
6372
elif known_cache:
64-
if token != persisted_Dataset._tokens[self._cache_file] or self._clobber:
73+
if token != PersistedDataset._tokens[self._cache_file] or self._clobber:
6574
print(f'name mismatch, removing: {self._cache_file}')
66-
os.remove(self._cache_file)
67-
persisted_Dataset._actions[self._cache_file] = 'overwrite_cache'
75+
if self._format != 'zarr':
76+
os.remove(self._cache_file)
77+
else:
78+
shutil.rmtree(self._cache_file, ignore_errors=True)
79+
PersistedDataset._actions[self._cache_file] = 'overwrite_cache'
6880
else:
69-
persisted_Dataset._actions[self._cache_file] = 'read_cache_verified'
81+
PersistedDataset._actions[self._cache_file] = 'read_cache_verified'
7082

7183
else:
72-
persisted_Dataset._tokens[self._cache_file] = token
73-
persisted_Dataset._actions[self._cache_file] = 'create_cache'
84+
PersistedDataset._tokens[self._cache_file] = token
85+
PersistedDataset._actions[self._cache_file] = 'create_cache'
7486
if os.path.dirname(self._cache_file) and not os.path.exists(self._path):
7587
print(f'making {self._path}')
7688
os.makedirs(self._path)
7789

78-
assert persisted_Dataset._actions[self._cache_file] in _actions
90+
assert PersistedDataset._actions[self._cache_file] in _actions
7991

8092
@property
8193
def _basename(self):
82-
if self._name.endswith('.'+self._format):
94+
if self._name.endswith('.' + self._format):
8395
return self._name
8496
else:
8597
return f'{self._name}.{self._format}'
@@ -95,36 +107,49 @@ def _cache_exists(self):
95107

96108
def __call__(self, *args, **kwargs):
97109
"""call function or read cache"""
98-
110+
# Generate Deterministic token
99111
token = dask.base.tokenize(self._func, args, kwargs)
100112
if self._name is None:
101-
self._name = f'persisted_Dataset-{token}'
113+
self._name = f'PersistedDataset-{token}'
102114

103115
if self._path is None:
104116
self._path = settings['cache_dir']
105117

106118
self._check_token_assign_action(token)
107119

108-
if {'read_cache_trusted', 'read_cache_verified'}.intersection({self._actions[self._cache_file]}):
120+
if {'read_cache_trusted', 'read_cache_verified'}.intersection(
121+
{self._actions[self._cache_file]}
122+
):
109123
print(f'reading cached file: {self._cache_file}')
110-
return xr.open_dataset(self._cache_file, **self._open_ds_kwargs)
124+
if self._format == 'nc':
125+
return xr.open_dataset(self._cache_file, **self._open_ds_kwargs)
126+
elif self._format == 'zarr':
127+
if 'consolidated' not in self._open_ds_kwargs:
128+
zarr_kwargs = self._open_ds_kwargs.copy()
129+
zarr_kwargs['consolidated'] = True
130+
return xr.open_zarr(self._cache_file, **zarr_kwargs)
111131

112132
elif {'create_cache', 'overwrite_cache'}.intersection({self._actions[self._cache_file]}):
113133
# generate dataset
114134
ds = self._func(*args, **kwargs)
115135

116136
# write dataset
117137
print(f'writing cache file: {self._cache_file}')
118-
ds.to_netcdf(self._cache_file)
119138

120-
return ds
139+
if self._format == 'nc':
140+
ds.to_netcdf(self._cache_file)
121141

142+
elif self._format == 'zarr':
143+
ds.to_zarr(self._cache_file, consolidated=True)
144+
145+
return ds
122146

123147

124148
@curry
125-
def persist_ds(func, name=None, path=None, trust_cache=False, clobber=False,
126-
format='nc', open_ds_kwargs={}):
127-
"""Wraps a function to produce a ``persisted_Dataset``.
149+
def persist_ds(
150+
func, name=None, path=None, trust_cache=False, clobber=False, format='nc', open_ds_kwargs={}
151+
):
152+
"""Wraps a function to produce a ``PersistedDataset``.
128153
129154
Parameters
130155
----------
@@ -182,4 +207,4 @@ def persist_ds(func, name=None, path=None, trust_cache=False, clobber=False,
182207
if not callable(func):
183208
raise ValueError('func must be callable')
184209

185-
return persisted_Dataset(func, name, path, trust_cache, clobber, format, open_ds_kwargs)
210+
return PersistedDataset(func, name, path, trust_cache, clobber, format, open_ds_kwargs)

0 commit comments

Comments
 (0)