Skip to content

Commit 0ebd036

Browse files
committed
WIP: allow the loss to be a tuple in the Learner1D
1 parent 0e1aa7c commit 0ebd036

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

adaptive/learner/learner1D.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,22 @@ def _wrapped(loss_per_interval):
6969
return _wrapped
7070

7171

72+
def loss_returns(return_type, return_length):
73+
def _wrapped(loss_per_interval):
74+
loss_per_interval.return_type = return_type
75+
loss_per_interval.return_length = return_length
76+
return loss_per_interval
77+
return _wrapped
78+
79+
80+
def inf_format(return_type, return_len=None):
81+
is_iterable = hasattr(return_type, '__iter__')
82+
if is_iterable:
83+
return return_type(return_len * [np.inf])
84+
else:
85+
return return_type(np.inf)
86+
87+
7288
@uses_nth_neighbors(0)
7389
def uniform_loss(xs, ys):
7490
"""Loss function that samples the domain uniformly.
@@ -287,7 +303,8 @@ def npoints(self):
287303
def loss(self, real=True):
288304
losses = self.losses if real else self.losses_combined
289305
if not losses:
290-
return np.inf
306+
return inf_format(self.loss_per_interval.return_type,
307+
self.loss_per_interval.return_length)
291308
max_interval, max_loss = losses.peekitem(0)
292309
return max_loss
293310

@@ -660,8 +677,14 @@ def _set_data(self, data):
660677

661678
def loss_manager(x_scale):
662679
def sort_key(ival, loss):
663-
loss, ival = finite_loss(ival, loss, x_scale)
664-
return -loss, ival
680+
if isinstance(loss, Iterable):
681+
loss, ival = zip(*[finite_loss(ival, l, x_scale) for l in loss])
682+
loss = tuple(-x for x in loss)
683+
ival = ival[0]
684+
else:
685+
loss, ival = finite_loss(ival, loss, x_scale)
686+
loss = -loss
687+
return loss, ival
665688
sorted_dict = sortedcollections.ItemSortedDict(sort_key)
666689
return sorted_dict
667690

0 commit comments

Comments
 (0)