11import os
2+ import shutil
23
3- from toolz import curry
4-
5- import xarray as xr
64import dask
5+ import xarray as xr
6+ from toolz import curry
77
88from . import settings
99
10- __all__ = ["persisted_Dataset" , " persist_ds" ]
10+ __all__ = ['PersistedDataset' , ' persist_ds' ]
1111
1212_actions = {'read_cache_trusted' , 'read_cache_verified' , 'overwrite_cache' , 'create_cache' }
13- _formats = {'nc' }
13+ _formats = {'nc' , 'zarr' }
1414
15- class persisted_Dataset (object ):
15+
16+ class PersistedDataset (object ):
1617 """
1718 Generate an `xarray.Dataset` from a function and cache the result to file.
1819 If the cache file exists, don't recompute, but read back in from file.
@@ -30,8 +31,16 @@ class persisted_Dataset(object):
3031 # class property
3132 _actions = {}
3233
33- def __init__ (self , func , name = None , path = None , trust_cache = False , clobber = False ,
34- format = 'nc' , open_ds_kwargs = {}):
34+ def __init__ (
35+ self ,
36+ func ,
37+ name = None ,
38+ path = None ,
39+ trust_cache = False ,
40+ clobber = False ,
41+ format = 'nc' ,
42+ open_ds_kwargs = {},
43+ ):
3544 """set instance attributes"""
3645 self ._func = func
3746 self ._name = name
@@ -52,34 +61,37 @@ def _check_token_assign_action(self, token):
5261
5362 # if we don't yet know about this file, assume it's the right one;
5463 # this enables usage on first call in a Python session, for instance
55- known_cache = self ._cache_file in persisted_Dataset ._tokens
64+ known_cache = self ._cache_file in PersistedDataset ._tokens
5665 if not known_cache or self ._trust_cache and not self ._clobber :
5766 print (f'assuming cache is correct' )
58- persisted_Dataset ._tokens [self ._cache_file ] = token
59- persisted_Dataset ._actions [self ._cache_file ] = 'read_cache_trusted'
67+ PersistedDataset ._tokens [self ._cache_file ] = token
68+ PersistedDataset ._actions [self ._cache_file ] = 'read_cache_trusted'
6069
6170 # if the cache file is present and we know about it,
6271 # check the token; if the token doesn't match, remove the file
6372 elif known_cache :
64- if token != persisted_Dataset ._tokens [self ._cache_file ] or self ._clobber :
73+ if token != PersistedDataset ._tokens [self ._cache_file ] or self ._clobber :
6574 print (f'name mismatch, removing: { self ._cache_file } ' )
66- os .remove (self ._cache_file )
67- persisted_Dataset ._actions [self ._cache_file ] = 'overwrite_cache'
75+ if self ._format != 'zarr' :
76+ os .remove (self ._cache_file )
77+ else :
78+ shutil .rmtree (self ._cache_file , ignore_errors = True )
79+ PersistedDataset ._actions [self ._cache_file ] = 'overwrite_cache'
6880 else :
69- persisted_Dataset ._actions [self ._cache_file ] = 'read_cache_verified'
81+ PersistedDataset ._actions [self ._cache_file ] = 'read_cache_verified'
7082
7183 else :
72- persisted_Dataset ._tokens [self ._cache_file ] = token
73- persisted_Dataset ._actions [self ._cache_file ] = 'create_cache'
84+ PersistedDataset ._tokens [self ._cache_file ] = token
85+ PersistedDataset ._actions [self ._cache_file ] = 'create_cache'
7486 if os .path .dirname (self ._cache_file ) and not os .path .exists (self ._path ):
7587 print (f'making { self ._path } ' )
7688 os .makedirs (self ._path )
7789
78- assert persisted_Dataset ._actions [self ._cache_file ] in _actions
90+ assert PersistedDataset ._actions [self ._cache_file ] in _actions
7991
8092 @property
8193 def _basename (self ):
82- if self ._name .endswith ('.' + self ._format ):
94+ if self ._name .endswith ('.' + self ._format ):
8395 return self ._name
8496 else :
8597 return f'{ self ._name } .{ self ._format } '
@@ -95,36 +107,49 @@ def _cache_exists(self):
95107
96108 def __call__ (self , * args , ** kwargs ):
97109 """call function or read cache"""
98-
110+ # Generate Deterministic token
99111 token = dask .base .tokenize (self ._func , args , kwargs )
100112 if self ._name is None :
101- self ._name = f'persisted_Dataset -{ token } '
113+ self ._name = f'PersistedDataset -{ token } '
102114
103115 if self ._path is None :
104116 self ._path = settings ['cache_dir' ]
105117
106118 self ._check_token_assign_action (token )
107119
108- if {'read_cache_trusted' , 'read_cache_verified' }.intersection ({self ._actions [self ._cache_file ]}):
120+ if {'read_cache_trusted' , 'read_cache_verified' }.intersection (
121+ {self ._actions [self ._cache_file ]}
122+ ):
109123 print (f'reading cached file: { self ._cache_file } ' )
110- return xr .open_dataset (self ._cache_file , ** self ._open_ds_kwargs )
124+ if self ._format == 'nc' :
125+ return xr .open_dataset (self ._cache_file , ** self ._open_ds_kwargs )
126+ elif self ._format == 'zarr' :
127+ if 'consolidated' not in self ._open_ds_kwargs :
128+ zarr_kwargs = self ._open_ds_kwargs .copy ()
129+ zarr_kwargs ['consolidated' ] = True
130+ return xr .open_zarr (self ._cache_file , ** zarr_kwargs )
111131
112132 elif {'create_cache' , 'overwrite_cache' }.intersection ({self ._actions [self ._cache_file ]}):
113133 # generate dataset
114134 ds = self ._func (* args , ** kwargs )
115135
116136 # write dataset
117137 print (f'writing cache file: { self ._cache_file } ' )
118- ds .to_netcdf (self ._cache_file )
119138
120- return ds
139+ if self ._format == 'nc' :
140+ ds .to_netcdf (self ._cache_file )
121141
142+ elif self ._format == 'zarr' :
143+ ds .to_zarr (self ._cache_file , consolidated = True )
144+
145+ return ds
122146
123147
124148@curry
125- def persist_ds (func , name = None , path = None , trust_cache = False , clobber = False ,
126- format = 'nc' , open_ds_kwargs = {}):
127- """Wraps a function to produce a ``persisted_Dataset``.
149+ def persist_ds (
150+ func , name = None , path = None , trust_cache = False , clobber = False , format = 'nc' , open_ds_kwargs = {}
151+ ):
152+ """Wraps a function to produce a ``PersistedDataset``.
128153
129154 Parameters
130155 ----------
@@ -182,4 +207,4 @@ def persist_ds(func, name=None, path=None, trust_cache=False, clobber=False,
182207 if not callable (func ):
183208 raise ValueError ('func must be callable' )
184209
185- return persisted_Dataset (func , name , path , trust_cache , clobber , format , open_ds_kwargs )
210+ return PersistedDataset (func , name , path , trust_cache , clobber , format , open_ds_kwargs )
0 commit comments