Skip to content

Commit 2409e14

Browse files
galrotemfacebook-github-bot
authored andcommitted
twfb logger (#790)
Summary: Pull Request resolved: #790 Adding a time-wait-for-batch logger callback. Reviewed By: JKSenthil Differential Revision: D56315489 fbshipit-source-id: 5fa9210114231c3c7d97d4872252cb8bf659b2d7
1 parent 35f9d92 commit 2409e14

File tree

4 files changed

+221
-0
lines changed

4 files changed

+221
-0
lines changed

docs/source/framework/callbacks.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ We offer several pre-written callbacks which are ready to be used out of the box
3232
SystemResourcesMonitor
3333
TensorBoardParameterMonitor
3434
TimeLimitInterrupter
35+
TimeWaitForBatchLogger
3536
IterationTimeLogger
3637
TorchSnapshotSaver
3738
TQDMProgressBar
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import unittest
9+
from unittest.mock import ANY, call, MagicMock
10+
11+
import torch
12+
from pyre_extensions import none_throws
13+
14+
from torch.utils.tensorboard import SummaryWriter
15+
from torchtnt.framework._callback_handler import CallbackHandler
16+
from torchtnt.framework._test_utils import (
17+
DummyAutoUnit,
18+
DummyPredictUnit,
19+
generate_random_dataloader,
20+
)
21+
from torchtnt.framework.callbacks.time_wait_for_batch_logger import (
22+
TimeWaitForBatchLogger,
23+
)
24+
from torchtnt.framework.predict import predict
25+
26+
from torchtnt.framework.state import EntryPoint, PhaseState, State
27+
from torchtnt.framework.train import _train_impl
28+
from torchtnt.utils.loggers.logger import MetricLogger
29+
from torchtnt.utils.timer import TimerProtocol
30+
31+
32+
class TimeWaitForBatchLoggerTest(unittest.TestCase):
33+
def test_log_step_metrics(self) -> None:
34+
for spec in [MetricLogger, SummaryWriter]:
35+
with self.subTest(spec=spec):
36+
logger = MagicMock(spec=spec)
37+
log_method = logger.log if spec is MetricLogger else logger.add_scalar
38+
39+
twfb_logger = TimeWaitForBatchLogger(logger=logger, log_every_n_steps=2)
40+
timer = MagicMock(spec=TimerProtocol)
41+
timer.recorded_durations = {"data_wait_time": [1, 2, 3]}
42+
twfb_logger._log_step_metrics(timer=timer, label="foo", step=1)
43+
log_method.assert_not_called()
44+
twfb_logger._log_step_metrics(timer=timer, label="foo", step=2)
45+
log_method.assert_has_calls(
46+
[
47+
call(
48+
"foo",
49+
3, # last element in the data wait time list
50+
2, # step
51+
)
52+
],
53+
)
54+
55+
def test_comparing_twfb_logging_time(self) -> None:
56+
dataloader = generate_random_dataloader(
57+
num_samples=8, input_dim=2, batch_size=2
58+
)
59+
state = State(
60+
entry_point=EntryPoint.FIT,
61+
train_state=PhaseState(
62+
dataloader=dataloader,
63+
max_epochs=2,
64+
max_steps_per_epoch=2,
65+
),
66+
eval_state=PhaseState(
67+
dataloader=dataloader,
68+
max_steps_per_epoch=1,
69+
evaluate_every_n_epochs=1,
70+
),
71+
)
72+
73+
logger = MagicMock(spec=MetricLogger)
74+
# we want to be able to compare the logging value to the state, so we need to create state manually and
75+
# call _train_impl. This would have been similar to calling fit() and getting the state as a ret value
76+
_train_impl(
77+
state,
78+
DummyAutoUnit(module=torch.nn.Linear(2, 2)),
79+
CallbackHandler(
80+
[TimeWaitForBatchLogger(logger=logger, log_every_n_steps=1)]
81+
),
82+
)
83+
train_twfb_durations = none_throws(
84+
state.train_state
85+
).iteration_timer.recorded_durations["data_wait_time"]
86+
eval_iteration_timer = none_throws(
87+
state.eval_state
88+
).iteration_timer.recorded_durations["data_wait_time"]
89+
90+
expected_training_iteration_time_calls = [
91+
call("Time Wait For Batch (Train)", train_twfb_durations[i], i + 1)
92+
for i in range(4)
93+
]
94+
expected_eval_iteration_time_calls = [
95+
call("Time Wait For Batch (Eval)", eval_iteration_timer[i], i + 1)
96+
for i in range(2)
97+
]
98+
99+
logger.log.assert_has_calls(
100+
expected_training_iteration_time_calls + expected_eval_iteration_time_calls,
101+
any_order=True,
102+
)
103+
104+
def test_with_predict(self) -> None:
105+
logger = MagicMock(spec=MetricLogger)
106+
predict(
107+
DummyPredictUnit(input_dim=2),
108+
generate_random_dataloader(num_samples=8, input_dim=2, batch_size=2),
109+
max_steps_per_epoch=1,
110+
callbacks=[TimeWaitForBatchLogger(logger=logger, log_every_n_steps=1)],
111+
)
112+
logger.log.assert_has_calls(
113+
[
114+
call(
115+
"Time Wait For Batch (Predict)",
116+
ANY,
117+
1,
118+
)
119+
],
120+
)
121+
122+
def test_invalid_log_every_n_steps(self) -> None:
123+
with self.assertRaisesRegex(
124+
ValueError, "log_every_n_steps must be at least 1, got 0"
125+
):
126+
TimeWaitForBatchLogger(
127+
logger=MagicMock(spec=MetricLogger), log_every_n_steps=0
128+
)

torchtnt/framework/callbacks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .system_resources_monitor import SystemResourcesMonitor
2323
from .tensorboard_parameter_monitor import TensorBoardParameterMonitor
2424
from .time_limit_interrupter import TimeLimitInterrupter
25+
from .time_wait_for_batch_logger import TimeWaitForBatchLogger
2526
from .torch_compile import TorchCompile
2627
from .torchsnapshot_saver import TorchSnapshotSaver
2728
from .tqdm_progress_bar import TQDMProgressBar
@@ -43,6 +44,7 @@
4344
"SystemResourcesMonitor",
4445
"TensorBoardParameterMonitor",
4546
"TimeLimitInterrupter",
47+
"TimeWaitForBatchLogger",
4648
"TorchCompile",
4749
"TorchSnapshotSaver",
4850
"TQDMProgressBar",
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from typing import cast, Union
9+
10+
from pyre_extensions import none_throws
11+
from torch.utils.tensorboard import SummaryWriter
12+
13+
from torchtnt.framework.callback import Callback
14+
from torchtnt.framework.state import State
15+
from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit
16+
from torchtnt.utils.distributed import rank_zero_fn
17+
from torchtnt.utils.loggers.logger import MetricLogger
18+
from torchtnt.utils.timer import TimerProtocol
19+
20+
21+
class TimeWaitForBatchLogger(Callback):
22+
"""
23+
A callback which logs time wait for batch as scalars to a MetricLogger.
24+
25+
Args:
26+
logger: Either a subclass of :class:`torchtnt.utils.loggers.logger.MetricLogger`
27+
or a :class:`torch.utils.tensorboard.SummaryWriter` instance.
28+
log_every_n_steps: an optional int to control the log frequency
29+
"""
30+
31+
def __init__(
32+
self,
33+
logger: Union[MetricLogger, SummaryWriter],
34+
log_every_n_steps: int = 1,
35+
) -> None:
36+
self._logger = logger
37+
if log_every_n_steps < 1:
38+
raise ValueError(
39+
f"log_every_n_steps must be at least 1, got {log_every_n_steps}"
40+
)
41+
self._log_every_n_steps = log_every_n_steps
42+
43+
@rank_zero_fn
44+
def _log_step_metrics(
45+
self,
46+
*,
47+
timer: TimerProtocol,
48+
label: str,
49+
step: int,
50+
) -> None:
51+
if step % self._log_every_n_steps != 0:
52+
return
53+
54+
data_wait_time_list = timer.recorded_durations.get("data_wait_time")
55+
if not data_wait_time_list:
56+
return
57+
58+
if isinstance(self._logger, SummaryWriter):
59+
self._logger.add_scalar(
60+
label,
61+
data_wait_time_list[-1],
62+
step,
63+
)
64+
else:
65+
cast(MetricLogger, self._logger).log(
66+
label,
67+
data_wait_time_list[-1],
68+
step,
69+
)
70+
71+
def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
72+
self._log_step_metrics(
73+
timer=none_throws(state.train_state).iteration_timer,
74+
label="Time Wait For Batch (Train)",
75+
step=unit.train_progress.num_steps_completed,
76+
)
77+
78+
def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None:
79+
self._log_step_metrics(
80+
timer=none_throws(state.eval_state).iteration_timer,
81+
label="Time Wait For Batch (Eval)",
82+
step=unit.eval_progress.num_steps_completed,
83+
)
84+
85+
def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None:
86+
self._log_step_metrics(
87+
timer=none_throws(state.predict_state).iteration_timer,
88+
label="Time Wait For Batch (Predict)",
89+
step=unit.predict_progress.num_steps_completed,
90+
)

0 commit comments

Comments
 (0)