Skip to content

Commit 28f42cc

Browse files
Aarti Basantfacebook-github-bot
Aarti Basant
authored andcommitted
separating set_params and init() for checkpoint managers.
Summary: separating set_params and init() for checkpoint managers. Reviewed By: anshulverma Differential Revision: D6852255 fbshipit-source-id: 061f16ce0c49953ca8a5fe9546af5c9945a3be48
1 parent 1d044dc commit 28f42cc

File tree

1 file changed

+116
-24
lines changed

1 file changed

+116
-24
lines changed

caffe2/python/checkpoint.py

+116-24
Original file line numberDiff line numberDiff line change
@@ -207,16 +207,6 @@ def init(
207207
assert nodes is None or len(nodes) == 1, (
208208
'CheckpointManager only supports single node.')
209209

210-
self._path_prefix = path_prefix
211-
self._path_type = path_type
212-
if self._metadata_handler:
213-
self._metadata_handler.init(
214-
db_prefix=self._db_prefix,
215-
db_type=self._db_type,
216-
node_names=[str(self._node_name)],
217-
path_prefix=self._path_prefix,
218-
path_type=self._path_type)
219-
220210
with Task(outputs=[self._blob_names]) as task:
221211
if retrieve_from_epoch is None:
222212
ops.GetAllBlobNames(
@@ -314,15 +304,66 @@ def save(self, epoch):
314304
return task
315305

316306
def write_checkpoint_metadata(self, epoch):
307+
"""
308+
Write metadata for checkpoint
309+
310+
Args:
311+
epoch: An integer. The epoch-id for which checkpoint metadata is
312+
written
313+
"""
317314
if self._metadata_handler is not None:
318315
self._metadata_handler.write(epoch=epoch)
319316

320317
def get_resume_from_epoch_id(self, user_epoch=None):
318+
"""
319+
Identify the epoch-id from which Job must resume
320+
321+
Args:
322+
user_epoch: An integer. Optional parameter for user to explicitly
323+
identify the epoch-id to load checkpoint from
324+
Retruns:
325+
epoch: the epoch-id to load checkpoints from
326+
or None if no checkpoints were written
327+
"""
321328
last_epoch = user_epoch
322329
if self._metadata_handler is not None:
323330
last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch)
324331
return last_epoch
325332

333+
def set_params(self, nodes, path_prefix=None, path_type=None):
334+
"""Set parameters associated with CP manager
335+
336+
Args:
337+
nodes: An array of nodes where this checkpoint manager is running.
338+
path_prefix: Used to construct db name or path where checkpoint files are
339+
stored.
340+
path_type: Indicate the type of path where checkpoint files are stored.
341+
"""
342+
self._path_prefix = path_prefix
343+
self._path_type = path_type
344+
if self._metadata_handler:
345+
self._metadata_handler.set_params(
346+
db_prefix=self._db_prefix,
347+
db_type=self._db_type,
348+
node_names=[str(self._node_name)],
349+
path_prefix=self._path_prefix,
350+
path_type=self._path_type)
351+
352+
def cp_accessible(self, epoch=None):
353+
"""Returns True if Checkpoint data is accessible
354+
355+
Args:
356+
epoch: An integer. The epoch of the checkpoint. If None,
357+
it implies we need to check if checkpoint directory is accessible
358+
359+
Returns:
360+
is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible
361+
"""
362+
if self._metadata_handler is not None:
363+
return self._metadata_handler.cp_accessible(epoch)
364+
else:
365+
return True
366+
326367

327368
class MultiNodeCheckpointManager(object):
328369
"""
@@ -366,16 +407,6 @@ def init(
366407
assert [node for node, _ in self._node_managers] == nodes
367408
return TaskGroup(WorkspaceType.GLOBAL)
368409
self._node_managers = []
369-
self._path_prefix = path_prefix
370-
self._path_type = path_type
371-
self._node_names = [str(node) for node in nodes]
372-
if self._metadata_handler:
373-
self._metadata_handler.init(
374-
db_prefix=self._db_prefix,
375-
db_type=self._db_type,
376-
node_names=self._node_names,
377-
path_prefix=self._path_prefix,
378-
path_type=self._path_type)
379410
for node in nodes:
380411
with Node(node):
381412
manager = CheckpointManager(
@@ -450,18 +481,74 @@ def get_ckpt_db_name(self, node_name, epoch):
450481
return db_name(epoch, manager._node_name, manager._db_prefix)
451482

452483
def save(self, epoch):
484+
"""
485+
Build a Task that will execute a Save ops to serialize and persist
486+
blobs present in the global workspace.
487+
"""
453488
return self._task_group(CheckpointManager.save, epoch)
454489

455490
def write_checkpoint_metadata(self, epoch):
491+
"""
492+
Write metadata for checkpoint
493+
494+
Args:
495+
epoch: An integer. The epoch-id for which checkpoint metadata is
496+
written
497+
"""
456498
if self._metadata_handler is not None:
457499
self._metadata_handler.write(epoch=epoch)
458500

459501
def get_resume_from_epoch_id(self, user_epoch=None):
502+
"""
503+
Identify the epoch-id from which Job must resume
504+
505+
Args:
506+
user_epoch: An integer. Optional parameter for user to explicitly
507+
identify the epoch-id to load checkpoint from
508+
Retruns:
509+
epoch: the epoch-id to load checkpoints from
510+
or None if no checkpoints were written
511+
"""
460512
last_epoch = user_epoch
461513
if self._metadata_handler is not None:
462514
last_epoch = self._metadata_handler.last_epoch(user_epoch=user_epoch)
463515
return last_epoch
464516

517+
def set_params(self, nodes, path_prefix=None, path_type=None):
518+
"""Set parameters associated with CP manager
519+
520+
Args:
521+
nodes: An array of nodes where this checkpoint manager is running.
522+
path_prefix: Used to construct db name or path where checkpoint files are
523+
stored.
524+
path_type: Indicate the type of path where checkpoint files are stored.
525+
"""
526+
self._path_prefix = path_prefix
527+
self._path_type = path_type
528+
self._node_names = [str(node) for node in nodes]
529+
if self._metadata_handler:
530+
self._metadata_handler.set_params(
531+
db_prefix=self._db_prefix,
532+
db_type=self._db_type,
533+
node_names=self._node_names,
534+
path_prefix=self._path_prefix,
535+
path_type=self._path_type)
536+
537+
def cp_accessible(self, epoch=None):
538+
"""Returns True if Checkpoint data is accessible
539+
540+
Args:
541+
epoch: An integer. The epoch of the checkpoint. If None,
542+
it implies we need to check if checkpoint directory is accessible
543+
544+
Returns:
545+
is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible
546+
"""
547+
if self._metadata_handler is not None:
548+
return self._metadata_handler.cp_accessible(epoch)
549+
else:
550+
return True
551+
465552

466553
class UploadTaskGroupBuilder(object):
467554
"""A simple class to upload checkpoints."""
@@ -525,6 +612,7 @@ def __call__(self, session):
525612
"""
526613
# identify the epoch we must resume from
527614
if self.checkpoint_manager:
615+
self.checkpoint_manager.set_params(nodes=self.job.nodes_to_checkpoint())
528616
self.resume_from_epoch = self.checkpoint_manager.\
529617
get_resume_from_epoch_id(self.resume_from_epoch)
530618
if self.resume_from_epoch is not None:
@@ -627,10 +715,14 @@ def save_checkpoints(self, epoch, session):
627715
if not self.checkpoint_manager:
628716
raise ValueError('Checkpoint manager is None')
629717
try:
630-
logger.info('Saving checkpoints for epoch {}'.format(epoch))
631-
session.run(self.checkpoint_manager.save(epoch))
632-
self.checkpoint_manager.write_checkpoint_metadata(epoch)
633-
logger.info('Checkpoints saved')
718+
is_accessible = self.checkpoint_manager.cp_accessible(epoch=None)
719+
if is_accessible:
720+
logger.info('Saving checkpoints for epoch {}'.format(epoch))
721+
session.run(self.checkpoint_manager.save(epoch))
722+
self.checkpoint_manager.write_checkpoint_metadata(epoch)
723+
logger.info('Checkpoints saved')
724+
else:
725+
logger.warning("Checkpoint files cannot be accessed!")
634726
except Exception as ex:
635727
logger.warning("Unable to write checkpoint for epoch {}. Error={}".
636728
format(epoch, ex))

0 commit comments

Comments
 (0)