Skip to content

Commit 4b4546c

Browse files
committed
type hint fixes for adaptive/tests/test_learners.py
1 parent cc296f4 commit 4b4546c

File tree

1 file changed

+21
-110
lines changed

1 file changed

+21
-110
lines changed

Diff for: adaptive/tests/test_learners.py

+21-110
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from adaptive.learner import (
2121
AverageLearner,
2222
BalancingLearner,
23+
BaseLearner,
2324
DataSaver,
2425
IntegratorLearner,
2526
Learner1D,
@@ -92,28 +93,15 @@ def uniform(a: Union[int, float], b: int) -> Callable:
9293
learner_function_combos = collections.defaultdict(list)
9394

9495

95-
def learn_with(
96-
learner_type: Union[
97-
Type[Learner2D],
98-
Type[SequenceLearner],
99-
Type[AverageLearner],
100-
Type[Learner1D],
101-
Type[LearnerND],
102-
],
103-
**init_kwargs,
104-
) -> Callable:
96+
def learn_with(learner_type: Type[BaseLearner], **init_kwargs,) -> Callable:
10597
def _(f):
10698
learner_function_combos[learner_type].append((f, init_kwargs))
10799
return f
108100

109101
return _
110102

111103

112-
def xfail(
113-
learner: Union[Type[Learner2D], Type[LearnerND]]
114-
) -> Union[
115-
Tuple[MarkDecorator, Type[Learner2D]], Tuple[MarkDecorator, Type[LearnerND]]
116-
]:
104+
def xfail(learner: Type[BaseLearner]) -> Tuple[MarkDecorator, Type[BaseLearner]]:
117105
return pytest.mark.xfail, learner
118106

119107

@@ -141,14 +129,7 @@ def linear_with_peak(x: Union[int, float], d: uniform(-1, 1)) -> float:
141129
@learn_with(Learner2D, bounds=((-1, 1), (-1, 1)))
142130
@learn_with(SequenceLearner, sequence=np.random.rand(1000, 2))
143131
def ring_of_fire(
144-
xy: Union[
145-
Tuple[float, float],
146-
np.ndarray,
147-
Tuple[int, int],
148-
Tuple[float, float],
149-
Tuple[float, float],
150-
],
151-
d: uniform(0.2, 1),
132+
xy: Union[np.ndarray, Tuple[float, float]], d: uniform(0.2, 1),
152133
) -> float:
153134
a = 0.2
154135
x, y = xy
@@ -158,8 +139,7 @@ def ring_of_fire(
158139
@learn_with(LearnerND, bounds=((-1, 1), (-1, 1), (-1, 1)))
159140
@learn_with(SequenceLearner, sequence=np.random.rand(1000, 3))
160141
def sphere_of_fire(
161-
xyz: Union[Tuple[float, float, float], Tuple[int, int, int], np.ndarray],
162-
d: uniform(0.2, 1),
142+
xyz: Union[Tuple[float, float, float], np.ndarray], d: uniform(0.2, 1),
163143
) -> float:
164144
a = 0.2
165145
x, y, z = xyz
@@ -177,16 +157,7 @@ def gaussian(n: int) -> float:
177157

178158
# Create a sequence of learner parameters by adding all
179159
# possible loss functions to an existing parameter set.
180-
def add_loss_to_params(
181-
learner_type: Union[
182-
Type[Learner2D],
183-
Type[SequenceLearner],
184-
Type[AverageLearner],
185-
Type[Learner1D],
186-
Type[LearnerND],
187-
],
188-
existing_params: Dict[str, Any],
189-
) -> Any:
160+
def add_loss_to_params(learner_type, existing_params: Dict[str, Any],) -> Any:
190161
if learner_type not in LOSS_FUNCTIONS:
191162
return [existing_params]
192163
loss_param, loss_functions = LOSS_FUNCTIONS[learner_type]
@@ -216,12 +187,7 @@ def ask_randomly(
216187
learner: Union[Learner1D, LearnerND, Learner2D],
217188
rounds: Tuple[int, int],
218189
points: Tuple[int, int],
219-
) -> Union[
220-
Tuple[List[Union[Tuple[float, float, float], Tuple[int, int, int]]], List[float]],
221-
Tuple[List[Union[Tuple[float, float], Tuple[int, int]]], List[float]],
222-
Tuple[List[float], List[float]],
223-
Tuple[List[Union[Tuple[int, int], Tuple[float, float]]], List[float]],
224-
]:
190+
):
225191
n_rounds = random.randrange(*rounds)
226192
n_points = [random.randrange(*points) for _ in range(n_rounds)]
227193

@@ -240,7 +206,7 @@ def ask_randomly(
240206

241207
@run_with(Learner1D)
242208
def test_uniform_sampling1D(
243-
learner_type: Type[Learner1D],
209+
learner_type,
244210
f: Callable,
245211
learner_kwargs: Dict[str, Union[Tuple[int, int], Callable]],
246212
) -> None:
@@ -262,7 +228,7 @@ def test_uniform_sampling1D(
262228
@pytest.mark.xfail
263229
@run_with(Learner2D, LearnerND)
264230
def test_uniform_sampling2D(
265-
learner_type: Union[Type[Learner2D], Type[LearnerND]],
231+
learner_type,
266232
f: Callable,
267233
learner_kwargs: Dict[
268234
str,
@@ -304,8 +270,7 @@ def test_uniform_sampling2D(
304270
],
305271
)
306272
def test_learner_accepts_lists(
307-
learner_type: Union[Type[Learner2D], Type[LearnerND], Type[Learner1D]],
308-
bounds: Union[Tuple[int, int], List[Tuple[int, int]]],
273+
learner_type, bounds: Union[Tuple[int, int], List[Tuple[int, int]]],
309274
) -> None:
310275
def f(x):
311276
return [0, 1]
@@ -316,11 +281,7 @@ def f(x):
316281

317282
@run_with(Learner1D, Learner2D, LearnerND, SequenceLearner)
318283
def test_adding_existing_data_is_idempotent(
319-
learner_type: Union[
320-
Type[SequenceLearner], Type[LearnerND], Type[Learner1D], Type[Learner2D]
321-
],
322-
f: Callable,
323-
learner_kwargs: Dict[str, Any],
284+
learner_type, f: Callable, learner_kwargs: Dict[str, Any],
324285
) -> None:
325286
"""Adding already existing data is an idempotent operation.
326287
@@ -369,15 +330,7 @@ def test_adding_existing_data_is_idempotent(
369330
# but we xfail it now, as Learner2D will be deprecated anyway
370331
@run_with(Learner1D, xfail(Learner2D), LearnerND, AverageLearner, SequenceLearner)
371332
def test_adding_non_chosen_data(
372-
learner_type: Union[
373-
Type[Learner2D],
374-
Type[SequenceLearner],
375-
Type[AverageLearner],
376-
Type[Learner1D],
377-
Type[LearnerND],
378-
],
379-
f: Callable,
380-
learner_kwargs: Dict[str, Any],
333+
learner_type, f: Callable, learner_kwargs: Dict[str, Any],
381334
) -> None:
382335
"""Adding data for a point that was not returned by 'ask'."""
383336
# XXX: learner, control and bounds are not defined
@@ -421,9 +374,7 @@ def test_adding_non_chosen_data(
421374

422375
@run_with(Learner1D, xfail(Learner2D), xfail(LearnerND), AverageLearner)
423376
def test_point_adding_order_is_irrelevant(
424-
learner_type: Union[
425-
Type[AverageLearner], Type[LearnerND], Type[Learner1D], Type[Learner2D]
426-
],
377+
learner_type,
427378
f: Callable,
428379
learner_kwargs: Dict[
429380
str,
@@ -478,9 +429,7 @@ def test_point_adding_order_is_irrelevant(
478429
# see https://github.com/python-adaptive/adaptive/issues/55
479430
@run_with(Learner1D, xfail(Learner2D), LearnerND, AverageLearner)
480431
def test_expected_loss_improvement_is_less_than_total_loss(
481-
learner_type: Union[
482-
Type[AverageLearner], Type[LearnerND], Type[Learner1D], Type[Learner2D]
483-
],
432+
learner_type,
484433
f: Callable,
485434
learner_kwargs: Dict[
486435
str,
@@ -519,7 +468,7 @@ def test_expected_loss_improvement_is_less_than_total_loss(
519468
# but we xfail it now, as Learner2D will be deprecated anyway
520469
@run_with(Learner1D, xfail(Learner2D), LearnerND)
521470
def test_learner_performance_is_invariant_under_scaling(
522-
learner_type: Union[Type[Learner2D], Type[LearnerND], Type[Learner1D]],
471+
learner_type,
523472
f: Callable,
524473
learner_kwargs: Dict[
525474
str,
@@ -583,15 +532,7 @@ def test_learner_performance_is_invariant_under_scaling(
583532
with_all_loss_functions=False,
584533
)
585534
def test_balancing_learner(
586-
learner_type: Union[
587-
Type[Learner2D],
588-
Type[SequenceLearner],
589-
Type[AverageLearner],
590-
Type[Learner1D],
591-
Type[LearnerND],
592-
],
593-
f: Callable,
594-
learner_kwargs: Dict[str, Any],
535+
learner_type, f: Callable, learner_kwargs: Dict[str, Any],
595536
) -> None:
596537
"""Test if the BalancingLearner works with the different types of learners."""
597538
learners = [
@@ -638,17 +579,7 @@ def test_balancing_learner(
638579
SequenceLearner,
639580
with_all_loss_functions=False,
640581
)
641-
def test_saving(
642-
learner_type: Union[
643-
Type[Learner2D],
644-
Type[SequenceLearner],
645-
Type[AverageLearner],
646-
Type[Learner1D],
647-
Type[LearnerND],
648-
],
649-
f: Callable,
650-
learner_kwargs: Dict[str, Any],
651-
) -> None:
582+
def test_saving(learner_type, f: Callable, learner_kwargs: Dict[str, Any],) -> None:
652583
f = generate_random_parametrization(f)
653584
learner = learner_type(f, **learner_kwargs)
654585
control = learner_type(f, **learner_kwargs)
@@ -680,15 +611,7 @@ def test_saving(
680611
with_all_loss_functions=False,
681612
)
682613
def test_saving_of_balancing_learner(
683-
learner_type: Union[
684-
Type[Learner2D],
685-
Type[SequenceLearner],
686-
Type[AverageLearner],
687-
Type[Learner1D],
688-
Type[LearnerND],
689-
],
690-
f: Callable,
691-
learner_kwargs: Dict[str, Any],
614+
learner_type, f: Callable, learner_kwargs: Dict[str, Any],
692615
) -> None:
693616
f = generate_random_parametrization(f)
694617
learner = BalancingLearner([learner_type(f, **learner_kwargs)])
@@ -727,9 +650,7 @@ def fname(learner):
727650
with_all_loss_functions=False,
728651
)
729652
def test_saving_with_datasaver(
730-
learner_type: Union[
731-
Type[Learner2D], Type[AverageLearner], Type[LearnerND], Type[Learner1D]
732-
],
653+
learner_type,
733654
f: Callable,
734655
learner_kwargs: Dict[
735656
str,
@@ -770,7 +691,7 @@ def test_saving_with_datasaver(
770691
@pytest.mark.xfail
771692
@run_with(Learner1D, Learner2D, LearnerND)
772693
def test_convergence_for_arbitrary_ordering(
773-
learner_type: Union[Type[Learner2D], Type[LearnerND], Type[Learner1D]],
694+
learner_type,
774695
f: Callable,
775696
learner_kwargs: Dict[
776697
str,
@@ -794,17 +715,7 @@ def test_convergence_for_arbitrary_ordering(
794715
@pytest.mark.xfail
795716
@run_with(Learner1D, Learner2D, LearnerND)
796717
def test_learner_subdomain(
797-
learner_type: Union[Type[Learner2D], Type[LearnerND], Type[Learner1D]],
798-
f: Callable,
799-
learner_kwargs: Dict[
800-
str,
801-
Union[
802-
Tuple[Tuple[int, int], Tuple[int, int]],
803-
Callable,
804-
Tuple[int, int],
805-
Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int]],
806-
],
807-
],
718+
learner_type, f: Callable, learner_kwargs,
808719
):
809720
"""Learners that never receive data outside of a subdomain should
810721
perform 'similarly' to learners defined on that subdomain only."""

0 commit comments

Comments
 (0)