20
20
from adaptive .learner import (
21
21
AverageLearner ,
22
22
BalancingLearner ,
23
+ BaseLearner ,
23
24
DataSaver ,
24
25
IntegratorLearner ,
25
26
Learner1D ,
@@ -92,28 +93,15 @@ def uniform(a: Union[int, float], b: int) -> Callable:
92
93
learner_function_combos = collections .defaultdict (list )
93
94
94
95
95
- def learn_with (
96
- learner_type : Union [
97
- Type [Learner2D ],
98
- Type [SequenceLearner ],
99
- Type [AverageLearner ],
100
- Type [Learner1D ],
101
- Type [LearnerND ],
102
- ],
103
- ** init_kwargs ,
104
- ) -> Callable :
96
+ def learn_with (learner_type : Type [BaseLearner ], ** init_kwargs ,) -> Callable :
105
97
def _ (f ):
106
98
learner_function_combos [learner_type ].append ((f , init_kwargs ))
107
99
return f
108
100
109
101
return _
110
102
111
103
112
- def xfail (
113
- learner : Union [Type [Learner2D ], Type [LearnerND ]]
114
- ) -> Union [
115
- Tuple [MarkDecorator , Type [Learner2D ]], Tuple [MarkDecorator , Type [LearnerND ]]
116
- ]:
104
+ def xfail (learner : Type [BaseLearner ]) -> Tuple [MarkDecorator , Type [BaseLearner ]]:
117
105
return pytest .mark .xfail , learner
118
106
119
107
@@ -141,14 +129,7 @@ def linear_with_peak(x: Union[int, float], d: uniform(-1, 1)) -> float:
141
129
@learn_with (Learner2D , bounds = ((- 1 , 1 ), (- 1 , 1 )))
142
130
@learn_with (SequenceLearner , sequence = np .random .rand (1000 , 2 ))
143
131
def ring_of_fire (
144
- xy : Union [
145
- Tuple [float , float ],
146
- np .ndarray ,
147
- Tuple [int , int ],
148
- Tuple [float , float ],
149
- Tuple [float , float ],
150
- ],
151
- d : uniform (0.2 , 1 ),
132
+ xy : Union [np .ndarray , Tuple [float , float ]], d : uniform (0.2 , 1 ),
152
133
) -> float :
153
134
a = 0.2
154
135
x , y = xy
@@ -158,8 +139,7 @@ def ring_of_fire(
158
139
@learn_with (LearnerND , bounds = ((- 1 , 1 ), (- 1 , 1 ), (- 1 , 1 )))
159
140
@learn_with (SequenceLearner , sequence = np .random .rand (1000 , 3 ))
160
141
def sphere_of_fire (
161
- xyz : Union [Tuple [float , float , float ], Tuple [int , int , int ], np .ndarray ],
162
- d : uniform (0.2 , 1 ),
142
+ xyz : Union [Tuple [float , float , float ], np .ndarray ], d : uniform (0.2 , 1 ),
163
143
) -> float :
164
144
a = 0.2
165
145
x , y , z = xyz
@@ -177,16 +157,7 @@ def gaussian(n: int) -> float:
177
157
178
158
# Create a sequence of learner parameters by adding all
179
159
# possible loss functions to an existing parameter set.
180
- def add_loss_to_params (
181
- learner_type : Union [
182
- Type [Learner2D ],
183
- Type [SequenceLearner ],
184
- Type [AverageLearner ],
185
- Type [Learner1D ],
186
- Type [LearnerND ],
187
- ],
188
- existing_params : Dict [str , Any ],
189
- ) -> Any :
160
+ def add_loss_to_params (learner_type , existing_params : Dict [str , Any ],) -> Any :
190
161
if learner_type not in LOSS_FUNCTIONS :
191
162
return [existing_params ]
192
163
loss_param , loss_functions = LOSS_FUNCTIONS [learner_type ]
@@ -216,12 +187,7 @@ def ask_randomly(
216
187
learner : Union [Learner1D , LearnerND , Learner2D ],
217
188
rounds : Tuple [int , int ],
218
189
points : Tuple [int , int ],
219
- ) -> Union [
220
- Tuple [List [Union [Tuple [float , float , float ], Tuple [int , int , int ]]], List [float ]],
221
- Tuple [List [Union [Tuple [float , float ], Tuple [int , int ]]], List [float ]],
222
- Tuple [List [float ], List [float ]],
223
- Tuple [List [Union [Tuple [int , int ], Tuple [float , float ]]], List [float ]],
224
- ]:
190
+ ):
225
191
n_rounds = random .randrange (* rounds )
226
192
n_points = [random .randrange (* points ) for _ in range (n_rounds )]
227
193
@@ -240,7 +206,7 @@ def ask_randomly(
240
206
241
207
@run_with (Learner1D )
242
208
def test_uniform_sampling1D (
243
- learner_type : Type [ Learner1D ] ,
209
+ learner_type ,
244
210
f : Callable ,
245
211
learner_kwargs : Dict [str , Union [Tuple [int , int ], Callable ]],
246
212
) -> None :
@@ -262,7 +228,7 @@ def test_uniform_sampling1D(
262
228
@pytest .mark .xfail
263
229
@run_with (Learner2D , LearnerND )
264
230
def test_uniform_sampling2D (
265
- learner_type : Union [ Type [ Learner2D ], Type [ LearnerND ]] ,
231
+ learner_type ,
266
232
f : Callable ,
267
233
learner_kwargs : Dict [
268
234
str ,
@@ -304,8 +270,7 @@ def test_uniform_sampling2D(
304
270
],
305
271
)
306
272
def test_learner_accepts_lists (
307
- learner_type : Union [Type [Learner2D ], Type [LearnerND ], Type [Learner1D ]],
308
- bounds : Union [Tuple [int , int ], List [Tuple [int , int ]]],
273
+ learner_type , bounds : Union [Tuple [int , int ], List [Tuple [int , int ]]],
309
274
) -> None :
310
275
def f (x ):
311
276
return [0 , 1 ]
@@ -316,11 +281,7 @@ def f(x):
316
281
317
282
@run_with (Learner1D , Learner2D , LearnerND , SequenceLearner )
318
283
def test_adding_existing_data_is_idempotent (
319
- learner_type : Union [
320
- Type [SequenceLearner ], Type [LearnerND ], Type [Learner1D ], Type [Learner2D ]
321
- ],
322
- f : Callable ,
323
- learner_kwargs : Dict [str , Any ],
284
+ learner_type , f : Callable , learner_kwargs : Dict [str , Any ],
324
285
) -> None :
325
286
"""Adding already existing data is an idempotent operation.
326
287
@@ -369,15 +330,7 @@ def test_adding_existing_data_is_idempotent(
369
330
# but we xfail it now, as Learner2D will be deprecated anyway
370
331
@run_with (Learner1D , xfail (Learner2D ), LearnerND , AverageLearner , SequenceLearner )
371
332
def test_adding_non_chosen_data (
372
- learner_type : Union [
373
- Type [Learner2D ],
374
- Type [SequenceLearner ],
375
- Type [AverageLearner ],
376
- Type [Learner1D ],
377
- Type [LearnerND ],
378
- ],
379
- f : Callable ,
380
- learner_kwargs : Dict [str , Any ],
333
+ learner_type , f : Callable , learner_kwargs : Dict [str , Any ],
381
334
) -> None :
382
335
"""Adding data for a point that was not returned by 'ask'."""
383
336
# XXX: learner, control and bounds are not defined
@@ -421,9 +374,7 @@ def test_adding_non_chosen_data(
421
374
422
375
@run_with (Learner1D , xfail (Learner2D ), xfail (LearnerND ), AverageLearner )
423
376
def test_point_adding_order_is_irrelevant (
424
- learner_type : Union [
425
- Type [AverageLearner ], Type [LearnerND ], Type [Learner1D ], Type [Learner2D ]
426
- ],
377
+ learner_type ,
427
378
f : Callable ,
428
379
learner_kwargs : Dict [
429
380
str ,
@@ -478,9 +429,7 @@ def test_point_adding_order_is_irrelevant(
478
429
# see https://github.com/python-adaptive/adaptive/issues/55
479
430
@run_with (Learner1D , xfail (Learner2D ), LearnerND , AverageLearner )
480
431
def test_expected_loss_improvement_is_less_than_total_loss (
481
- learner_type : Union [
482
- Type [AverageLearner ], Type [LearnerND ], Type [Learner1D ], Type [Learner2D ]
483
- ],
432
+ learner_type ,
484
433
f : Callable ,
485
434
learner_kwargs : Dict [
486
435
str ,
@@ -519,7 +468,7 @@ def test_expected_loss_improvement_is_less_than_total_loss(
519
468
# but we xfail it now, as Learner2D will be deprecated anyway
520
469
@run_with (Learner1D , xfail (Learner2D ), LearnerND )
521
470
def test_learner_performance_is_invariant_under_scaling (
522
- learner_type : Union [ Type [ Learner2D ], Type [ LearnerND ], Type [ Learner1D ]] ,
471
+ learner_type ,
523
472
f : Callable ,
524
473
learner_kwargs : Dict [
525
474
str ,
@@ -583,15 +532,7 @@ def test_learner_performance_is_invariant_under_scaling(
583
532
with_all_loss_functions = False ,
584
533
)
585
534
def test_balancing_learner (
586
- learner_type : Union [
587
- Type [Learner2D ],
588
- Type [SequenceLearner ],
589
- Type [AverageLearner ],
590
- Type [Learner1D ],
591
- Type [LearnerND ],
592
- ],
593
- f : Callable ,
594
- learner_kwargs : Dict [str , Any ],
535
+ learner_type , f : Callable , learner_kwargs : Dict [str , Any ],
595
536
) -> None :
596
537
"""Test if the BalancingLearner works with the different types of learners."""
597
538
learners = [
@@ -638,17 +579,7 @@ def test_balancing_learner(
638
579
SequenceLearner ,
639
580
with_all_loss_functions = False ,
640
581
)
641
- def test_saving (
642
- learner_type : Union [
643
- Type [Learner2D ],
644
- Type [SequenceLearner ],
645
- Type [AverageLearner ],
646
- Type [Learner1D ],
647
- Type [LearnerND ],
648
- ],
649
- f : Callable ,
650
- learner_kwargs : Dict [str , Any ],
651
- ) -> None :
582
+ def test_saving (learner_type , f : Callable , learner_kwargs : Dict [str , Any ],) -> None :
652
583
f = generate_random_parametrization (f )
653
584
learner = learner_type (f , ** learner_kwargs )
654
585
control = learner_type (f , ** learner_kwargs )
@@ -680,15 +611,7 @@ def test_saving(
680
611
with_all_loss_functions = False ,
681
612
)
682
613
def test_saving_of_balancing_learner (
683
- learner_type : Union [
684
- Type [Learner2D ],
685
- Type [SequenceLearner ],
686
- Type [AverageLearner ],
687
- Type [Learner1D ],
688
- Type [LearnerND ],
689
- ],
690
- f : Callable ,
691
- learner_kwargs : Dict [str , Any ],
614
+ learner_type , f : Callable , learner_kwargs : Dict [str , Any ],
692
615
) -> None :
693
616
f = generate_random_parametrization (f )
694
617
learner = BalancingLearner ([learner_type (f , ** learner_kwargs )])
@@ -727,9 +650,7 @@ def fname(learner):
727
650
with_all_loss_functions = False ,
728
651
)
729
652
def test_saving_with_datasaver (
730
- learner_type : Union [
731
- Type [Learner2D ], Type [AverageLearner ], Type [LearnerND ], Type [Learner1D ]
732
- ],
653
+ learner_type ,
733
654
f : Callable ,
734
655
learner_kwargs : Dict [
735
656
str ,
@@ -770,7 +691,7 @@ def test_saving_with_datasaver(
770
691
@pytest .mark .xfail
771
692
@run_with (Learner1D , Learner2D , LearnerND )
772
693
def test_convergence_for_arbitrary_ordering (
773
- learner_type : Union [ Type [ Learner2D ], Type [ LearnerND ], Type [ Learner1D ]] ,
694
+ learner_type ,
774
695
f : Callable ,
775
696
learner_kwargs : Dict [
776
697
str ,
@@ -794,17 +715,7 @@ def test_convergence_for_arbitrary_ordering(
794
715
@pytest .mark .xfail
795
716
@run_with (Learner1D , Learner2D , LearnerND )
796
717
def test_learner_subdomain (
797
- learner_type : Union [Type [Learner2D ], Type [LearnerND ], Type [Learner1D ]],
798
- f : Callable ,
799
- learner_kwargs : Dict [
800
- str ,
801
- Union [
802
- Tuple [Tuple [int , int ], Tuple [int , int ]],
803
- Callable ,
804
- Tuple [int , int ],
805
- Tuple [Tuple [int , int ], Tuple [int , int ], Tuple [int , int ]],
806
- ],
807
- ],
718
+ learner_type , f : Callable , learner_kwargs ,
808
719
):
809
720
"""Learners that never receive data outside of a subdomain should
810
721
perform 'similarly' to learners defined on that subdomain only."""
0 commit comments