1
+ from __future__ import annotations
2
+
1
3
import collections .abc
2
4
import itertools
3
5
import math
4
6
from copy import copy , deepcopy
5
- from typing import Any , Callable , Dict , List , Optional , Sequence , Set , Tuple , Union
7
+ from typing import Any , Callable , Dict , List , Sequence , Tuple , Union
6
8
7
9
import cloudpickle
8
10
import numpy as np
@@ -170,7 +172,7 @@ def curvature_loss(xs: XsType1, ys: YsType1) -> Float:
170
172
return curvature_loss
171
173
172
174
173
- def linspace (x_left : Real , x_right : Real , n : Int ) -> List [Float ]:
175
+ def linspace (x_left : Real , x_right : Real , n : Int ) -> list [Float ]:
174
176
"""This is equivalent to
175
177
'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
176
178
but it is 15-30 times faster for small 'n'."""
@@ -194,7 +196,7 @@ def _get_neighbors_from_array(xs: np.ndarray) -> NeighborsType:
194
196
195
197
def _get_intervals (
196
198
x : float , neighbors : NeighborsType , nth_neighbors : int
197
- ) -> List [ Tuple [float , float ]]:
199
+ ) -> list [ tuple [float , float ]]:
198
200
nn = nth_neighbors
199
201
i = neighbors .index (x )
200
202
start = max (0 , i - nn - 1 )
@@ -249,9 +251,9 @@ class Learner1D(BaseLearner):
249
251
250
252
def __init__ (
251
253
self ,
252
- function : Callable [[Real ], Union [ Float , np .ndarray ] ],
253
- bounds : Tuple [Real , Real ],
254
- loss_per_interval : Optional [ Callable [[XsTypeN , YsTypeN ], Float ]] = None ,
254
+ function : Callable [[Real ], Float | np .ndarray ],
255
+ bounds : tuple [Real , Real ],
256
+ loss_per_interval : Callable [[XsTypeN , YsTypeN ], Float ] | None = None ,
255
257
):
256
258
self .function = function # type: ignore
257
259
@@ -267,8 +269,8 @@ def __init__(
267
269
# the learners behavior in the tests.
268
270
self ._recompute_losses_factor = 2
269
271
270
- self .data : Dict [Real , Real ] = {}
271
- self .pending_points : Set [Real ] = set ()
272
+ self .data : dict [Real , Real ] = {}
273
+ self .pending_points : set [Real ] = set ()
272
274
273
275
# A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local
274
276
# properties.
@@ -292,7 +294,7 @@ def __init__(
292
294
self .bounds = list (bounds )
293
295
self .__missing_bounds = set (self .bounds ) # cache of missing bounds
294
296
295
- self ._vdim : Optional [ int ] = None
297
+ self ._vdim : int | None = None
296
298
297
299
@property
298
300
def vdim (self ) -> int :
@@ -334,20 +336,18 @@ def loss(self, real: bool = True) -> float:
334
336
max_interval , max_loss = losses .peekitem (0 )
335
337
return max_loss
336
338
337
- def _scale_x (self , x : Optional [ Float ] ) -> Optional [ Float ] :
339
+ def _scale_x (self , x : Float | None ) -> Float | None :
338
340
if x is None :
339
341
return None
340
342
return x / self ._scale [0 ]
341
343
342
- def _scale_y (
343
- self , y : Union [Float , np .ndarray , None ]
344
- ) -> Union [Float , np .ndarray , None ]:
344
+ def _scale_y (self , y : Float | np .ndarray | None ) -> Float | np .ndarray | None :
345
345
if y is None :
346
346
return None
347
347
y_scale = self ._scale [1 ] or 1
348
348
return y / y_scale
349
349
350
- def _get_point_by_index (self , ind : int ) -> Optional [ float ] :
350
+ def _get_point_by_index (self , ind : int ) -> float | None :
351
351
if ind < 0 or ind >= len (self .neighbors ):
352
352
return None
353
353
return self .neighbors .keys ()[ind ]
@@ -449,7 +449,7 @@ def _update_neighbors(self, x: float, neighbors: NeighborsType) -> None:
449
449
neighbors .get (x_left , [None , None ])[1 ] = x
450
450
neighbors .get (x_right , [None , None ])[0 ] = x
451
451
452
- def _update_scale (self , x : float , y : Union [ Float , np .ndarray ] ) -> None :
452
+ def _update_scale (self , x : float , y : Float | np .ndarray ) -> None :
453
453
"""Update the scale with which the x and y-values are scaled.
454
454
455
455
For a learner where the function returns a single scalar the scale
@@ -476,7 +476,7 @@ def _update_scale(self, x: float, y: Union[Float, np.ndarray]) -> None:
476
476
self ._bbox [1 ][1 ] = max (self ._bbox [1 ][1 ], y )
477
477
self ._scale [1 ] = self ._bbox [1 ][1 ] - self ._bbox [1 ][0 ]
478
478
479
- def tell (self , x : float , y : Union [ Float , Sequence [Float ], np .ndarray ] ) -> None :
479
+ def tell (self , x : float , y : Float | Sequence [Float ] | np .ndarray ) -> None :
480
480
if x in self .data :
481
481
# The point is already evaluated before
482
482
return
@@ -522,13 +522,9 @@ def tell_pending(self, x: float) -> None:
522
522
def tell_many (
523
523
self ,
524
524
xs : Sequence [Float ],
525
- ys : Union [
526
- Sequence [Float ],
527
- Sequence [Sequence [Float ]],
528
- Sequence [np .ndarray ],
529
- ],
525
+ ys : (Sequence [Float ] | Sequence [Sequence [Float ]] | Sequence [np .ndarray ]),
530
526
* ,
531
- force : bool = False
527
+ force : bool = False ,
532
528
) -> None :
533
529
if not force and not (len (xs ) > 0.5 * len (self .data ) and len (xs ) > 2 ):
534
530
# Only run this more efficient method if there are
@@ -597,7 +593,7 @@ def tell_many(
597
593
# have an inf loss.
598
594
self ._update_interpolated_loss_in_interval (* ival )
599
595
600
- def ask (self , n : int , tell_pending : bool = True ) -> Tuple [ List [float ], List [float ]]:
596
+ def ask (self , n : int , tell_pending : bool = True ) -> tuple [ list [float ], list [float ]]:
601
597
"""Return 'n' points that are expected to maximally reduce the loss."""
602
598
points , loss_improvements = self ._ask_points_without_adding (n )
603
599
@@ -607,7 +603,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[float], List[floa
607
603
608
604
return points , loss_improvements
609
605
610
- def _missing_bounds (self ) -> List [Real ]:
606
+ def _missing_bounds (self ) -> list [Real ]:
611
607
missing_bounds = []
612
608
for b in copy (self .__missing_bounds ):
613
609
if b in self .data :
@@ -616,7 +612,7 @@ def _missing_bounds(self) -> List[Real]:
616
612
missing_bounds .append (b )
617
613
return sorted (missing_bounds )
618
614
619
- def _ask_points_without_adding (self , n : int ) -> Tuple [ List [float ], List [float ]]:
615
+ def _ask_points_without_adding (self , n : int ) -> tuple [ list [float ], list [float ]]:
620
616
"""Return 'n' points that are expected to maximally reduce the loss.
621
617
Without altering the state of the learner"""
622
618
# Find out how to divide the n points over the intervals
@@ -691,8 +687,8 @@ def _ask_points_without_adding(self, n: int) -> Tuple[List[float], List[float]]:
691
687
return points , loss_improvements
692
688
693
689
def _loss (
694
- self , mapping : Dict [Interval , float ], ival : Interval
695
- ) -> Tuple [float , Interval ]:
690
+ self , mapping : dict [Interval , float ], ival : Interval
691
+ ) -> tuple [float , Interval ]:
696
692
loss = mapping [ival ]
697
693
return finite_loss (ival , loss , self ._scale [0 ])
698
694
@@ -736,10 +732,10 @@ def remove_unfinished(self) -> None:
736
732
self .losses_combined = deepcopy (self .losses )
737
733
self .neighbors_combined = deepcopy (self .neighbors )
738
734
739
- def _get_data (self ) -> Dict [float , float ]:
735
+ def _get_data (self ) -> dict [float , float ]:
740
736
return self .data
741
737
742
- def _set_data (self , data : Dict [float , float ]) -> None :
738
+ def _set_data (self , data : dict [float , float ]) -> None :
743
739
if data :
744
740
xs , ys = zip (* data .items ())
745
741
self .tell_many (xs , ys )
@@ -763,7 +759,7 @@ def __setstate__(self, state):
763
759
self .losses_combined .update (losses_combined )
764
760
765
761
766
- def loss_manager (x_scale : float ) -> Dict [Interval , float ]:
762
+ def loss_manager (x_scale : float ) -> dict [Interval , float ]:
767
763
def sort_key (ival , loss ):
768
764
loss , ival = finite_loss (ival , loss , x_scale )
769
765
return - loss , ival
@@ -772,7 +768,7 @@ def sort_key(ival, loss):
772
768
return sorted_dict
773
769
774
770
775
- def finite_loss (ival : Interval , loss : float , x_scale : float ) -> Tuple [float , Interval ]:
771
+ def finite_loss (ival : Interval , loss : float , x_scale : float ) -> tuple [float , Interval ]:
776
772
"""Get the so-called finite_loss of an interval in order to be able to
777
773
sort intervals that have infinite loss."""
778
774
# If the loss is infinite we return the
0 commit comments