Skip to content

Commit 3eec6df

Browse files
Added state attributes for tqdm logger (#2162)
* added state attributes for tqdm logger * autopep8 fix * removed extra imports * fix flake8 Co-authored-by: Ishan-Kumar2 <[email protected]> Co-authored-by: vfdev <[email protected]>
1 parent 0f0920d commit 3eec6df

File tree

2 files changed

+69
-2
lines changed

2 files changed

+69
-2
lines changed

ignite/contrib/handlers/tqdm_logger.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,19 @@ class ProgressBar(BaseLogger):
7777
# Progress bar will looks like
7878
# Epoch [2/50]: [64/128] 50%|█████ , loss=0.123 [06:17<12:34]
7979
80+
81+
Example where the State Attributes ``trainer.state.alpha`` and ``trainer.state.beta``
82+
are also logged along with the NLL and Accuracy after each iteration:
83+
84+
.. code-block:: python
85+
86+
pbar.attach(
87+
trainer,
88+
metric_names=["nll", "accuracy"],
89+
state_attributes=["alpha", "beta"],
90+
)
91+
92+
8093
Note:
8194
When adding attaching the progress bar to an engine, it is recommend that you replace
8295
every print operation in the engine's handlers triggered every iteration with
@@ -88,6 +101,9 @@ class ProgressBar(BaseLogger):
88101
Due to `tqdm notebook bugs <https://github.com/tqdm/tqdm/issues/594>`_, bar format may be needed to be set
89102
to an empty string value.
90103
104+
.. versionchanged:: 0.5.0
105+
`attach` now accepts an optional list of `state_attributes`
106+
91107
"""
92108

93109
_events_order = [
@@ -161,6 +177,7 @@ def attach( # type: ignore[override]
161177
output_transform: Optional[Callable] = None,
162178
event_name: Union[Events, CallableEventWithFilter] = Events.ITERATION_COMPLETED,
163179
closing_event_name: Union[Events, CallableEventWithFilter] = Events.EPOCH_COMPLETED,
180+
state_attributes: Optional[List[str]] = None,
164181
) -> None:
165182
"""
166183
Attaches the progress bar to an engine object.
@@ -176,6 +193,7 @@ def attach( # type: ignore[override]
176193
:class:`~ignite.engine.events.Events`.
177194
closing_event_name: event's name on which the progress bar is closed. Valid events are from
178195
:class:`~ignite.engine.events.Events`.
196+
state_attributes: list of attributes of the ``trainer.state`` to plot.
179197
180198
Note:
181199
Accepted output value types are numbers, 0d and 1d torch tensors and strings.
@@ -193,7 +211,13 @@ def attach( # type: ignore[override]
193211
if not self._compare_lt(event_name, closing_event_name):
194212
raise ValueError(f"Logging event {event_name} should be called before closing event {closing_event_name}")
195213

196-
log_handler = _OutputHandler(desc, metric_names, output_transform, closing_event_name=closing_event_name)
214+
log_handler = _OutputHandler(
215+
desc,
216+
metric_names,
217+
output_transform,
218+
closing_event_name=closing_event_name,
219+
state_attributes=state_attributes,
220+
)
197221

198222
super(ProgressBar, self).attach(engine, log_handler, event_name)
199223
engine.add_event_handler(closing_event_name, self._close)
@@ -215,6 +239,7 @@ def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> Callable:
215239
class _OutputHandler(BaseOutputHandler):
216240
"""Helper handler to log engine's output and/or metrics
217241
242+
pbar = ProgressBar()
218243
Args:
219244
description: progress bar description.
220245
metric_names: list of metric names to plot or a string "all" to plot all available
@@ -226,6 +251,7 @@ class _OutputHandler(BaseOutputHandler):
226251
closing_event_name: event's name on which the progress bar is closed. Valid events are from
227252
:class:`~ignite.engine.events.Events` or any `event_name` added by
228253
:meth:`~ignite.engine.engine.Engine.register_events`.
254+
state_attributes: list of attributes of the ``trainer.state`` to plot.
229255
230256
"""
231257

@@ -235,11 +261,14 @@ def __init__(
235261
metric_names: Optional[Union[str, List[str]]] = None,
236262
output_transform: Optional[Callable] = None,
237263
closing_event_name: Union[Events, CallableEventWithFilter] = Events.EPOCH_COMPLETED,
264+
state_attributes: Optional[List[str]] = None,
238265
):
239266
if metric_names is None and output_transform is None:
240267
# This helps to avoid 'Either metric_names or output_transform should be defined' of BaseOutputHandler
241268
metric_names = []
242-
super(_OutputHandler, self).__init__(description, metric_names, output_transform, global_step_transform=None)
269+
super(_OutputHandler, self).__init__(
270+
description, metric_names, output_transform, global_step_transform=None, state_attributes=state_attributes
271+
)
243272
self.closing_event_name = closing_event_name
244273

245274
@staticmethod

tests/ignite/contrib/handlers/test_tqdm_logger.py

+38
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,44 @@ def step(engine, batch):
209209
assert actual == expected
210210

211211

212+
def test_pbar_with_state_attrs(capsys):
213+
214+
n_iters = 2
215+
data = list(range(n_iters))
216+
loss_values = iter(range(n_iters))
217+
218+
def step(engine, batch):
219+
loss_value = next(loss_values)
220+
return loss_value
221+
222+
trainer = Engine(step)
223+
trainer.state.alpha = 3.899
224+
trainer.state.beta = torch.tensor(12.21)
225+
trainer.state.gamma = torch.tensor([21.0, 6.0])
226+
227+
RunningAverage(alpha=0.5, output_transform=lambda x: x).attach(trainer, "batchloss")
228+
229+
pbar = ProgressBar()
230+
pbar.attach(trainer, metric_names=["batchloss",], state_attributes=["alpha", "beta", "gamma"])
231+
232+
trainer.run(data=data, max_epochs=1)
233+
234+
captured = capsys.readouterr()
235+
err = captured.err.split("\r")
236+
err = list(map(lambda x: x.strip(), err))
237+
err = list(filter(None, err))
238+
actual = err[-1]
239+
if get_tqdm_version() < LooseVersion("4.49.0"):
240+
expected = (
241+
"Iteration: [1/2] 50%|█████ , batchloss=0.5, alpha=3.9, beta=12.2, gamma_0=21, gamma_1=6 [00:00<00:00]"
242+
)
243+
else:
244+
expected = (
245+
"Iteration: [1/2] 50%|█████ , batchloss=0.5, alpha=3.9, beta=12.2, gamma_0=21, gamma_1=6 [00:00<?]"
246+
)
247+
assert actual == expected
248+
249+
212250
def test_pbar_no_metric_names(capsys):
213251

214252
n_epochs = 2

0 commit comments

Comments
 (0)