@@ -207,16 +207,6 @@ def init(
207207 assert nodes is None or len (nodes ) == 1 , (
208208 'CheckpointManager only supports single node.' )
209209
210- self ._path_prefix = path_prefix
211- self ._path_type = path_type
212- if self ._metadata_handler :
213- self ._metadata_handler .init (
214- db_prefix = self ._db_prefix ,
215- db_type = self ._db_type ,
216- node_names = [str (self ._node_name )],
217- path_prefix = self ._path_prefix ,
218- path_type = self ._path_type )
219-
220210 with Task (outputs = [self ._blob_names ]) as task :
221211 if retrieve_from_epoch is None :
222212 ops .GetAllBlobNames (
@@ -314,15 +304,66 @@ def save(self, epoch):
314304 return task
315305
316306 def write_checkpoint_metadata (self , epoch ):
307+ """
308+ Write metadata for checkpoint
309+
310+ Args:
311+ epoch: An integer. The epoch-id for which checkpoint metadata is
312+ written
313+ """
317314 if self ._metadata_handler is not None :
318315 self ._metadata_handler .write (epoch = epoch )
319316
320317 def get_resume_from_epoch_id (self , user_epoch = None ):
318+ """
319+ Identify the epoch-id from which Job must resume
320+
321+ Args:
322+ user_epoch: An integer. Optional parameter for user to explicitly
323+ identify the epoch-id to load checkpoint from
324+ Retruns:
325+ epoch: the epoch-id to load checkpoints from
326+ or None if no checkpoints were written
327+ """
321328 last_epoch = user_epoch
322329 if self ._metadata_handler is not None :
323330 last_epoch = self ._metadata_handler .last_epoch (user_epoch = user_epoch )
324331 return last_epoch
325332
333+ def set_params (self , nodes , path_prefix = None , path_type = None ):
334+ """Set parameters associated with CP manager
335+
336+ Args:
337+ nodes: An array of nodes where this checkpoint manager is running.
338+ path_prefix: Used to construct db name or path where checkpoint files are
339+ stored.
340+ path_type: Indicate the type of path where checkpoint files are stored.
341+ """
342+ self ._path_prefix = path_prefix
343+ self ._path_type = path_type
344+ if self ._metadata_handler :
345+ self ._metadata_handler .set_params (
346+ db_prefix = self ._db_prefix ,
347+ db_type = self ._db_type ,
348+ node_names = [str (self ._node_name )],
349+ path_prefix = self ._path_prefix ,
350+ path_type = self ._path_type )
351+
352+ def cp_accessible (self , epoch = None ):
353+ """Returns True if Checkpoint data is accessible
354+
355+ Args:
356+ epoch: An integer. The epoch of the checkpoint. If None,
357+ it implies we need to check if checkpoint directory is accessible
358+
359+ Returns:
360+ is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible
361+ """
362+ if self ._metadata_handler is not None :
363+ return self ._metadata_handler .cp_accessible (epoch )
364+ else :
365+ return True
366+
326367
327368class MultiNodeCheckpointManager (object ):
328369 """
@@ -366,16 +407,6 @@ def init(
366407 assert [node for node , _ in self ._node_managers ] == nodes
367408 return TaskGroup (WorkspaceType .GLOBAL )
368409 self ._node_managers = []
369- self ._path_prefix = path_prefix
370- self ._path_type = path_type
371- self ._node_names = [str (node ) for node in nodes ]
372- if self ._metadata_handler :
373- self ._metadata_handler .init (
374- db_prefix = self ._db_prefix ,
375- db_type = self ._db_type ,
376- node_names = self ._node_names ,
377- path_prefix = self ._path_prefix ,
378- path_type = self ._path_type )
379410 for node in nodes :
380411 with Node (node ):
381412 manager = CheckpointManager (
@@ -450,18 +481,74 @@ def get_ckpt_db_name(self, node_name, epoch):
450481 return db_name (epoch , manager ._node_name , manager ._db_prefix )
451482
452483 def save (self , epoch ):
484+ """
485+ Build a Task that will execute a Save ops to serialize and persist
486+ blobs present in the global workspace.
487+ """
453488 return self ._task_group (CheckpointManager .save , epoch )
454489
455490 def write_checkpoint_metadata (self , epoch ):
491+ """
492+ Write metadata for checkpoint
493+
494+ Args:
495+ epoch: An integer. The epoch-id for which checkpoint metadata is
496+ written
497+ """
456498 if self ._metadata_handler is not None :
457499 self ._metadata_handler .write (epoch = epoch )
458500
459501 def get_resume_from_epoch_id (self , user_epoch = None ):
502+ """
503+ Identify the epoch-id from which Job must resume
504+
505+ Args:
506+ user_epoch: An integer. Optional parameter for user to explicitly
507+ identify the epoch-id to load checkpoint from
508+ Retruns:
509+ epoch: the epoch-id to load checkpoints from
510+ or None if no checkpoints were written
511+ """
460512 last_epoch = user_epoch
461513 if self ._metadata_handler is not None :
462514 last_epoch = self ._metadata_handler .last_epoch (user_epoch = user_epoch )
463515 return last_epoch
464516
517+ def set_params (self , nodes , path_prefix = None , path_type = None ):
518+ """Set parameters associated with CP manager
519+
520+ Args:
521+ nodes: An array of nodes where this checkpoint manager is running.
522+ path_prefix: Used to construct db name or path where checkpoint files are
523+ stored.
524+ path_type: Indicate the type of path where checkpoint files are stored.
525+ """
526+ self ._path_prefix = path_prefix
527+ self ._path_type = path_type
528+ self ._node_names = [str (node ) for node in nodes ]
529+ if self ._metadata_handler :
530+ self ._metadata_handler .set_params (
531+ db_prefix = self ._db_prefix ,
532+ db_type = self ._db_type ,
533+ node_names = self ._node_names ,
534+ path_prefix = self ._path_prefix ,
535+ path_type = self ._path_type )
536+
537+ def cp_accessible (self , epoch = None ):
538+ """Returns True if Checkpoint data is accessible
539+
540+ Args:
541+ epoch: An integer. The epoch of the checkpoint. If None,
542+ it implies we need to check if checkpoint directory is accessible
543+
544+ Returns:
545+ is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible
546+ """
547+ if self ._metadata_handler is not None :
548+ return self ._metadata_handler .cp_accessible (epoch )
549+ else :
550+ return True
551+
465552
466553class UploadTaskGroupBuilder (object ):
467554 """A simple class to upload checkpoints."""
@@ -525,6 +612,7 @@ def __call__(self, session):
525612 """
526613 # identify the epoch we must resume from
527614 if self .checkpoint_manager :
615+ self .checkpoint_manager .set_params (nodes = self .job .nodes_to_checkpoint ())
528616 self .resume_from_epoch = self .checkpoint_manager .\
529617 get_resume_from_epoch_id (self .resume_from_epoch )
530618 if self .resume_from_epoch is not None :
@@ -627,10 +715,14 @@ def save_checkpoints(self, epoch, session):
627715 if not self .checkpoint_manager :
628716 raise ValueError ('Checkpoint manager is None' )
629717 try :
630- logger .info ('Saving checkpoints for epoch {}' .format (epoch ))
631- session .run (self .checkpoint_manager .save (epoch ))
632- self .checkpoint_manager .write_checkpoint_metadata (epoch )
633- logger .info ('Checkpoints saved' )
718+ is_accessible = self .checkpoint_manager .cp_accessible (epoch = None )
719+ if is_accessible :
720+ logger .info ('Saving checkpoints for epoch {}' .format (epoch ))
721+ session .run (self .checkpoint_manager .save (epoch ))
722+ self .checkpoint_manager .write_checkpoint_metadata (epoch )
723+ logger .info ('Checkpoints saved' )
724+ else :
725+ logger .warning ("Checkpoint files cannot be accessed!" )
634726 except Exception as ex :
635727 logger .warning ("Unable to write checkpoint for epoch {}. Error={}" .
636728 format (epoch , ex ))
0 commit comments