1
1
from __future__ import annotations
2
2
3
3
from copy import copy
4
- from typing import Any , Callable , Iterable
4
+ from typing import Any , Callable , Iterable , Tuple
5
5
6
6
import cloudpickle
7
- import numpy as np
8
7
from sortedcontainers import SortedDict , SortedSet
9
8
10
9
from adaptive .learner .base_learner import BaseLearner
10
+ from adaptive .types import Int
11
11
from adaptive .utils import assign_defaults , partial_function_from_dataframe
12
12
13
13
try :
18
18
except ModuleNotFoundError :
19
19
with_pandas = False
20
20
21
+ try :
22
+ from typing import TypeAlias
23
+ except ImportError :
24
+ from typing_extensions import TypeAlias
25
+
26
+
27
+ PointType : TypeAlias = Tuple [Int , Any ]
28
+
21
29
22
30
class _IgnoreFirstArgument :
23
31
"""Remove the first argument from the call signature.
@@ -32,9 +40,7 @@ class _IgnoreFirstArgument:
32
40
def __init__ (self , function : Callable ) -> None :
33
41
self .function = function # type: ignore
34
42
35
- def __call__ (
36
- self , index_point : tuple [int , float | np .ndarray ], * args , ** kwargs
37
- ) -> float :
43
+ def __call__ (self , index_point : PointType , * args , ** kwargs ):
38
44
index , point = index_point
39
45
return self .function (point , * args , ** kwargs )
40
46
@@ -85,7 +91,9 @@ def new(self) -> SequenceLearner:
85
91
"""Return a new `~adaptive.SequenceLearner` without the data."""
86
92
return SequenceLearner (self ._original_function , self .sequence )
87
93
88
- def ask (self , n : int , tell_pending : bool = True ) -> tuple [Any , list [float ]]:
94
+ def ask (
95
+ self , n : int , tell_pending : bool = True
96
+ ) -> tuple [list [PointType ], list [float ]]:
89
97
indices = []
90
98
points = []
91
99
loss_improvements = []
@@ -105,31 +113,31 @@ def ask(self, n: int, tell_pending: bool = True) -> tuple[Any, list[float]]:
105
113
106
114
def loss (self , real : bool = True ) -> float :
107
115
if not (self ._to_do_indices or self .pending_points ):
108
- return 0
116
+ return 0.0
109
117
else :
110
118
npoints = self .npoints + (0 if real else len (self .pending_points ))
111
119
return (self ._ntotal - npoints ) / self ._ntotal
112
120
113
- def remove_unfinished (self ):
121
+ def remove_unfinished (self ) -> None :
114
122
for i in self .pending_points :
115
123
self ._to_do_indices .add (i )
116
124
self .pending_points = set ()
117
125
118
- def tell (self , point : tuple [ int , Any ] , value : Any ) -> None :
126
+ def tell (self , point : PointType , value : Any ) -> None :
119
127
index , point = point
120
128
self .data [index ] = value
121
129
self .pending_points .discard (index )
122
130
self ._to_do_indices .discard (index )
123
131
124
- def tell_pending (self , point : Any ) -> None :
132
+ def tell_pending (self , point : PointType ) -> None :
125
133
index , point = point
126
134
self .pending_points .add (index )
127
135
self ._to_do_indices .discard (index )
128
136
129
- def done (self ):
137
+ def done (self ) -> bool :
130
138
return not self ._to_do_indices and not self .pending_points
131
139
132
- def result (self ):
140
+ def result (self ) -> list [ Any ] :
133
141
"""Get the function values in the same order as ``sequence``."""
134
142
if not self .done ():
135
143
raise Exception ("Learner is not yet complete." )
@@ -217,16 +225,18 @@ def load_dataframe(
217
225
y_name : str, optional
218
226
The ``y_name`` used in ``to_dataframe``, by default "y"
219
227
"""
220
- self .tell_many (df [[index_name , x_name ]].values , df [y_name ].values )
228
+ indices = df [index_name ].values
229
+ xs = df [x_name ].values
230
+ self .tell_many (zip (indices , xs ), df [y_name ].values )
221
231
if with_default_function_args :
222
232
self .function = partial_function_from_dataframe (
223
233
self ._original_function , df , function_prefix
224
234
)
225
235
226
- def _get_data (self ) -> SortedDict :
236
+ def _get_data (self ) -> dict [ int , Any ] :
227
237
return self .data
228
238
229
- def _set_data (self , data : SortedDict ) -> None :
239
+ def _set_data (self , data : dict [ int , Any ] ) -> None :
230
240
if data :
231
241
indices , values = zip (* data .items ())
232
242
# the points aren't used by tell, so we can safely pass None
0 commit comments