-
Notifications
You must be signed in to change notification settings - Fork 282
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #790 Adding a time-wait-for-batch logger callback. Reviewed By: JKSenthil Differential Revision: D56315489 fbshipit-source-id: 5fa9210114231c3c7d97d4872252cb8bf659b2d7
- Loading branch information
1 parent
35f9d92
commit 2409e14
Showing
4 changed files
with
221 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
128 changes: 128 additions & 0 deletions
128
tests/framework/callbacks/test_time_wait_for_batch_logger.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import unittest | ||
from unittest.mock import ANY, call, MagicMock | ||
|
||
import torch | ||
from pyre_extensions import none_throws | ||
|
||
from torch.utils.tensorboard import SummaryWriter | ||
from torchtnt.framework._callback_handler import CallbackHandler | ||
from torchtnt.framework._test_utils import ( | ||
DummyAutoUnit, | ||
DummyPredictUnit, | ||
generate_random_dataloader, | ||
) | ||
from torchtnt.framework.callbacks.time_wait_for_batch_logger import ( | ||
TimeWaitForBatchLogger, | ||
) | ||
from torchtnt.framework.predict import predict | ||
|
||
from torchtnt.framework.state import EntryPoint, PhaseState, State | ||
from torchtnt.framework.train import _train_impl | ||
from torchtnt.utils.loggers.logger import MetricLogger | ||
from torchtnt.utils.timer import TimerProtocol | ||
|
||
|
||
class TimeWaitForBatchLoggerTest(unittest.TestCase): | ||
def test_log_step_metrics(self) -> None: | ||
for spec in [MetricLogger, SummaryWriter]: | ||
with self.subTest(spec=spec): | ||
logger = MagicMock(spec=spec) | ||
log_method = logger.log if spec is MetricLogger else logger.add_scalar | ||
|
||
twfb_logger = TimeWaitForBatchLogger(logger=logger, log_every_n_steps=2) | ||
timer = MagicMock(spec=TimerProtocol) | ||
timer.recorded_durations = {"data_wait_time": [1, 2, 3]} | ||
twfb_logger._log_step_metrics(timer=timer, label="foo", step=1) | ||
log_method.assert_not_called() | ||
twfb_logger._log_step_metrics(timer=timer, label="foo", step=2) | ||
log_method.assert_has_calls( | ||
[ | ||
call( | ||
"foo", | ||
3, # last element in the data wait time list | ||
2, # step | ||
) | ||
], | ||
) | ||
|
||
def test_comparing_twfb_logging_time(self) -> None: | ||
dataloader = generate_random_dataloader( | ||
num_samples=8, input_dim=2, batch_size=2 | ||
) | ||
state = State( | ||
entry_point=EntryPoint.FIT, | ||
train_state=PhaseState( | ||
dataloader=dataloader, | ||
max_epochs=2, | ||
max_steps_per_epoch=2, | ||
), | ||
eval_state=PhaseState( | ||
dataloader=dataloader, | ||
max_steps_per_epoch=1, | ||
evaluate_every_n_epochs=1, | ||
), | ||
) | ||
|
||
logger = MagicMock(spec=MetricLogger) | ||
# we want to be able to compare the logging value to the state, so we need to create state manually and | ||
# call _train_impl. This would have been similar to calling fit() and getting the state as a ret value | ||
_train_impl( | ||
state, | ||
DummyAutoUnit(module=torch.nn.Linear(2, 2)), | ||
CallbackHandler( | ||
[TimeWaitForBatchLogger(logger=logger, log_every_n_steps=1)] | ||
), | ||
) | ||
train_twfb_durations = none_throws( | ||
state.train_state | ||
).iteration_timer.recorded_durations["data_wait_time"] | ||
eval_iteration_timer = none_throws( | ||
state.eval_state | ||
).iteration_timer.recorded_durations["data_wait_time"] | ||
|
||
expected_training_iteration_time_calls = [ | ||
call("Time Wait For Batch (Train)", train_twfb_durations[i], i + 1) | ||
for i in range(4) | ||
] | ||
expected_eval_iteration_time_calls = [ | ||
call("Time Wait For Batch (Eval)", eval_iteration_timer[i], i + 1) | ||
for i in range(2) | ||
] | ||
|
||
logger.log.assert_has_calls( | ||
expected_training_iteration_time_calls + expected_eval_iteration_time_calls, | ||
any_order=True, | ||
) | ||
|
||
def test_with_predict(self) -> None: | ||
logger = MagicMock(spec=MetricLogger) | ||
predict( | ||
DummyPredictUnit(input_dim=2), | ||
generate_random_dataloader(num_samples=8, input_dim=2, batch_size=2), | ||
max_steps_per_epoch=1, | ||
callbacks=[TimeWaitForBatchLogger(logger=logger, log_every_n_steps=1)], | ||
) | ||
logger.log.assert_has_calls( | ||
[ | ||
call( | ||
"Time Wait For Batch (Predict)", | ||
ANY, | ||
1, | ||
) | ||
], | ||
) | ||
|
||
def test_invalid_log_every_n_steps(self) -> None: | ||
with self.assertRaisesRegex( | ||
ValueError, "log_every_n_steps must be at least 1, got 0" | ||
): | ||
TimeWaitForBatchLogger( | ||
logger=MagicMock(spec=MetricLogger), log_every_n_steps=0 | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
90 changes: 90 additions & 0 deletions
90
torchtnt/framework/callbacks/time_wait_for_batch_logger.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
from typing import cast, Union | ||
|
||
from pyre_extensions import none_throws | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
from torchtnt.framework.callback import Callback | ||
from torchtnt.framework.state import State | ||
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit | ||
from torchtnt.utils.distributed import rank_zero_fn | ||
from torchtnt.utils.loggers.logger import MetricLogger | ||
from torchtnt.utils.timer import TimerProtocol | ||
|
||
|
||
class TimeWaitForBatchLogger(Callback): | ||
""" | ||
A callback which logs time wait for batch as scalars to a MetricLogger. | ||
Args: | ||
logger: Either a subclass of :class:`torchtnt.utils.loggers.logger.MetricLogger` | ||
or a :class:`torch.utils.tensorboard.SummaryWriter` instance. | ||
log_every_n_steps: an optional int to control the log frequency | ||
""" | ||
|
||
def __init__( | ||
self, | ||
logger: Union[MetricLogger, SummaryWriter], | ||
log_every_n_steps: int = 1, | ||
) -> None: | ||
self._logger = logger | ||
if log_every_n_steps < 1: | ||
raise ValueError( | ||
f"log_every_n_steps must be at least 1, got {log_every_n_steps}" | ||
) | ||
self._log_every_n_steps = log_every_n_steps | ||
|
||
@rank_zero_fn | ||
def _log_step_metrics( | ||
self, | ||
*, | ||
timer: TimerProtocol, | ||
label: str, | ||
step: int, | ||
) -> None: | ||
if step % self._log_every_n_steps != 0: | ||
return | ||
|
||
data_wait_time_list = timer.recorded_durations.get("data_wait_time") | ||
if not data_wait_time_list: | ||
return | ||
|
||
if isinstance(self._logger, SummaryWriter): | ||
self._logger.add_scalar( | ||
label, | ||
data_wait_time_list[-1], | ||
step, | ||
) | ||
else: | ||
cast(MetricLogger, self._logger).log( | ||
label, | ||
data_wait_time_list[-1], | ||
step, | ||
) | ||
|
||
def on_train_step_end(self, state: State, unit: TTrainUnit) -> None: | ||
self._log_step_metrics( | ||
timer=none_throws(state.train_state).iteration_timer, | ||
label="Time Wait For Batch (Train)", | ||
step=unit.train_progress.num_steps_completed, | ||
) | ||
|
||
def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None: | ||
self._log_step_metrics( | ||
timer=none_throws(state.eval_state).iteration_timer, | ||
label="Time Wait For Batch (Eval)", | ||
step=unit.eval_progress.num_steps_completed, | ||
) | ||
|
||
def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None: | ||
self._log_step_metrics( | ||
timer=none_throws(state.predict_state).iteration_timer, | ||
label="Time Wait For Batch (Predict)", | ||
step=unit.predict_progress.num_steps_completed, | ||
) |