@@ -676,12 +676,16 @@ def elapsed_time(self):
676
676
return end_time - self .start_time
677
677
678
678
def cancel_point (
679
- self , point : Any | None = None , future : asyncio . Future | None = None
679
+ self , future : asyncio . Future | None = None , point : Any | None = None
680
680
):
681
- """Cancel a point that is currently being evaluated.
681
+ """Cancel a future or point that is currently being evaluated.
682
+
683
+ Either the ``future`` or the ``point`` must be provided.
682
684
683
685
Parameters
684
686
----------
687
+ future : asyncio.Future
688
+ The future that is currently being evaluated.
685
689
point
686
690
The point that should be cancelled.
687
691
"""
@@ -691,9 +695,9 @@ def cancel_point(
691
695
future = next (fut for fut , p in self .pending_points if p == point )
692
696
future .cancel ()
693
697
694
- def add_periodic_callback (
698
+ def start_periodic_callback (
695
699
self ,
696
- method : Callable [[AsyncRunner ]],
700
+ method : Callable [[AsyncRunner ], None ],
697
701
interval : int = 30 ,
698
702
):
699
703
"""Start a periodic callback that calls the given method on the runner.
@@ -753,9 +757,11 @@ def default_save(learner):
753
757
if method is None :
754
758
method = default_save
755
759
if save_kwargs is None :
756
- raise ValueError ("Must provide `save_kwargs` if method=None." )
760
+ raise ValueError ("Must provide `save_kwargs` if ` method=None` ." )
757
761
758
- self .saving_task = self .add_periodic_callback (method , interval = interval )
762
+ self .saving_task = self .start_periodic_callback (
763
+ lambda r : method (r .learner ), interval = interval
764
+ )
759
765
return self .saving_task
760
766
761
767
0 commit comments