Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: allow the loss to be a tuple in the Learner1D #154

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions adaptive/learner/learner1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,28 @@ def _wrapped(loss_per_interval):
return _wrapped


def loss_returns(return_type, return_length):
def _wrapped(loss_per_interval):
loss_per_interval.return_type = return_type
loss_per_interval.return_length = return_length
return loss_per_interval
return _wrapped


def inf_format(return_type, return_len=None):
is_iterable = hasattr(return_type, '__iter__')
if is_iterable:
return return_type(return_len * [np.inf])
else:
return return_type(np.inf)


def ensure_tuple(x):
if not isinstance(x, Iterable):
x = (x,)
return x


@uses_nth_neighbors(0)
def uniform_loss(xs, ys):
"""Loss function that samples the domain uniformly.
Expand Down Expand Up @@ -287,7 +309,8 @@ def npoints(self):
def loss(self, real=True):
losses = self.losses if real else self.losses_combined
if not losses:
return np.inf
return inf_format(self.loss_per_interval.return_type,
self.loss_per_interval.return_length)
max_interval, max_loss = losses.peekitem(0)
return max_loss

Expand Down Expand Up @@ -325,7 +348,7 @@ def _get_loss_in_interval(self, x_left, x_right):
ys_scaled = tuple(self._scale_y(y) for y in ys)

# we need to compute the loss for this interval
return self.loss_per_interval(xs_scaled, ys_scaled)
return ensure_tuple(self.loss_per_interval(xs_scaled, ys_scaled))

def _update_interpolated_loss_in_interval(self, x_left, x_right):
if x_left is None or x_right is None:
Expand Down Expand Up @@ -379,13 +402,17 @@ def _update_losses(self, x, real=True):
left_loss_is_unknown = ((x_left is None) or
(not real and x_right is None))
if (a is not None) and left_loss_is_unknown:
self.losses_combined[a, x] = float('inf')
self.losses_combined[a, x] = inf_format(
self.loss_per_interval.return_type,
self.loss_per_interval.return_length)

# (no real point right of x) or (no real point left of b)
right_loss_is_unknown = ((x_right is None) or
(not real and x_left is None))
if (b is not None) and right_loss_is_unknown:
self.losses_combined[x, b] = float('inf')
self.losses_combined[x, b] = inf_format(
self.loss_per_interval.return_type,
self.loss_per_interval.return_length)

@staticmethod
def _find_neighbors(x, neighbors):
Expand Down Expand Up @@ -660,8 +687,8 @@ def _set_data(self, data):

def loss_manager(x_scale):
def sort_key(ival, loss):
loss, ival = finite_loss(ival, loss, x_scale)
return -loss, ival
loss = [-finite_loss(ival, l, x_scale)[0] for l in loss]
return loss, ival
sorted_dict = sortedcollections.ItemSortedDict(sort_key)
return sorted_dict

Expand Down