Skip to content

Commit 327cec5

Browse files
authored
Added state attributes for MLFlow Logger (#2160)
* added state attributes for MLFlow Logger * autopep8 fix * added state attributes in args Co-authored-by: Ishan-Kumar2 <[email protected]>
1 parent f2b812d commit 327cec5

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

ignite/contrib/handlers/mlflow_logger.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,20 @@ def global_step_transform(*args, **kwargs):
180180
global_step_transform=global_step_transform
181181
)
182182
183+
Another example where the State Attributes ``trainer.state.alpha`` and ``trainer.state.beta``
184+
are also logged along with the NLL and Accuracy after each iteration:
185+
186+
.. code-block:: python
187+
188+
mlflow_logger.attach_output_handler(
189+
trainer,
190+
event_name=Events.ITERATION_COMPLETED,
191+
tag="training",
192+
metrics=["nll", "accuracy"],
193+
state_attributes=["alpha", "beta"],
194+
)
195+
196+
183197
Args:
184198
tag: common title for all produced plots. For example, 'training'
185199
metric_names: list of metric names to plot or a string "all" to plot all available
@@ -193,6 +207,7 @@ def global_step_transform(*args, **kwargs):
193207
Default is None, global_step based on attached engine. If provided,
194208
uses function output as global_step. To setup global step from another engine, please use
195209
:meth:`~ignite.contrib.handlers.mlflow_logger.global_step_from_engine`.
210+
state_attributes: list of attributes of the ``trainer.state`` to plot.
196211
197212
Note:
198213
@@ -203,6 +218,8 @@ def global_step_transform(*args, **kwargs):
203218
def global_step_transform(engine, event_name):
204219
return engine.state.get_event_attrib_value(event_name)
205220
221+
.. versionchanged:: 0.5.0
222+
accepts an optional list of `state_attributes`
206223
"""
207224

208225
def __init__(
@@ -211,8 +228,11 @@ def __init__(
211228
metric_names: Optional[Union[str, List[str]]] = None,
212229
output_transform: Optional[Callable] = None,
213230
global_step_transform: Optional[Callable] = None,
231+
state_attributes: Optional[List[str]] = None,
214232
) -> None:
215-
super(OutputHandler, self).__init__(tag, metric_names, output_transform, global_step_transform)
233+
super(OutputHandler, self).__init__(
234+
tag, metric_names, output_transform, global_step_transform, state_attributes
235+
)
216236

217237
def __call__(self, engine: Engine, logger: MLflowLogger, event_name: Union[str, Events]) -> None:
218238

tests/ignite/contrib/handlers/test_mlflow_logger.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,25 @@ def test_output_handler_with_global_step_from_engine():
188188
)
189189

190190

191+
def test_output_handler_state_attrs():
192+
wrapper = OutputHandler("tag", state_attributes=["alpha", "beta", "gamma"])
193+
mock_logger = MagicMock(spec=MLflowLogger)
194+
mock_logger.log_metrics = MagicMock()
195+
196+
mock_engine = MagicMock()
197+
mock_engine.state = State()
198+
mock_engine.state.iteration = 5
199+
mock_engine.state.alpha = 3.899
200+
mock_engine.state.beta = torch.tensor(12.21)
201+
mock_engine.state.gamma = torch.tensor([21.0, 6.0])
202+
203+
wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
204+
205+
mock_logger.log_metrics.assert_called_once_with(
206+
{"tag alpha": 3.899, "tag beta": torch.tensor(12.21).item(), "tag gamma 0": 21.0, "tag gamma 1": 6.0,}, step=5,
207+
)
208+
209+
191210
def test_optimizer_params_handler_wrong_setup():
192211

193212
with pytest.raises(TypeError):

0 commit comments

Comments
 (0)