1
+ from __future__ import annotations
2
+
1
3
import abc
2
4
import functools
3
5
import gzip
4
6
import inspect
5
7
import os
6
8
import pickle
7
9
import warnings
8
- from contextlib import contextmanager
10
+ from contextlib import _GeneratorContextManager , contextmanager
9
11
from itertools import product
12
+ from typing import Any , Callable , Mapping , Sequence
10
13
11
14
import cloudpickle
12
15
13
16
14
- def named_product (** items ):
17
+ def named_product (** items : Mapping [ str , Sequence [ Any ]] ):
15
18
names = items .keys ()
16
19
vals = items .values ()
17
20
return [dict (zip (names , res )) for res in product (* vals )]
18
21
19
22
20
23
@contextmanager
21
- def restore (* learners ):
24
+ def restore (* learners ) -> _GeneratorContextManager :
22
25
states = [learner .__getstate__ () for learner in learners ]
23
26
try :
24
27
yield
@@ -27,7 +30,7 @@ def restore(*learners):
27
30
learner .__setstate__ (state )
28
31
29
32
30
- def cache_latest (f ) :
33
+ def cache_latest (f : Callable ) -> Callable :
31
34
"""Cache the latest return value of the function and add it
32
35
as 'self._cache[f.__name__]'."""
33
36
@@ -42,7 +45,7 @@ def wrapper(*args, **kwargs):
42
45
return wrapper
43
46
44
47
45
- def save (fname , data , compress = True ):
48
+ def save (fname : str , data : Any , compress : bool = True ) -> None :
46
49
fname = os .path .expanduser (fname )
47
50
dirname = os .path .dirname (fname )
48
51
if dirname :
@@ -71,14 +74,14 @@ def save(fname, data, compress=True):
71
74
return True
72
75
73
76
74
- def load (fname , compress = True ):
77
+ def load (fname : str , compress : bool = True ):
75
78
fname = os .path .expanduser (fname )
76
79
_open = gzip .open if compress else open
77
80
with _open (fname , "rb" ) as f :
78
81
return cloudpickle .load (f )
79
82
80
83
81
- def copy_docstring_from (other ) :
84
+ def copy_docstring_from (other : Callable ) -> Callable :
82
85
def decorator (method ):
83
86
return functools .wraps (other )(method )
84
87
0 commit comments