Skip to content

Commit 8f65fb8

Browse files
committed
Merge remote-tracking branch 'origin/master' into mypy
2 parents 80518a1 + 50fae43 commit 8f65fb8

15 files changed

+113
-26
lines changed

adaptive/learner/average_learner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ def __init__(
7474
self.sum_f: Real = 0.0
7575
self.sum_f_sq: Real = 0.0
7676

77+
def new(self) -> AverageLearner:
78+
"""Create a copy of `~adaptive.AverageLearner` without the data."""
79+
return AverageLearner(self.function, self.atol, self.rtol, self.min_npoints)
80+
7781
@property
7882
def n_requested(self) -> int:
7983
return self.npoints + len(self.pending_points)

adaptive/learner/average_learner1D.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,20 @@ def __init__(
125125
# {xii: error[xii]/min(_distances[xi], _distances[xii], ...}
126126
self.rescaled_error: dict[Real, float] = decreasing_dict()
127127

128+
def new(self) -> AverageLearner1D:
129+
"""Create a copy of `~adaptive.AverageLearner1D` without the data."""
130+
return AverageLearner1D(
131+
self.function,
132+
self.bounds,
133+
self.loss_per_interval,
134+
self.delta,
135+
self.alpha,
136+
self.neighbor_sampling,
137+
self.min_samples,
138+
self.max_samples,
139+
self.min_error,
140+
)
141+
128142
@property
129143
def nsamples(self) -> int:
130144
"""Returns the total number of samples"""

adaptive/learner/balancing_learner.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,14 @@ def __init__(
118118

119119
self.strategy: STRATEGY_TYPE = strategy
120120

121+
def new(self) -> BalancingLearner:
122+
"""Create a new `BalancingLearner` with the same parameters."""
123+
return BalancingLearner(
124+
[learner.new() for learner in self.learners],
125+
cdims=self._cdims_default,
126+
strategy=self.strategy,
127+
)
128+
121129
@property
122130
def data(self) -> dict[tuple[int, Any], Any]:
123131
data = {}

adaptive/learner/base_learner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,11 @@ def _get_data(self):
157157
def _set_data(self, data: Any):
158158
pass
159159

160+
@abc.abstractmethod
161+
def new(self):
162+
"""Return a new learner with the same function and parameters."""
163+
pass
164+
160165
def copy_from(self, other):
161166
"""Copy over the data from another learner.
162167

adaptive/learner/data_saver.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ def __init__(self, learner: BaseLearner, arg_picker: itemgetter) -> None:
4747
self.function = learner.function
4848
self.arg_picker = arg_picker
4949

50+
def new(self) -> DataSaver:
51+
"""Return a new `DataSaver` with the same `arg_picker` and `learner`."""
52+
return DataSaver(self.learner.new(), self.arg_picker)
53+
5054
def __getattr__(self, attr: str) -> Any:
5155
return getattr(self.learner, attr)
5256

adaptive/learner/integrator_learner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,10 @@ def __init__(self, function: Callable, bounds: tuple[int, int], tol: float) -> N
390390
self.add_ival(ival)
391391
self.first_ival = ival
392392

393+
def new(self) -> IntegratorLearner:
394+
"""Create a copy of `~adaptive.Learner2D` without the data."""
395+
return IntegratorLearner(self.function, self.bounds, self.tol)
396+
393397
@property
394398
def approximating_intervals(self) -> set[_Interval]:
395399
return self.first_ival.done_leaves

adaptive/learner/learner1D.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@
2222
partial_function_from_dataframe,
2323
)
2424

25+
try:
26+
from typing import TypeAlias
27+
except ImportError:
28+
# Remove this when we drop support for Python 3.9
29+
from typing_extensions import TypeAlias
30+
2531
try:
2632
import pandas
2733

@@ -33,21 +39,21 @@
3339
# -- types --
3440

3541
# Commonly used types
36-
Interval = Union[Tuple[float, float], Tuple[float, float, int]]
37-
NeighborsType = Dict[float, List[Union[float, None]]]
42+
Interval: TypeAlias = Union[Tuple[float, float], Tuple[float, float, int]]
43+
NeighborsType: TypeAlias = Dict[float, List[Union[float, None]]]
3844

3945
# Types for loss_per_interval functions
40-
NoneFloat = Union[Float, None]
41-
NoneArray = Union[np.ndarray, None]
42-
XsType0 = Tuple[Float, Float]
43-
YsType0 = Union[Tuple[Float, Float], Tuple[np.ndarray, np.ndarray]]
44-
XsType1 = Tuple[NoneFloat, NoneFloat, NoneFloat, NoneFloat]
45-
YsType1 = Union[
46+
NoneFloat: TypeAlias = Union[Float, None]
47+
NoneArray: TypeAlias = Union[np.ndarray, None]
48+
XsType0: TypeAlias = Tuple[Float, Float]
49+
YsType0: TypeAlias = Union[Tuple[Float, Float], Tuple[np.ndarray, np.ndarray]]
50+
XsType1: TypeAlias = Tuple[NoneFloat, NoneFloat, NoneFloat, NoneFloat]
51+
YsType1: TypeAlias = Union[
4652
Tuple[NoneFloat, NoneFloat, NoneFloat, NoneFloat],
4753
Tuple[NoneArray, NoneArray, NoneArray, NoneArray],
4854
]
49-
XsTypeN = Tuple[NoneFloat, ...]
50-
YsTypeN = Union[Tuple[NoneFloat, ...], Tuple[NoneArray, ...]]
55+
XsTypeN: TypeAlias = Tuple[NoneFloat, ...]
56+
YsTypeN: TypeAlias = Union[Tuple[NoneFloat, ...], Tuple[NoneArray, ...]]
5157

5258

5359
__all__ = [
@@ -303,11 +309,15 @@ def __init__(
303309
# The precision in 'x' below which we set losses to 0.
304310
self._dx_eps = 2 * max(np.abs(bounds)) * np.finfo(float).eps
305311

306-
self.bounds = list(bounds)
312+
self.bounds = tuple(bounds)
307313
self.__missing_bounds = set(self.bounds) # cache of missing bounds
308314

309315
self._vdim: int | None = None
310316

317+
def new(self) -> Learner1D:
318+
"""Create a copy of `~adaptive.Learner1D` without the data."""
319+
return Learner1D(self.function, self.bounds, self.loss_per_interval)
320+
311321
@property
312322
def vdim(self) -> int:
313323
"""Length of the output of ``learner.function``.

adaptive/learner/learner2D.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,9 @@ def __init__(
393393

394394
self.stack_size = 10
395395

396+
def new(self) -> Learner2D:
397+
return Learner2D(self.function, self.bounds, self.loss_per_triangle)
398+
396399
@property
397400
def xy_scale(self) -> np.ndarray:
398401
xy_scale = self._xy_scale

adaptive/learner/learnerND.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,10 @@ def __init__(
395395
# _pop_highest_existing_simplex
396396
self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority)
397397

398+
def new(self) -> LearnerND:
399+
"""Create a new learner with the same function and bounds."""
400+
return LearnerND(self.function, self.bounds, self.loss_per_simplex)
401+
398402
@property
399403
def npoints(self) -> int:
400404
"""Number of evaluated points."""

adaptive/learner/sequence_learner.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ def __init__(self, function: Callable, sequence: Iterable) -> None:
8181
self.data = SortedDict()
8282
self.pending_points = set()
8383

84+
def new(self) -> SequenceLearner:
85+
"""Return a new `~adaptive.SequenceLearner` without the data."""
86+
return SequenceLearner(self._original_function, self.sequence)
87+
8488
def ask(self, n: int, tell_pending: bool = True) -> tuple[Any, list[float]]:
8589
indices = []
8690
points = []
@@ -179,7 +183,7 @@ def to_dataframe(
179183
df.attrs["inputs"] = [index_name]
180184
df.attrs["output"] = y_name
181185
if with_default_function_args:
182-
assign_defaults(self.function, df, function_prefix)
186+
assign_defaults(self._original_function, df, function_prefix)
183187
return df
184188

185189
def load_dataframe(
@@ -216,7 +220,7 @@ def load_dataframe(
216220
self.tell_many(df[[index_name, x_name]].values, df[y_name].values)
217221
if with_default_function_args:
218222
self.function = partial_function_from_dataframe(
219-
self.function, df, function_prefix
223+
self._original_function, df, function_prefix
220224
)
221225

222226
def _get_data(self) -> SortedDict:

adaptive/learner/skopt_learner.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,13 @@ def __init__(self, function: Callable, **kwargs) -> None:
3030
self.function = function # type: ignore
3131
self.pending_points = set()
3232
self.data = collections.OrderedDict()
33+
self._kwargs = kwargs
3334
super().__init__(**kwargs)
3435

36+
def new(self) -> SKOptLearner:
37+
"""Return a new `~adaptive.SKOptLearner` without the data."""
38+
return SKOptLearner(self.function, **self._kwargs)
39+
3540
def tell(self, x: float | list[float], y: float, fit: bool = True) -> None:
3641
if isinstance(x, collections.abc.Iterable):
3742
self.pending_points.discard(tuple(x))

adaptive/tests/test_learners.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs):
294294
"""
295295
f = generate_random_parametrization(f)
296296
learner = learner_type(f, **learner_kwargs)
297-
control = learner_type(f, **learner_kwargs)
297+
control = learner.new()
298298
if learner_type in (Learner1D, AverageLearner1D):
299299
learner._recompute_losses_factor = 1
300300
control._recompute_losses_factor = 1
@@ -345,7 +345,7 @@ def test_adding_non_chosen_data(learner_type, f, learner_kwargs):
345345
# XXX: learner, control and bounds are not defined
346346
f = generate_random_parametrization(f)
347347
learner = learner_type(f, **learner_kwargs)
348-
control = learner_type(f, **learner_kwargs)
348+
control = learner.new()
349349

350350
if learner_type is Learner2D:
351351
# If the stack_size is bigger then the number of points added,
@@ -395,7 +395,7 @@ def test_point_adding_order_is_irrelevant(learner_type, f, learner_kwargs):
395395
"""
396396
f = generate_random_parametrization(f)
397397
learner = learner_type(f, **learner_kwargs)
398-
control = learner_type(f, **learner_kwargs)
398+
control = learner.new()
399399

400400
if learner_type in (Learner1D, AverageLearner1D):
401401
learner._recompute_losses_factor = 1
@@ -581,7 +581,7 @@ def test_balancing_learner(learner_type, f, learner_kwargs):
581581
def test_saving(learner_type, f, learner_kwargs):
582582
f = generate_random_parametrization(f)
583583
learner = learner_type(f, **learner_kwargs)
584-
control = learner_type(f, **learner_kwargs)
584+
control = learner.new()
585585
if learner_type in (Learner1D, AverageLearner1D):
586586
learner._recompute_losses_factor = 1
587587
control._recompute_losses_factor = 1
@@ -614,7 +614,7 @@ def test_saving(learner_type, f, learner_kwargs):
614614
def test_saving_of_balancing_learner(learner_type, f, learner_kwargs):
615615
f = generate_random_parametrization(f)
616616
learner = BalancingLearner([learner_type(f, **learner_kwargs)])
617-
control = BalancingLearner([learner_type(f, **learner_kwargs)])
617+
control = learner.new()
618618

619619
if learner_type in (Learner1D, AverageLearner1D):
620620
for l, c in zip(learner.learners, control.learners):
@@ -654,7 +654,7 @@ def test_saving_with_datasaver(learner_type, f, learner_kwargs):
654654
g = lambda x: {"y": f(x), "t": random.random()} # noqa: E731
655655
arg_picker = operator.itemgetter("y")
656656
learner = DataSaver(learner_type(g, **learner_kwargs), arg_picker)
657-
control = DataSaver(learner_type(g, **learner_kwargs), arg_picker)
657+
control = learner.new()
658658

659659
if learner_type in (Learner1D, AverageLearner1D):
660660
learner.learner._recompute_losses_factor = 1
@@ -742,7 +742,7 @@ def test_to_dataframe(learner_type, f, learner_kwargs):
742742
assert len(df) == learner.npoints
743743

744744
# Add points from the DataFrame to a new empty learner
745-
learner2 = learner_type(learner.function, **learner_kwargs)
745+
learner2 = learner.new()
746746
learner2.load_dataframe(df, **kw)
747747
assert learner2.npoints == learner.npoints
748748

@@ -787,8 +787,7 @@ def test_to_dataframe(learner_type, f, learner_kwargs):
787787
assert len(df) == data_saver.npoints
788788

789789
# Test loading from a DataFrame into a new DataSaver
790-
learner2 = learner_type(learner.function, **learner_kwargs)
791-
data_saver2 = DataSaver(learner2, operator.itemgetter("result"))
790+
data_saver2 = data_saver.new()
792791
data_saver2.load_dataframe(df, **kw)
793792
assert data_saver2.extra_data.keys() == data_saver.extra_data.keys()
794793
assert all(

adaptive/types.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
import numpy as np
44

5-
Float = Union[float, np.float_]
6-
Int = Union[int, np.int_]
7-
Real = Union[Float, Int]
5+
try:
6+
from typing import TypeAlias
7+
except ImportError:
8+
# Remove this when we drop support for Python 3.9
9+
from typing_extensions import TypeAlias
10+
11+
Float: TypeAlias = Union[float, np.float_]
12+
Int: TypeAlias = Union[int, np.int_]
13+
Real: TypeAlias = Union[Float, Int]

adaptive/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import inspect
77
import os
88
import pickle
9+
import warnings
910
from contextlib import _GeneratorContextManager, contextmanager
1011
from itertools import product
1112
from typing import Any, Callable, Mapping, Sequence
@@ -138,4 +139,19 @@ def partial_function_from_dataframe(function, df, function_prefix: str = "functi
138139
kwargs[k] = v
139140
if not kwargs:
140141
return function
142+
143+
sig = inspect.signature(function)
144+
for k, v in kwargs.items():
145+
if k not in sig.parameters:
146+
raise ValueError(
147+
f"The DataFrame contains a default parameter"
148+
f" ({k}={v}) but the function does not have that parameter."
149+
)
150+
default = sig.parameters[k].default
151+
if default != inspect._empty and kwargs[k] != default:
152+
warnings.warn(
153+
f"The DataFrame contains a default parameter"
154+
f" ({k}={v}) but the function already has a default ({k}={default})."
155+
" The DataFrame's value will be used."
156+
)
141157
return functools.partial(function, **kwargs)

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_version_and_cmdclass(package_name):
3131
"cloudpickle",
3232
"loky >= 2.9",
3333
]
34-
if sys.version_info < (3, 8):
34+
if sys.version_info < (3, 10):
3535
install_requires.append("typing_extensions")
3636

3737
extras_require = {
@@ -80,6 +80,7 @@ def get_version_and_cmdclass(package_name):
8080
"Programming Language :: Python :: 3.7",
8181
"Programming Language :: Python :: 3.8",
8282
"Programming Language :: Python :: 3.9",
83+
"Programming Language :: Python :: 3.10",
8384
],
8485
packages=find_packages("."),
8586
install_requires=install_requires,

0 commit comments

Comments
 (0)