Skip to content

Commit 3c90103

Browse files
basnijholtjbweston
authored andcommitted
change the simplex_queue to a SortedKeyList
1 parent e17a27d commit 3c90103

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

adaptive/learner/learnerND.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010
from scipy import interpolate
1111
import scipy.spatial
12+
from sortedcontainers import SortedKeyList
1213

1314
from adaptive.learner.base_learner import BaseLearner
1415
from adaptive.notebook_integration import ensure_holoviews, ensure_plotly
@@ -91,7 +92,6 @@ def choose_point_in_simplex(simplex, transform=None):
9192
distance_matrix = scipy.spatial.distance.squareform(distances)
9293
i, j = np.unravel_index(np.argmax(distance_matrix),
9394
distance_matrix.shape)
94-
9595
point = (simplex[i, :] + simplex[j, :]) / 2
9696

9797
if transform is not None:
@@ -100,6 +100,15 @@ def choose_point_in_simplex(simplex, transform=None):
100100
return point
101101

102102

103+
def _simplex_evaluation_priority(key):
104+
# We round the loss to 8 digits such that losses
105+
# are equal up to numerical precision will be considered
106+
# to be equal. This is needed because we want the learner
107+
# to behave in a deterministic fashion.
108+
loss, simplex, subsimplex = key
109+
return -round(loss, ndigits=8), simplex, subsimplex or (0,)
110+
111+
103112
class LearnerND(BaseLearner):
104113
"""Learns and predicts a function 'f: ℝ^N → ℝ^M'.
105114
@@ -200,7 +209,7 @@ def __init__(self, func, bounds, loss_per_simplex=None):
200209
# so when popping an item, you should check that the simplex that has
201210
# been returned has not been deleted. This checking is done by
202211
# _pop_highest_existing_simplex
203-
self._simplex_queue = [] # heap
212+
self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority)
204213

205214
@property
206215
def npoints(self):
@@ -344,9 +353,7 @@ def _update_subsimplex_losses(self, simplex, new_subsimplices):
344353
subtriangulation = self._subtriangulations[simplex]
345354
for subsimplex in new_subsimplices:
346355
subloss = subtriangulation.volume(subsimplex) * loss_density
347-
subloss = round(subloss, ndigits=8)
348-
heapq.heappush(self._simplex_queue,
349-
(-subloss, simplex, subsimplex))
356+
self._simplex_queue.add((subloss, simplex, subsimplex))
350357

351358
def _ask_and_tell_pending(self, n=1):
352359
xs, losses = zip(*(self._ask() for _ in range(n)))
@@ -386,7 +393,7 @@ def _pop_highest_existing_simplex(self):
386393
# find the simplex with the highest loss, we do need to check that the
387394
# simplex hasn't been deleted yet
388395
while len(self._simplex_queue):
389-
loss, simplex, subsimplex = heapq.heappop(self._simplex_queue)
396+
loss, simplex, subsimplex = self._simplex_queue.pop(0)
390397
if (subsimplex is None
391398
and simplex in self.tri.simplices
392399
and simplex not in self._subtriangulations):
@@ -462,8 +469,7 @@ def update_losses(self, to_delete: set, to_add: set):
462469
self._try_adding_pending_point_to_simplex(p, simplex)
463470

464471
if simplex not in self._subtriangulations:
465-
loss = round(loss, ndigits=8)
466-
heapq.heappush(self._simplex_queue, (-loss, simplex, None))
472+
self._simplex_queue.add((loss, simplex, None))
467473
continue
468474

469475
self._update_subsimplex_losses(
@@ -488,7 +494,7 @@ def recompute_all_losses(self):
488494
return
489495

490496
# reset the _simplex_queue
491-
self._simplex_queue = []
497+
self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority)
492498

493499
# recompute all losses
494500
for simplex in self.tri.simplices:
@@ -497,8 +503,7 @@ def recompute_all_losses(self):
497503

498504
# now distribute it around the the children if they are present
499505
if simplex not in self._subtriangulations:
500-
loss = round(loss, ndigits=8)
501-
heapq.heappush(self._simplex_queue, (-loss, simplex, None))
506+
self._simplex_queue.add((loss, simplex, None))
502507
continue
503508

504509
self._update_subsimplex_losses(

adaptive/tests/test_learners.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,7 @@ def test_expected_loss_improvement_is_less_than_total_loss(learner_type, f, lear
362362

363363
# XXX: This *should* pass (https://gitlab.kwant-project.org/qt/adaptive/issues/84)
364364
# but we xfail it now, as Learner2D will be deprecated anyway
365-
# The LearnerND fails sometimes, see
366-
# https://gitlab.kwant-project.org/qt/adaptive/merge_requests/128#note_21807
367-
@run_with(Learner1D, xfail(Learner2D), xfail(LearnerND))
365+
@run_with(Learner1D, xfail(Learner2D), LearnerND)
368366
def test_learner_performance_is_invariant_under_scaling(learner_type, f, learner_kwargs):
369367
"""Learners behave identically under transformations that leave
370368
the loss invariant.

0 commit comments

Comments
 (0)