4
4
import logging
5
5
import pickle
6
6
import time
7
- from typing import (Any , Callable , Container , Hashable , Iterable , Optional ,
8
- Union )
7
+ from typing import (Any , Callable , Container , Dict , Hashable , Iterable ,
8
+ Optional , Union )
9
9
10
10
import cloudpickle
11
11
import dask
@@ -23,11 +23,11 @@ class CachedComputation:
23
23
24
24
def __init__ (
25
25
self ,
26
- dsk : dict ,
26
+ dsk : Dict [ Hashable , Any ] ,
27
27
key : Hashable ,
28
28
computation : Any ,
29
29
location : Union [str , fs .base .FS ],
30
- write_to_cache : Union [bool , str ]= 'auto' ) -> None :
30
+ write_to_cache : Union [bool , str ] = 'auto' ) -> None :
31
31
"""Cache a dask graph computation.
32
32
33
33
Parameters
@@ -51,9 +51,9 @@ def __init__(
51
51
52
52
Returns
53
53
-------
54
- CachedComputation
55
- A wrapper for the computation object to replace the original
56
- computation with in the dask graph.
54
+ CachedComputation
55
+ A wrapper for the computation object to replace the original
56
+ computation with in the dask graph.
57
57
"""
58
58
self .dsk = dsk
59
59
self .key = key
@@ -62,7 +62,7 @@ def __init__(
62
62
self .write_to_cache = write_to_cache
63
63
64
64
@property # type: ignore
65
- @functools .lru_cache () # type: ignore
65
+ @functools .lru_cache ()
66
66
def cache_fs (self ) -> fs .base .FS :
67
67
"""Open a PyFilesystem FS to the cache directory."""
68
68
# create=True does not yet work for S3FS [1]. This should probably be
@@ -133,7 +133,7 @@ def estimate_load_time(self, result: Any) -> float:
133
133
500e6 if isinstance (self .cache_fs , fs .osfs .OSFS ) else 50e6 ))
134
134
return read_latency + size / read_throughput
135
135
136
- @functools .lru_cache () # type: ignore
136
+ @functools .lru_cache ()
137
137
def read_time (self , timing_type : str ) -> float :
138
138
"""Read the time to load, compute, or store from file."""
139
139
time_filename = f'{ self .hash } .time.{ timing_type } '
@@ -154,7 +154,7 @@ def write_log(self, log_type: str) -> None:
154
154
with self .cache_fs .open (log_filename , 'w' ) as fid : # type: ignore
155
155
fid .write (self .hash )
156
156
157
- def time_to_result (self , memoize : bool = True ) -> float :
157
+ def time_to_result (self , memoize : bool = True ) -> float :
158
158
"""Estimate the time to load or compute this computation."""
159
159
if hasattr (self , '_time_to_result' ):
160
160
return self ._time_to_result # type: ignore
@@ -286,10 +286,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
286
286
287
287
288
288
def optimize (
289
- dsk : dict ,
290
- keys : Optional [Union [Hashable , Iterable [Hashable ]]]= None ,
291
- skip_keys : Optional [Container [Hashable ]]= None ,
292
- location : Union [str , fs .base .FS ]= "./__graphchain_cache__" ) -> dict :
289
+ dsk : Dict [Hashable , Any ],
290
+ keys : Optional [Union [Hashable , Iterable [Hashable ]]] = None ,
291
+ skip_keys : Optional [Container [Hashable ]] = None ,
292
+ location : Union [str , fs .base .FS ] = "./__graphchain_cache__" ) \
293
+ -> Dict [Hashable , Any ]:
293
294
"""Optimize a dask graph with cached computations.
294
295
295
296
According to the dask graph specification [1]_, a dask graph is a
@@ -318,23 +319,23 @@ def optimize(
318
319
319
320
Parameters
320
321
----------
321
- dsk
322
- The dask graph to optimize with caching computations.
323
- keys
324
- Not used. Is present for compatibility with dask optimizers [2]_.
325
- skip_keys
326
- A container of keys not to cache.
327
- location
328
- A PyFilesystem FS URL to store the cached computations in. Can be a
329
- local directory such as ``'./__graphchain_cache__'`` or a remote
330
- directory such as ``'s3://bucket/__graphchain_cache__'``. You can
331
- also pass a PyFilesystem itself instead.
322
+ dsk
323
+ The dask graph to optimize with caching computations.
324
+ keys
325
+ Not used. Is present for compatibility with dask optimizers [2]_.
326
+ skip_keys
327
+ A container of keys not to cache.
328
+ location
329
+ A PyFilesystem FS URL to store the cached computations in. Can be a
330
+ local directory such as ``'./__graphchain_cache__'`` or a remote
331
+ directory such as ``'s3://bucket/__graphchain_cache__'``. You can
332
+ also pass a PyFilesystem itself instead.
332
333
333
334
Returns
334
335
-------
335
- dict
336
- A copy of the dask graph where the computations have been replaced
337
- by ``CachedComputation``'s.
336
+ dict
337
+ A copy of the dask graph where the computations have been replaced by
338
+ ``CachedComputation``'s.
338
339
339
340
References
340
341
----------
@@ -361,11 +362,14 @@ def optimize(
361
362
362
363
363
364
def get (
364
- dsk : dict ,
365
+ dsk : Dict [ Hashable , Any ] ,
365
366
keys : Union [Hashable , Iterable [Hashable ]],
366
- skip_keys : Optional [Container [Hashable ]]= None ,
367
- location : Union [str , fs .base .FS ]= "./__graphchain_cache__" ,
368
- scheduler : Optional [Callable ]= None ) -> Any :
367
+ skip_keys : Optional [Container [Hashable ]] = None ,
368
+ location : Union [str , fs .base .FS ] = "./__graphchain_cache__" ,
369
+ scheduler : Optional [Callable [
370
+ [Dict [Hashable , Any ], Union [Hashable , Iterable [Hashable ]]],
371
+ Any
372
+ ]] = None ) -> Any :
369
373
"""Get one or more keys from a dask graph with caching.
370
374
371
375
Optimizes a dask graph with ``graphchain.optimize`` and then computes the
@@ -377,24 +381,24 @@ def get(
377
381
378
382
Parameters
379
383
----------
380
- dsk
381
- The dask graph to query.
382
- keys
383
- The keys to compute.
384
- skip_keys
385
- A container of keys not to cache.
386
- location
387
- A PyFilesystem FS URL to store the cached computations in. Can be a
388
- local directory such as ``'./__graphchain_cache__'`` or a remote
389
- directory such as ``'s3://bucket/__graphchain_cache__'``. You can
390
- also pass a PyFilesystem itself instead.
391
- scheduler
392
- The dask scheduler to use to retrieve the keys from the graph.
384
+ dsk
385
+ The dask graph to query.
386
+ keys
387
+ The keys to compute.
388
+ skip_keys
389
+ A container of keys not to cache.
390
+ location
391
+ A PyFilesystem FS URL to store the cached computations in. Can be a
392
+ local directory such as ``'./__graphchain_cache__'`` or a remote
393
+ directory such as ``'s3://bucket/__graphchain_cache__'``. You can also
394
+ pass a PyFilesystem itself instead.
395
+ scheduler
396
+ The dask scheduler to use to retrieve the keys from the graph.
393
397
394
398
Returns
395
399
-------
396
- Any
397
- The computed values corresponding to the given keys.
400
+ Any
401
+ The computed values corresponding to the given keys.
398
402
"""
399
403
cached_dsk = optimize (dsk , keys , skip_keys = skip_keys , location = location )
400
404
scheduler = \
0 commit comments