From 2409e1450d5c25600b6dad6cb3ed266fa4d015de Mon Sep 17 00:00:00 2001 From: Gal Rotem Date: Thu, 18 Apr 2024 17:40:31 -0700 Subject: [PATCH] twfb logger (#790) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/790 Adding a time-wait-for-batch logger callback. Reviewed By: JKSenthil Differential Revision: D56315489 fbshipit-source-id: 5fa9210114231c3c7d97d4872252cb8bf659b2d7 --- docs/source/framework/callbacks.rst | 1 + .../test_time_wait_for_batch_logger.py | 128 ++++++++++++++++++ torchtnt/framework/callbacks/__init__.py | 2 + .../callbacks/time_wait_for_batch_logger.py | 90 ++++++++++++ 4 files changed, 221 insertions(+) create mode 100644 tests/framework/callbacks/test_time_wait_for_batch_logger.py create mode 100644 torchtnt/framework/callbacks/time_wait_for_batch_logger.py diff --git a/docs/source/framework/callbacks.rst b/docs/source/framework/callbacks.rst index e3ed20043a..15c0ace0de 100644 --- a/docs/source/framework/callbacks.rst +++ b/docs/source/framework/callbacks.rst @@ -32,6 +32,7 @@ We offer several pre-written callbacks which are ready to be used out of the box SystemResourcesMonitor TensorBoardParameterMonitor TimeLimitInterrupter + TimeWaitForBatchLogger IterationTimeLogger TorchSnapshotSaver TQDMProgressBar diff --git a/tests/framework/callbacks/test_time_wait_for_batch_logger.py b/tests/framework/callbacks/test_time_wait_for_batch_logger.py new file mode 100644 index 0000000000..1f0f2918a9 --- /dev/null +++ b/tests/framework/callbacks/test_time_wait_for_batch_logger.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from unittest.mock import ANY, call, MagicMock + +import torch +from pyre_extensions import none_throws + +from torch.utils.tensorboard import SummaryWriter +from torchtnt.framework._callback_handler import CallbackHandler +from torchtnt.framework._test_utils import ( + DummyAutoUnit, + DummyPredictUnit, + generate_random_dataloader, +) +from torchtnt.framework.callbacks.time_wait_for_batch_logger import ( + TimeWaitForBatchLogger, +) +from torchtnt.framework.predict import predict + +from torchtnt.framework.state import EntryPoint, PhaseState, State +from torchtnt.framework.train import _train_impl +from torchtnt.utils.loggers.logger import MetricLogger +from torchtnt.utils.timer import TimerProtocol + + +class TimeWaitForBatchLoggerTest(unittest.TestCase): + def test_log_step_metrics(self) -> None: + for spec in [MetricLogger, SummaryWriter]: + with self.subTest(spec=spec): + logger = MagicMock(spec=spec) + log_method = logger.log if spec is MetricLogger else logger.add_scalar + + twfb_logger = TimeWaitForBatchLogger(logger=logger, log_every_n_steps=2) + timer = MagicMock(spec=TimerProtocol) + timer.recorded_durations = {"data_wait_time": [1, 2, 3]} + twfb_logger._log_step_metrics(timer=timer, label="foo", step=1) + log_method.assert_not_called() + twfb_logger._log_step_metrics(timer=timer, label="foo", step=2) + log_method.assert_has_calls( + [ + call( + "foo", + 3, # last element in the data wait time list + 2, # step + ) + ], + ) + + def test_comparing_twfb_logging_time(self) -> None: + dataloader = generate_random_dataloader( + num_samples=8, input_dim=2, batch_size=2 + ) + state = State( + entry_point=EntryPoint.FIT, + train_state=PhaseState( + dataloader=dataloader, + max_epochs=2, + max_steps_per_epoch=2, + ), + eval_state=PhaseState( + dataloader=dataloader, + max_steps_per_epoch=1, + evaluate_every_n_epochs=1, + ), + ) + + logger = MagicMock(spec=MetricLogger) + # we want to be able to compare the logging value to the state, so we need to create state manually and + # call _train_impl. This would have been similar to calling fit() and getting the state as a ret value + _train_impl( + state, + DummyAutoUnit(module=torch.nn.Linear(2, 2)), + CallbackHandler( + [TimeWaitForBatchLogger(logger=logger, log_every_n_steps=1)] + ), + ) + train_twfb_durations = none_throws( + state.train_state + ).iteration_timer.recorded_durations["data_wait_time"] + eval_iteration_timer = none_throws( + state.eval_state + ).iteration_timer.recorded_durations["data_wait_time"] + + expected_training_iteration_time_calls = [ + call("Time Wait For Batch (Train)", train_twfb_durations[i], i + 1) + for i in range(4) + ] + expected_eval_iteration_time_calls = [ + call("Time Wait For Batch (Eval)", eval_iteration_timer[i], i + 1) + for i in range(2) + ] + + logger.log.assert_has_calls( + expected_training_iteration_time_calls + expected_eval_iteration_time_calls, + any_order=True, + ) + + def test_with_predict(self) -> None: + logger = MagicMock(spec=MetricLogger) + predict( + DummyPredictUnit(input_dim=2), + generate_random_dataloader(num_samples=8, input_dim=2, batch_size=2), + max_steps_per_epoch=1, + callbacks=[TimeWaitForBatchLogger(logger=logger, log_every_n_steps=1)], + ) + logger.log.assert_has_calls( + [ + call( + "Time Wait For Batch (Predict)", + ANY, + 1, + ) + ], + ) + + def test_invalid_log_every_n_steps(self) -> None: + with self.assertRaisesRegex( + ValueError, "log_every_n_steps must be at least 1, got 0" + ): + TimeWaitForBatchLogger( + logger=MagicMock(spec=MetricLogger), log_every_n_steps=0 + ) diff --git a/torchtnt/framework/callbacks/__init__.py b/torchtnt/framework/callbacks/__init__.py index 38a62e0de6..29c9996f69 100644 --- a/torchtnt/framework/callbacks/__init__.py +++ b/torchtnt/framework/callbacks/__init__.py @@ -22,6 +22,7 @@ from .system_resources_monitor import SystemResourcesMonitor from .tensorboard_parameter_monitor import TensorBoardParameterMonitor from .time_limit_interrupter import TimeLimitInterrupter +from .time_wait_for_batch_logger import TimeWaitForBatchLogger from .torch_compile import TorchCompile from .torchsnapshot_saver import TorchSnapshotSaver from .tqdm_progress_bar import TQDMProgressBar @@ -43,6 +44,7 @@ "SystemResourcesMonitor", "TensorBoardParameterMonitor", "TimeLimitInterrupter", + "TimeWaitForBatchLogger", "TorchCompile", "TorchSnapshotSaver", "TQDMProgressBar", diff --git a/torchtnt/framework/callbacks/time_wait_for_batch_logger.py b/torchtnt/framework/callbacks/time_wait_for_batch_logger.py new file mode 100644 index 0000000000..040efc835b --- /dev/null +++ b/torchtnt/framework/callbacks/time_wait_for_batch_logger.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import cast, Union + +from pyre_extensions import none_throws +from torch.utils.tensorboard import SummaryWriter + +from torchtnt.framework.callback import Callback +from torchtnt.framework.state import State +from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit +from torchtnt.utils.distributed import rank_zero_fn +from torchtnt.utils.loggers.logger import MetricLogger +from torchtnt.utils.timer import TimerProtocol + + +class TimeWaitForBatchLogger(Callback): + """ + A callback which logs time wait for batch as scalars to a MetricLogger. + + Args: + logger: Either a subclass of :class:`torchtnt.utils.loggers.logger.MetricLogger` + or a :class:`torch.utils.tensorboard.SummaryWriter` instance. + log_every_n_steps: an optional int to control the log frequency + """ + + def __init__( + self, + logger: Union[MetricLogger, SummaryWriter], + log_every_n_steps: int = 1, + ) -> None: + self._logger = logger + if log_every_n_steps < 1: + raise ValueError( + f"log_every_n_steps must be at least 1, got {log_every_n_steps}" + ) + self._log_every_n_steps = log_every_n_steps + + @rank_zero_fn + def _log_step_metrics( + self, + *, + timer: TimerProtocol, + label: str, + step: int, + ) -> None: + if step % self._log_every_n_steps != 0: + return + + data_wait_time_list = timer.recorded_durations.get("data_wait_time") + if not data_wait_time_list: + return + + if isinstance(self._logger, SummaryWriter): + self._logger.add_scalar( + label, + data_wait_time_list[-1], + step, + ) + else: + cast(MetricLogger, self._logger).log( + label, + data_wait_time_list[-1], + step, + ) + + def on_train_step_end(self, state: State, unit: TTrainUnit) -> None: + self._log_step_metrics( + timer=none_throws(state.train_state).iteration_timer, + label="Time Wait For Batch (Train)", + step=unit.train_progress.num_steps_completed, + ) + + def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None: + self._log_step_metrics( + timer=none_throws(state.eval_state).iteration_timer, + label="Time Wait For Batch (Eval)", + step=unit.eval_progress.num_steps_completed, + ) + + def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None: + self._log_step_metrics( + timer=none_throws(state.predict_state).iteration_timer, + label="Time Wait For Batch (Predict)", + step=unit.predict_progress.num_steps_completed, + )