@@ -69,6 +69,22 @@ def _wrapped(loss_per_interval):
69
69
return _wrapped
70
70
71
71
72
+ def loss_returns (return_type , return_length ):
73
+ def _wrapped (loss_per_interval ):
74
+ loss_per_interval .return_type = return_type
75
+ loss_per_interval .return_length = return_length
76
+ return loss_per_interval
77
+ return _wrapped
78
+
79
+
80
+ def inf_format (return_type , return_len = None ):
81
+ is_iterable = hasattr (return_type , '__iter__' )
82
+ if is_iterable :
83
+ return return_type (return_len * [np .inf ])
84
+ else :
85
+ return return_type (np .inf )
86
+
87
+
72
88
@uses_nth_neighbors (0 )
73
89
def uniform_loss (xs , ys ):
74
90
"""Loss function that samples the domain uniformly.
@@ -287,7 +303,8 @@ def npoints(self):
287
303
def loss (self , real = True ):
288
304
losses = self .losses if real else self .losses_combined
289
305
if not losses :
290
- return np .inf
306
+ return inf_format (self .loss_per_interval .return_type ,
307
+ self .loss_per_interval .return_length )
291
308
max_interval , max_loss = losses .peekitem (0 )
292
309
return max_loss
293
310
@@ -660,8 +677,14 @@ def _set_data(self, data):
660
677
661
678
def loss_manager (x_scale ):
662
679
def sort_key (ival , loss ):
663
- loss , ival = finite_loss (ival , loss , x_scale )
664
- return - loss , ival
680
+ if isinstance (loss , Iterable ):
681
+ loss , ival = zip (* [finite_loss (ival , l , x_scale ) for l in loss ])
682
+ loss = tuple (- x for x in loss )
683
+ ival = ival [0 ]
684
+ else :
685
+ loss , ival = finite_loss (ival , loss , x_scale )
686
+ loss = - loss
687
+ return loss , ival
665
688
sorted_dict = sortedcollections .ItemSortedDict (sort_key )
666
689
return sorted_dict
667
690
0 commit comments