@@ -77,6 +77,19 @@ class ProgressBar(BaseLogger):
77
77
# Progress bar will looks like
78
78
# Epoch [2/50]: [64/128] 50%|█████ , loss=0.123 [06:17<12:34]
79
79
80
+
81
+ Example where the State Attributes ``trainer.state.alpha`` and ``trainer.state.beta``
82
+ are also logged along with the NLL and Accuracy after each iteration:
83
+
84
+ .. code-block:: python
85
+
86
+ pbar.attach(
87
+ trainer,
88
+ metric_names=["nll", "accuracy"],
89
+ state_attributes=["alpha", "beta"],
90
+ )
91
+
92
+
80
93
Note:
81
94
When adding attaching the progress bar to an engine, it is recommend that you replace
82
95
every print operation in the engine's handlers triggered every iteration with
@@ -88,6 +101,9 @@ class ProgressBar(BaseLogger):
88
101
Due to `tqdm notebook bugs <https://github.com/tqdm/tqdm/issues/594>`_, bar format may be needed to be set
89
102
to an empty string value.
90
103
104
+ .. versionchanged:: 0.5.0
105
+ `attach` now accepts an optional list of `state_attributes`
106
+
91
107
"""
92
108
93
109
_events_order = [
@@ -161,6 +177,7 @@ def attach( # type: ignore[override]
161
177
output_transform : Optional [Callable ] = None ,
162
178
event_name : Union [Events , CallableEventWithFilter ] = Events .ITERATION_COMPLETED ,
163
179
closing_event_name : Union [Events , CallableEventWithFilter ] = Events .EPOCH_COMPLETED ,
180
+ state_attributes : Optional [List [str ]] = None ,
164
181
) -> None :
165
182
"""
166
183
Attaches the progress bar to an engine object.
@@ -176,6 +193,7 @@ def attach( # type: ignore[override]
176
193
:class:`~ignite.engine.events.Events`.
177
194
closing_event_name: event's name on which the progress bar is closed. Valid events are from
178
195
:class:`~ignite.engine.events.Events`.
196
+ state_attributes: list of attributes of the ``trainer.state`` to plot.
179
197
180
198
Note:
181
199
Accepted output value types are numbers, 0d and 1d torch tensors and strings.
@@ -193,7 +211,13 @@ def attach( # type: ignore[override]
193
211
if not self ._compare_lt (event_name , closing_event_name ):
194
212
raise ValueError (f"Logging event { event_name } should be called before closing event { closing_event_name } " )
195
213
196
- log_handler = _OutputHandler (desc , metric_names , output_transform , closing_event_name = closing_event_name )
214
+ log_handler = _OutputHandler (
215
+ desc ,
216
+ metric_names ,
217
+ output_transform ,
218
+ closing_event_name = closing_event_name ,
219
+ state_attributes = state_attributes ,
220
+ )
197
221
198
222
super (ProgressBar , self ).attach (engine , log_handler , event_name )
199
223
engine .add_event_handler (closing_event_name , self ._close )
@@ -215,6 +239,7 @@ def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> Callable:
215
239
class _OutputHandler (BaseOutputHandler ):
216
240
"""Helper handler to log engine's output and/or metrics
217
241
242
+ pbar = ProgressBar()
218
243
Args:
219
244
description: progress bar description.
220
245
metric_names: list of metric names to plot or a string "all" to plot all available
@@ -226,6 +251,7 @@ class _OutputHandler(BaseOutputHandler):
226
251
closing_event_name: event's name on which the progress bar is closed. Valid events are from
227
252
:class:`~ignite.engine.events.Events` or any `event_name` added by
228
253
:meth:`~ignite.engine.engine.Engine.register_events`.
254
+ state_attributes: list of attributes of the ``trainer.state`` to plot.
229
255
230
256
"""
231
257
@@ -235,11 +261,14 @@ def __init__(
235
261
metric_names : Optional [Union [str , List [str ]]] = None ,
236
262
output_transform : Optional [Callable ] = None ,
237
263
closing_event_name : Union [Events , CallableEventWithFilter ] = Events .EPOCH_COMPLETED ,
264
+ state_attributes : Optional [List [str ]] = None ,
238
265
):
239
266
if metric_names is None and output_transform is None :
240
267
# This helps to avoid 'Either metric_names or output_transform should be defined' of BaseOutputHandler
241
268
metric_names = []
242
- super (_OutputHandler , self ).__init__ (description , metric_names , output_transform , global_step_transform = None )
269
+ super (_OutputHandler , self ).__init__ (
270
+ description , metric_names , output_transform , global_step_transform = None , state_attributes = state_attributes
271
+ )
243
272
self .closing_event_name = closing_event_name
244
273
245
274
@staticmethod
0 commit comments