9
9
import numpy as np
10
10
from scipy import interpolate
11
11
import scipy .spatial
12
+ from sortedcontainers import SortedKeyList
12
13
13
14
from adaptive .learner .base_learner import BaseLearner
14
15
from adaptive .notebook_integration import ensure_holoviews , ensure_plotly
@@ -91,7 +92,6 @@ def choose_point_in_simplex(simplex, transform=None):
91
92
distance_matrix = scipy .spatial .distance .squareform (distances )
92
93
i , j = np .unravel_index (np .argmax (distance_matrix ),
93
94
distance_matrix .shape )
94
-
95
95
point = (simplex [i , :] + simplex [j , :]) / 2
96
96
97
97
if transform is not None :
@@ -100,6 +100,15 @@ def choose_point_in_simplex(simplex, transform=None):
100
100
return point
101
101
102
102
103
+ def _simplex_evaluation_priority (key ):
104
+ # We round the loss to 8 digits such that losses
105
+ # are equal up to numerical precision will be considered
106
+ # to be equal. This is needed because we want the learner
107
+ # to behave in a deterministic fashion.
108
+ loss , simplex , subsimplex = key
109
+ return - round (loss , ndigits = 8 ), simplex , subsimplex or (0 ,)
110
+
111
+
103
112
class LearnerND (BaseLearner ):
104
113
"""Learns and predicts a function 'f: ℝ^N → ℝ^M'.
105
114
@@ -200,7 +209,7 @@ def __init__(self, func, bounds, loss_per_simplex=None):
200
209
# so when popping an item, you should check that the simplex that has
201
210
# been returned has not been deleted. This checking is done by
202
211
# _pop_highest_existing_simplex
203
- self ._simplex_queue = [] # heap
212
+ self ._simplex_queue = SortedKeyList ( key = _simplex_evaluation_priority )
204
213
205
214
@property
206
215
def npoints (self ):
@@ -344,9 +353,7 @@ def _update_subsimplex_losses(self, simplex, new_subsimplices):
344
353
subtriangulation = self ._subtriangulations [simplex ]
345
354
for subsimplex in new_subsimplices :
346
355
subloss = subtriangulation .volume (subsimplex ) * loss_density
347
- subloss = round (subloss , ndigits = 8 )
348
- heapq .heappush (self ._simplex_queue ,
349
- (- subloss , simplex , subsimplex ))
356
+ self ._simplex_queue .add ((subloss , simplex , subsimplex ))
350
357
351
358
def _ask_and_tell_pending (self , n = 1 ):
352
359
xs , losses = zip (* (self ._ask () for _ in range (n )))
@@ -386,7 +393,7 @@ def _pop_highest_existing_simplex(self):
386
393
# find the simplex with the highest loss, we do need to check that the
387
394
# simplex hasn't been deleted yet
388
395
while len (self ._simplex_queue ):
389
- loss , simplex , subsimplex = heapq . heappop ( self ._simplex_queue )
396
+ loss , simplex , subsimplex = self ._simplex_queue . pop ( 0 )
390
397
if (subsimplex is None
391
398
and simplex in self .tri .simplices
392
399
and simplex not in self ._subtriangulations ):
@@ -462,8 +469,7 @@ def update_losses(self, to_delete: set, to_add: set):
462
469
self ._try_adding_pending_point_to_simplex (p , simplex )
463
470
464
471
if simplex not in self ._subtriangulations :
465
- loss = round (loss , ndigits = 8 )
466
- heapq .heappush (self ._simplex_queue , (- loss , simplex , None ))
472
+ self ._simplex_queue .add ((loss , simplex , None ))
467
473
continue
468
474
469
475
self ._update_subsimplex_losses (
@@ -488,7 +494,7 @@ def recompute_all_losses(self):
488
494
return
489
495
490
496
# reset the _simplex_queue
491
- self ._simplex_queue = []
497
+ self ._simplex_queue = SortedKeyList ( key = _simplex_evaluation_priority )
492
498
493
499
# recompute all losses
494
500
for simplex in self .tri .simplices :
@@ -497,8 +503,7 @@ def recompute_all_losses(self):
497
503
498
504
# now distribute it around the the children if they are present
499
505
if simplex not in self ._subtriangulations :
500
- loss = round (loss , ndigits = 8 )
501
- heapq .heappush (self ._simplex_queue , (- loss , simplex , None ))
506
+ self ._simplex_queue .add ((loss , simplex , None ))
502
507
continue
503
508
504
509
self ._update_subsimplex_losses (
0 commit comments