@@ -649,6 +649,7 @@ def __init__(
649
649
650
650
self .task = self .ioloop .create_task (self ._run ())
651
651
self .saving_task = None
652
+ self .callbacks = []
652
653
if in_ipynb () and not self .ioloop .is_running ():
653
654
warnings .warn (
654
655
"The runner has been scheduled, but the asyncio "
@@ -753,6 +754,31 @@ def elapsed_time(self):
753
754
end_time = time .time ()
754
755
return end_time - self .start_time
755
756
757
+ def add_periodic_callback (
758
+ self ,
759
+ method : Callable [[AsyncRunner ]],
760
+ interval : int = 30 ,
761
+ ):
762
+ """Start a periodic callback that calls the given method on the runner.
763
+
764
+ Parameters
765
+ ----------
766
+ method : callable
767
+ The method to call periodically.
768
+ interval : int
769
+ The interval in seconds between the calls.
770
+ """
771
+
772
+ async def _callback ():
773
+ while self .status () == "running" :
774
+ method (self )
775
+ await asyncio .sleep (interval )
776
+ method (self ) # one last time
777
+
778
+ task = self .ioloop .create_task (_callback ())
779
+ self .callbacks .append (task )
780
+ return task
781
+
756
782
def start_periodic_saving (
757
783
self ,
758
784
save_kwargs : dict [str , Any ] | None = None ,
@@ -781,6 +807,8 @@ def start_periodic_saving(
781
807
... save_kwargs=dict(fname='data/test.pickle'),
782
808
... interval=600)
783
809
"""
810
+ if self .saving_task is not None :
811
+ raise RuntimeError ("Already saving." )
784
812
785
813
def default_save (learner ):
786
814
learner .save (** save_kwargs )
@@ -790,13 +818,7 @@ def default_save(learner):
790
818
if save_kwargs is None :
791
819
raise ValueError ("Must provide `save_kwargs` if method=None." )
792
820
793
- async def _saver ():
794
- while self .status () == "running" :
795
- method (self .learner )
796
- await asyncio .sleep (interval )
797
- method (self .learner ) # one last time
798
-
799
- self .saving_task = self .ioloop .create_task (_saver ())
821
+ self .saving_task = self .add_periodic_callback (method , interval = interval )
800
822
return self .saving_task
801
823
802
824
0 commit comments