@@ -77,6 +77,19 @@ class ProgressBar(BaseLogger):
7777 # Progress bar will looks like
7878 # Epoch [2/50]: [64/128] 50%|█████ , loss=0.123 [06:17<12:34]
7979
80+
81+ Example where the State Attributes ``trainer.state.alpha`` and ``trainer.state.beta``
82+ are also logged along with the NLL and Accuracy after each iteration:
83+
84+ .. code-block:: python
85+
86+ pbar.attach(
87+ trainer,
88+ metric_names=["nll", "accuracy"],
89+ state_attributes=["alpha", "beta"],
90+ )
91+
92+
8093 Note:
8194 When adding attaching the progress bar to an engine, it is recommend that you replace
8295 every print operation in the engine's handlers triggered every iteration with
@@ -88,6 +101,9 @@ class ProgressBar(BaseLogger):
88101 Due to `tqdm notebook bugs <https://github.com/tqdm/tqdm/issues/594>`_, bar format may be needed to be set
89102 to an empty string value.
90103
104+ .. versionchanged:: 0.5.0
105+ `attach` now accepts an optional list of `state_attributes`
106+
91107 """
92108
93109 _events_order = [
@@ -161,6 +177,7 @@ def attach( # type: ignore[override]
161177 output_transform : Optional [Callable ] = None ,
162178 event_name : Union [Events , CallableEventWithFilter ] = Events .ITERATION_COMPLETED ,
163179 closing_event_name : Union [Events , CallableEventWithFilter ] = Events .EPOCH_COMPLETED ,
180+ state_attributes : Optional [List [str ]] = None ,
164181 ) -> None :
165182 """
166183 Attaches the progress bar to an engine object.
@@ -176,6 +193,7 @@ def attach( # type: ignore[override]
176193 :class:`~ignite.engine.events.Events`.
177194 closing_event_name: event's name on which the progress bar is closed. Valid events are from
178195 :class:`~ignite.engine.events.Events`.
196+ state_attributes: list of attributes of the ``trainer.state`` to plot.
179197
180198 Note:
181199 Accepted output value types are numbers, 0d and 1d torch tensors and strings.
@@ -193,7 +211,13 @@ def attach( # type: ignore[override]
193211 if not self ._compare_lt (event_name , closing_event_name ):
194212 raise ValueError (f"Logging event { event_name } should be called before closing event { closing_event_name } " )
195213
196- log_handler = _OutputHandler (desc , metric_names , output_transform , closing_event_name = closing_event_name )
214+ log_handler = _OutputHandler (
215+ desc ,
216+ metric_names ,
217+ output_transform ,
218+ closing_event_name = closing_event_name ,
219+ state_attributes = state_attributes ,
220+ )
197221
198222 super (ProgressBar , self ).attach (engine , log_handler , event_name )
199223 engine .add_event_handler (closing_event_name , self ._close )
@@ -215,6 +239,7 @@ def _create_opt_params_handler(self, *args: Any, **kwargs: Any) -> Callable:
215239class _OutputHandler (BaseOutputHandler ):
216240 """Helper handler to log engine's output and/or metrics
217241
242+ pbar = ProgressBar()
218243 Args:
219244 description: progress bar description.
220245 metric_names: list of metric names to plot or a string "all" to plot all available
@@ -226,6 +251,7 @@ class _OutputHandler(BaseOutputHandler):
226251 closing_event_name: event's name on which the progress bar is closed. Valid events are from
227252 :class:`~ignite.engine.events.Events` or any `event_name` added by
228253 :meth:`~ignite.engine.engine.Engine.register_events`.
254+ state_attributes: list of attributes of the ``trainer.state`` to plot.
229255
230256 """
231257
@@ -235,11 +261,14 @@ def __init__(
235261 metric_names : Optional [Union [str , List [str ]]] = None ,
236262 output_transform : Optional [Callable ] = None ,
237263 closing_event_name : Union [Events , CallableEventWithFilter ] = Events .EPOCH_COMPLETED ,
264+ state_attributes : Optional [List [str ]] = None ,
238265 ):
239266 if metric_names is None and output_transform is None :
240267 # This helps to avoid 'Either metric_names or output_transform should be defined' of BaseOutputHandler
241268 metric_names = []
242- super (_OutputHandler , self ).__init__ (description , metric_names , output_transform , global_step_transform = None )
269+ super (_OutputHandler , self ).__init__ (
270+ description , metric_names , output_transform , global_step_transform = None , state_attributes = state_attributes
271+ )
243272 self .closing_event_name = closing_event_name
244273
245274 @staticmethod
0 commit comments