Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 403cb55

Browse files
committedMar 8, 2019
WIP: allow the loss to be a tuple in the Learner1D
1 parent 0e1aa7c commit 403cb55

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed
 

‎adaptive/learner/learner1D.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,22 @@ def _wrapped(loss_per_interval):
6969
return _wrapped
7070

7171

72+
def loss_returns(return_type, return_length):
73+
def _wrapped(loss_per_interval):
74+
loss_per_interval.return_type = return_type
75+
loss_per_interval.return_length = return_length
76+
return loss_per_interval
77+
return _wrapped
78+
79+
80+
def inf_format(return_type, return_len=None):
81+
is_iterable = hasattr(return_type, '__iter__')
82+
if is_iterable:
83+
return return_type(return_len * [np.inf])
84+
else:
85+
return return_type(np.inf)
86+
87+
7288
@uses_nth_neighbors(0)
7389
def uniform_loss(xs, ys):
7490
"""Loss function that samples the domain uniformly.
@@ -287,7 +303,8 @@ def npoints(self):
287303
def loss(self, real=True):
288304
losses = self.losses if real else self.losses_combined
289305
if not losses:
290-
return np.inf
306+
return inf_format(self.loss_per_interval.return_type,
307+
self.loss_per_interval.return_length)
291308
max_interval, max_loss = losses.peekitem(0)
292309
return max_loss
293310

@@ -661,7 +678,8 @@ def _set_data(self, data):
661678
def loss_manager(x_scale):
662679
def sort_key(ival, loss):
663680
loss, ival = finite_loss(ival, loss, x_scale)
664-
return -loss, ival
681+
loss = tuple(-l for l in loss) if isinstance(loss, tuple) else -loss
682+
return loss, ival
665683
sorted_dict = sortedcollections.ItemSortedDict(sort_key)
666684
return sorted_dict
667685

0 commit comments

Comments
 (0)
Please sign in to comment.