@@ -207,16 +207,6 @@ def init(
207
207
assert nodes is None or len (nodes ) == 1 , (
208
208
'CheckpointManager only supports single node.' )
209
209
210
- self ._path_prefix = path_prefix
211
- self ._path_type = path_type
212
- if self ._metadata_handler :
213
- self ._metadata_handler .init (
214
- db_prefix = self ._db_prefix ,
215
- db_type = self ._db_type ,
216
- node_names = [str (self ._node_name )],
217
- path_prefix = self ._path_prefix ,
218
- path_type = self ._path_type )
219
-
220
210
with Task (outputs = [self ._blob_names ]) as task :
221
211
if retrieve_from_epoch is None :
222
212
ops .GetAllBlobNames (
@@ -314,15 +304,66 @@ def save(self, epoch):
314
304
return task
315
305
316
306
def write_checkpoint_metadata (self , epoch ):
307
+ """
308
+ Write metadata for checkpoint
309
+
310
+ Args:
311
+ epoch: An integer. The epoch-id for which checkpoint metadata is
312
+ written
313
+ """
317
314
if self ._metadata_handler is not None :
318
315
self ._metadata_handler .write (epoch = epoch )
319
316
320
317
def get_resume_from_epoch_id (self , user_epoch = None ):
318
+ """
319
+ Identify the epoch-id from which Job must resume
320
+
321
+ Args:
322
+ user_epoch: An integer. Optional parameter for user to explicitly
323
+ identify the epoch-id to load checkpoint from
324
+ Retruns:
325
+ epoch: the epoch-id to load checkpoints from
326
+ or None if no checkpoints were written
327
+ """
321
328
last_epoch = user_epoch
322
329
if self ._metadata_handler is not None :
323
330
last_epoch = self ._metadata_handler .last_epoch (user_epoch = user_epoch )
324
331
return last_epoch
325
332
333
+ def set_params (self , nodes , path_prefix = None , path_type = None ):
334
+ """Set parameters associated with CP manager
335
+
336
+ Args:
337
+ nodes: An array of nodes where this checkpoint manager is running.
338
+ path_prefix: Used to construct db name or path where checkpoint files are
339
+ stored.
340
+ path_type: Indicate the type of path where checkpoint files are stored.
341
+ """
342
+ self ._path_prefix = path_prefix
343
+ self ._path_type = path_type
344
+ if self ._metadata_handler :
345
+ self ._metadata_handler .set_params (
346
+ db_prefix = self ._db_prefix ,
347
+ db_type = self ._db_type ,
348
+ node_names = [str (self ._node_name )],
349
+ path_prefix = self ._path_prefix ,
350
+ path_type = self ._path_type )
351
+
352
+ def cp_accessible (self , epoch = None ):
353
+ """Returns True if Checkpoint data is accessible
354
+
355
+ Args:
356
+ epoch: An integer. The epoch of the checkpoint. If None,
357
+ it implies we need to check if checkpoint directory is accessible
358
+
359
+ Returns:
360
+ is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible
361
+ """
362
+ if self ._metadata_handler is not None :
363
+ return self ._metadata_handler .cp_accessible (epoch )
364
+ else :
365
+ return True
366
+
326
367
327
368
class MultiNodeCheckpointManager (object ):
328
369
"""
@@ -366,16 +407,6 @@ def init(
366
407
assert [node for node , _ in self ._node_managers ] == nodes
367
408
return TaskGroup (WorkspaceType .GLOBAL )
368
409
self ._node_managers = []
369
- self ._path_prefix = path_prefix
370
- self ._path_type = path_type
371
- self ._node_names = [str (node ) for node in nodes ]
372
- if self ._metadata_handler :
373
- self ._metadata_handler .init (
374
- db_prefix = self ._db_prefix ,
375
- db_type = self ._db_type ,
376
- node_names = self ._node_names ,
377
- path_prefix = self ._path_prefix ,
378
- path_type = self ._path_type )
379
410
for node in nodes :
380
411
with Node (node ):
381
412
manager = CheckpointManager (
@@ -450,18 +481,74 @@ def get_ckpt_db_name(self, node_name, epoch):
450
481
return db_name (epoch , manager ._node_name , manager ._db_prefix )
451
482
452
483
def save (self , epoch ):
484
+ """
485
+ Build a Task that will execute a Save ops to serialize and persist
486
+ blobs present in the global workspace.
487
+ """
453
488
return self ._task_group (CheckpointManager .save , epoch )
454
489
455
490
def write_checkpoint_metadata (self , epoch ):
491
+ """
492
+ Write metadata for checkpoint
493
+
494
+ Args:
495
+ epoch: An integer. The epoch-id for which checkpoint metadata is
496
+ written
497
+ """
456
498
if self ._metadata_handler is not None :
457
499
self ._metadata_handler .write (epoch = epoch )
458
500
459
501
def get_resume_from_epoch_id (self , user_epoch = None ):
502
+ """
503
+ Identify the epoch-id from which Job must resume
504
+
505
+ Args:
506
+ user_epoch: An integer. Optional parameter for user to explicitly
507
+ identify the epoch-id to load checkpoint from
508
+ Retruns:
509
+ epoch: the epoch-id to load checkpoints from
510
+ or None if no checkpoints were written
511
+ """
460
512
last_epoch = user_epoch
461
513
if self ._metadata_handler is not None :
462
514
last_epoch = self ._metadata_handler .last_epoch (user_epoch = user_epoch )
463
515
return last_epoch
464
516
517
+ def set_params (self , nodes , path_prefix = None , path_type = None ):
518
+ """Set parameters associated with CP manager
519
+
520
+ Args:
521
+ nodes: An array of nodes where this checkpoint manager is running.
522
+ path_prefix: Used to construct db name or path where checkpoint files are
523
+ stored.
524
+ path_type: Indicate the type of path where checkpoint files are stored.
525
+ """
526
+ self ._path_prefix = path_prefix
527
+ self ._path_type = path_type
528
+ self ._node_names = [str (node ) for node in nodes ]
529
+ if self ._metadata_handler :
530
+ self ._metadata_handler .set_params (
531
+ db_prefix = self ._db_prefix ,
532
+ db_type = self ._db_type ,
533
+ node_names = self ._node_names ,
534
+ path_prefix = self ._path_prefix ,
535
+ path_type = self ._path_type )
536
+
537
+ def cp_accessible (self , epoch = None ):
538
+ """Returns True if Checkpoint data is accessible
539
+
540
+ Args:
541
+ epoch: An integer. The epoch of the checkpoint. If None,
542
+ it implies we need to check if checkpoint directory is accessible
543
+
544
+ Returns:
545
+ is_cp_accessible: A boolean. Returns True if Checkpoint data is accessible
546
+ """
547
+ if self ._metadata_handler is not None :
548
+ return self ._metadata_handler .cp_accessible (epoch )
549
+ else :
550
+ return True
551
+
465
552
466
553
class UploadTaskGroupBuilder (object ):
467
554
"""A simple class to upload checkpoints."""
@@ -525,6 +612,7 @@ def __call__(self, session):
525
612
"""
526
613
# identify the epoch we must resume from
527
614
if self .checkpoint_manager :
615
+ self .checkpoint_manager .set_params (nodes = self .job .nodes_to_checkpoint ())
528
616
self .resume_from_epoch = self .checkpoint_manager .\
529
617
get_resume_from_epoch_id (self .resume_from_epoch )
530
618
if self .resume_from_epoch is not None :
@@ -627,10 +715,14 @@ def save_checkpoints(self, epoch, session):
627
715
if not self .checkpoint_manager :
628
716
raise ValueError ('Checkpoint manager is None' )
629
717
try :
630
- logger .info ('Saving checkpoints for epoch {}' .format (epoch ))
631
- session .run (self .checkpoint_manager .save (epoch ))
632
- self .checkpoint_manager .write_checkpoint_metadata (epoch )
633
- logger .info ('Checkpoints saved' )
718
+ is_accessible = self .checkpoint_manager .cp_accessible (epoch = None )
719
+ if is_accessible :
720
+ logger .info ('Saving checkpoints for epoch {}' .format (epoch ))
721
+ session .run (self .checkpoint_manager .save (epoch ))
722
+ self .checkpoint_manager .write_checkpoint_metadata (epoch )
723
+ logger .info ('Checkpoints saved' )
724
+ else :
725
+ logger .warning ("Checkpoint files cannot be accessed!" )
634
726
except Exception as ex :
635
727
logger .warning ("Unable to write checkpoint for epoch {}. Error={}" .
636
728
format (epoch , ex ))
0 commit comments