|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# This source code is licensed under the BSD-style license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | + |
| 8 | +import unittest |
| 9 | +from unittest.mock import ANY, call, MagicMock |
| 10 | + |
| 11 | +import torch |
| 12 | +from pyre_extensions import none_throws |
| 13 | + |
| 14 | +from torch.utils.tensorboard import SummaryWriter |
| 15 | +from torchtnt.framework._callback_handler import CallbackHandler |
| 16 | +from torchtnt.framework._test_utils import ( |
| 17 | + DummyAutoUnit, |
| 18 | + DummyPredictUnit, |
| 19 | + generate_random_dataloader, |
| 20 | +) |
| 21 | +from torchtnt.framework.callbacks.time_wait_for_batch_logger import ( |
| 22 | + TimeWaitForBatchLogger, |
| 23 | +) |
| 24 | +from torchtnt.framework.predict import predict |
| 25 | + |
| 26 | +from torchtnt.framework.state import EntryPoint, PhaseState, State |
| 27 | +from torchtnt.framework.train import _train_impl |
| 28 | +from torchtnt.utils.loggers.logger import MetricLogger |
| 29 | +from torchtnt.utils.timer import TimerProtocol |
| 30 | + |
| 31 | + |
| 32 | +class TimeWaitForBatchLoggerTest(unittest.TestCase): |
| 33 | + def test_log_step_metrics(self) -> None: |
| 34 | + for spec in [MetricLogger, SummaryWriter]: |
| 35 | + with self.subTest(spec=spec): |
| 36 | + logger = MagicMock(spec=spec) |
| 37 | + log_method = logger.log if spec is MetricLogger else logger.add_scalar |
| 38 | + |
| 39 | + twfb_logger = TimeWaitForBatchLogger(logger=logger, log_every_n_steps=2) |
| 40 | + timer = MagicMock(spec=TimerProtocol) |
| 41 | + timer.recorded_durations = {"data_wait_time": [1, 2, 3]} |
| 42 | + twfb_logger._log_step_metrics(timer=timer, label="foo", step=1) |
| 43 | + log_method.assert_not_called() |
| 44 | + twfb_logger._log_step_metrics(timer=timer, label="foo", step=2) |
| 45 | + log_method.assert_has_calls( |
| 46 | + [ |
| 47 | + call( |
| 48 | + "foo", |
| 49 | + 3, # last element in the data wait time list |
| 50 | + 2, # step |
| 51 | + ) |
| 52 | + ], |
| 53 | + ) |
| 54 | + |
| 55 | + def test_comparing_twfb_logging_time(self) -> None: |
| 56 | + dataloader = generate_random_dataloader( |
| 57 | + num_samples=8, input_dim=2, batch_size=2 |
| 58 | + ) |
| 59 | + state = State( |
| 60 | + entry_point=EntryPoint.FIT, |
| 61 | + train_state=PhaseState( |
| 62 | + dataloader=dataloader, |
| 63 | + max_epochs=2, |
| 64 | + max_steps_per_epoch=2, |
| 65 | + ), |
| 66 | + eval_state=PhaseState( |
| 67 | + dataloader=dataloader, |
| 68 | + max_steps_per_epoch=1, |
| 69 | + evaluate_every_n_epochs=1, |
| 70 | + ), |
| 71 | + ) |
| 72 | + |
| 73 | + logger = MagicMock(spec=MetricLogger) |
| 74 | + # we want to be able to compare the logging value to the state, so we need to create state manually and |
| 75 | + # call _train_impl. This would have been similar to calling fit() and getting the state as a ret value |
| 76 | + _train_impl( |
| 77 | + state, |
| 78 | + DummyAutoUnit(module=torch.nn.Linear(2, 2)), |
| 79 | + CallbackHandler( |
| 80 | + [TimeWaitForBatchLogger(logger=logger, log_every_n_steps=1)] |
| 81 | + ), |
| 82 | + ) |
| 83 | + train_twfb_durations = none_throws( |
| 84 | + state.train_state |
| 85 | + ).iteration_timer.recorded_durations["data_wait_time"] |
| 86 | + eval_iteration_timer = none_throws( |
| 87 | + state.eval_state |
| 88 | + ).iteration_timer.recorded_durations["data_wait_time"] |
| 89 | + |
| 90 | + expected_training_iteration_time_calls = [ |
| 91 | + call("Time Wait For Batch (Train)", train_twfb_durations[i], i + 1) |
| 92 | + for i in range(4) |
| 93 | + ] |
| 94 | + expected_eval_iteration_time_calls = [ |
| 95 | + call("Time Wait For Batch (Eval)", eval_iteration_timer[i], i + 1) |
| 96 | + for i in range(2) |
| 97 | + ] |
| 98 | + |
| 99 | + logger.log.assert_has_calls( |
| 100 | + expected_training_iteration_time_calls + expected_eval_iteration_time_calls, |
| 101 | + any_order=True, |
| 102 | + ) |
| 103 | + |
| 104 | + def test_with_predict(self) -> None: |
| 105 | + logger = MagicMock(spec=MetricLogger) |
| 106 | + predict( |
| 107 | + DummyPredictUnit(input_dim=2), |
| 108 | + generate_random_dataloader(num_samples=8, input_dim=2, batch_size=2), |
| 109 | + max_steps_per_epoch=1, |
| 110 | + callbacks=[TimeWaitForBatchLogger(logger=logger, log_every_n_steps=1)], |
| 111 | + ) |
| 112 | + logger.log.assert_has_calls( |
| 113 | + [ |
| 114 | + call( |
| 115 | + "Time Wait For Batch (Predict)", |
| 116 | + ANY, |
| 117 | + 1, |
| 118 | + ) |
| 119 | + ], |
| 120 | + ) |
| 121 | + |
| 122 | + def test_invalid_log_every_n_steps(self) -> None: |
| 123 | + with self.assertRaisesRegex( |
| 124 | + ValueError, "log_every_n_steps must be at least 1, got 0" |
| 125 | + ): |
| 126 | + TimeWaitForBatchLogger( |
| 127 | + logger=MagicMock(spec=MetricLogger), log_every_n_steps=0 |
| 128 | + ) |
0 commit comments