Skip to content

Commit 15a4de5

Browse files
Wang Zhoufacebook-github-bot
Wang Zhou
authored andcommitted
Call finish_checkpoint_resume to flush optimizer step (#2884)
Summary: X-link: pytorch/torchrec#2244 Pull Request resolved: #2884 `iter` number is a critical counter in computing how much decay is needed together with `prev_iter`. However, `iter` is not checkpointed since it is not in `split_optimizer_state`. Instead, it relies on external calls of `set_optimizer_step` to flush the counter, which relies on explict calls of `finish_checkpoint_resume` to trigger `flush_state` in `FullSyncOptimizer`. This chain of actions is not properly set up in existing code, which causes the issue that after a preemption, `iter` gets reset to zero. Reviewed By: yuxihu Differential Revision: D60155781 fbshipit-source-id: 51b85fe10fb196df52bfdfc22424e6ab6f8c5bfb
1 parent 0e561d4 commit 15a4de5

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2125,11 +2125,12 @@ def _set_learning_rate(self, lr: float) -> float:
21252125
self.optimizer_args = self.optimizer_args._replace(learning_rate=lr)
21262126
return 0.0
21272127

2128-
@torch.jit.export
2128+
@torch.jit.ignore
21292129
def set_optimizer_step(self, step: int) -> None:
21302130
"""
21312131
Sets the optimizer step.
21322132
"""
2133+
self.log(f"set_optimizer_step from {self.iter[0]} to {step}")
21332134
if self.optimizer == OptimType.NONE:
21342135
raise NotImplementedError(
21352136
f"Setting optimizer step is not supported for {self.optimizer}"

0 commit comments

Comments
 (0)