Skip to content

Commit 697193b

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Fix ThroughputLogger error on intra-epoch restore (#829)
Summary: Pull Request resolved: #829 Reviewed By: galrotem, JKSenthil Differential Revision: D57183157 fbshipit-source-id: 839dfb149a89d4c92c87f725a613b8e7104b499c
1 parent 735bfbc commit 697193b

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

tests/framework/callbacks/test_throughput_logger.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,10 @@ def test_log_for_epoch(self) -> None:
272272
throughput_logger.on_train_step_end(state, unit)
273273
self.assertEqual(throughput_logger._steps_in_epoch[ActivePhase.TRAIN], 1)
274274

275+
# Make sure we don't log or fail if the _epoch_start_times dict is not initialized
276+
throughput_logger._log_for_epoch(state, epoch_logging_for=15)
277+
logger.log.assert_not_called()
278+
275279
with patch("time.perf_counter", return_value=0.5):
276280
throughput_logger.on_train_epoch_start(state, MagicMock(spec=TrainUnit))
277281
self.assertEqual(throughput_logger._epoch_start_times[ActivePhase.TRAIN], 0.5)

torchtnt/framework/callbacks/throughput_logger.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-strict
88

99

10+
import logging
1011
import time
1112
from collections import defaultdict
1213
from typing import Dict, Mapping
@@ -30,6 +31,8 @@
3031
ActivePhase.PREDICT: "Predict",
3132
}
3233

34+
logger: logging.Logger = logging.getLogger(__name__)
35+
3336

3437
class ThroughputLogger(Callback):
3538
"""
@@ -192,6 +195,18 @@ def _log_for_epoch(
192195
*,
193196
epoch_logging_for: int,
194197
) -> None:
198+
199+
# Avoid key errors if active phase is not set. This may happen if we restore
200+
# from an intra-epoch checkpoint.
201+
if (
202+
state.active_phase not in self._epoch_start_times
203+
or state.active_phase not in self._steps_in_epoch
204+
):
205+
logger.warning(
206+
f"Missing troughput data for epoch {epoch_logging_for}, phase {state.active_phase}. Ommiting troughput logging."
207+
)
208+
return
209+
195210
time_since_epoch_start = (
196211
time.perf_counter() - self._epoch_start_times[state.active_phase]
197212
)

0 commit comments

Comments
 (0)