1
1
import os
2
+ import shutil
2
3
3
- from toolz import curry
4
-
5
- import xarray as xr
6
4
import dask
5
+ import xarray as xr
6
+ from toolz import curry
7
7
8
8
from . import settings
9
9
10
- __all__ = ["persisted_Dataset" , " persist_ds" ]
10
+ __all__ = ['PersistedDataset' , ' persist_ds' ]
11
11
12
12
_actions = {'read_cache_trusted' , 'read_cache_verified' , 'overwrite_cache' , 'create_cache' }
13
- _formats = {'nc' }
13
+ _formats = {'nc' , 'zarr' }
14
14
15
- class persisted_Dataset (object ):
15
+
16
+ class PersistedDataset (object ):
16
17
"""
17
18
Generate an `xarray.Dataset` from a function and cache the result to file.
18
19
If the cache file exists, don't recompute, but read back in from file.
@@ -30,8 +31,16 @@ class persisted_Dataset(object):
30
31
# class property
31
32
_actions = {}
32
33
33
- def __init__ (self , func , name = None , path = None , trust_cache = False , clobber = False ,
34
- format = 'nc' , open_ds_kwargs = {}):
34
+ def __init__ (
35
+ self ,
36
+ func ,
37
+ name = None ,
38
+ path = None ,
39
+ trust_cache = False ,
40
+ clobber = False ,
41
+ format = 'nc' ,
42
+ open_ds_kwargs = {},
43
+ ):
35
44
"""set instance attributes"""
36
45
self ._func = func
37
46
self ._name = name
@@ -52,34 +61,37 @@ def _check_token_assign_action(self, token):
52
61
53
62
# if we don't yet know about this file, assume it's the right one;
54
63
# this enables usage on first call in a Python session, for instance
55
- known_cache = self ._cache_file in persisted_Dataset ._tokens
64
+ known_cache = self ._cache_file in PersistedDataset ._tokens
56
65
if not known_cache or self ._trust_cache and not self ._clobber :
57
66
print (f'assuming cache is correct' )
58
- persisted_Dataset ._tokens [self ._cache_file ] = token
59
- persisted_Dataset ._actions [self ._cache_file ] = 'read_cache_trusted'
67
+ PersistedDataset ._tokens [self ._cache_file ] = token
68
+ PersistedDataset ._actions [self ._cache_file ] = 'read_cache_trusted'
60
69
61
70
# if the cache file is present and we know about it,
62
71
# check the token; if the token doesn't match, remove the file
63
72
elif known_cache :
64
- if token != persisted_Dataset ._tokens [self ._cache_file ] or self ._clobber :
73
+ if token != PersistedDataset ._tokens [self ._cache_file ] or self ._clobber :
65
74
print (f'name mismatch, removing: { self ._cache_file } ' )
66
- os .remove (self ._cache_file )
67
- persisted_Dataset ._actions [self ._cache_file ] = 'overwrite_cache'
75
+ if self ._format != 'zarr' :
76
+ os .remove (self ._cache_file )
77
+ else :
78
+ shutil .rmtree (self ._cache_file , ignore_errors = True )
79
+ PersistedDataset ._actions [self ._cache_file ] = 'overwrite_cache'
68
80
else :
69
- persisted_Dataset ._actions [self ._cache_file ] = 'read_cache_verified'
81
+ PersistedDataset ._actions [self ._cache_file ] = 'read_cache_verified'
70
82
71
83
else :
72
- persisted_Dataset ._tokens [self ._cache_file ] = token
73
- persisted_Dataset ._actions [self ._cache_file ] = 'create_cache'
84
+ PersistedDataset ._tokens [self ._cache_file ] = token
85
+ PersistedDataset ._actions [self ._cache_file ] = 'create_cache'
74
86
if os .path .dirname (self ._cache_file ) and not os .path .exists (self ._path ):
75
87
print (f'making { self ._path } ' )
76
88
os .makedirs (self ._path )
77
89
78
- assert persisted_Dataset ._actions [self ._cache_file ] in _actions
90
+ assert PersistedDataset ._actions [self ._cache_file ] in _actions
79
91
80
92
@property
81
93
def _basename (self ):
82
- if self ._name .endswith ('.' + self ._format ):
94
+ if self ._name .endswith ('.' + self ._format ):
83
95
return self ._name
84
96
else :
85
97
return f'{ self ._name } .{ self ._format } '
@@ -95,36 +107,49 @@ def _cache_exists(self):
95
107
96
108
def __call__ (self , * args , ** kwargs ):
97
109
"""call function or read cache"""
98
-
110
+ # Generate Deterministic token
99
111
token = dask .base .tokenize (self ._func , args , kwargs )
100
112
if self ._name is None :
101
- self ._name = f'persisted_Dataset -{ token } '
113
+ self ._name = f'PersistedDataset -{ token } '
102
114
103
115
if self ._path is None :
104
116
self ._path = settings ['cache_dir' ]
105
117
106
118
self ._check_token_assign_action (token )
107
119
108
- if {'read_cache_trusted' , 'read_cache_verified' }.intersection ({self ._actions [self ._cache_file ]}):
120
+ if {'read_cache_trusted' , 'read_cache_verified' }.intersection (
121
+ {self ._actions [self ._cache_file ]}
122
+ ):
109
123
print (f'reading cached file: { self ._cache_file } ' )
110
- return xr .open_dataset (self ._cache_file , ** self ._open_ds_kwargs )
124
+ if self ._format == 'nc' :
125
+ return xr .open_dataset (self ._cache_file , ** self ._open_ds_kwargs )
126
+ elif self ._format == 'zarr' :
127
+ if 'consolidated' not in self ._open_ds_kwargs :
128
+ zarr_kwargs = self ._open_ds_kwargs .copy ()
129
+ zarr_kwargs ['consolidated' ] = True
130
+ return xr .open_zarr (self ._cache_file , ** zarr_kwargs )
111
131
112
132
elif {'create_cache' , 'overwrite_cache' }.intersection ({self ._actions [self ._cache_file ]}):
113
133
# generate dataset
114
134
ds = self ._func (* args , ** kwargs )
115
135
116
136
# write dataset
117
137
print (f'writing cache file: { self ._cache_file } ' )
118
- ds .to_netcdf (self ._cache_file )
119
138
120
- return ds
139
+ if self ._format == 'nc' :
140
+ ds .to_netcdf (self ._cache_file )
121
141
142
+ elif self ._format == 'zarr' :
143
+ ds .to_zarr (self ._cache_file , consolidated = True )
144
+
145
+ return ds
122
146
123
147
124
148
@curry
125
- def persist_ds (func , name = None , path = None , trust_cache = False , clobber = False ,
126
- format = 'nc' , open_ds_kwargs = {}):
127
- """Wraps a function to produce a ``persisted_Dataset``.
149
+ def persist_ds (
150
+ func , name = None , path = None , trust_cache = False , clobber = False , format = 'nc' , open_ds_kwargs = {}
151
+ ):
152
+ """Wraps a function to produce a ``PersistedDataset``.
128
153
129
154
Parameters
130
155
----------
@@ -182,4 +207,4 @@ def persist_ds(func, name=None, path=None, trust_cache=False, clobber=False,
182
207
if not callable (func ):
183
208
raise ValueError ('func must be callable' )
184
209
185
- return persisted_Dataset (func , name , path , trust_cache , clobber , format , open_ds_kwargs )
210
+ return PersistedDataset (func , name , path , trust_cache , clobber , format , open_ds_kwargs )
0 commit comments