1
1
import warnings
2
+ from argparse import Namespace
2
3
from pathlib import Path
3
- from typing import Any , Dict , List
4
+ from typing import Any , Dict , List , Mapping , Optional , Union
4
5
5
6
import matplotlib .pyplot as plt
6
7
import numpy as np
8
+ from lightning_fabric .utilities .logger import _flatten_dict
9
+ from lightning_fabric .utilities .rank_zero import rank_zero_only
7
10
from pytorch_lightning .loggers import (
11
+ Logger ,
8
12
MLFlowLogger ,
9
13
NeptuneLogger ,
10
14
TensorBoardLogger ,
@@ -31,6 +35,66 @@ def prepare_tags(cfg: TCfg) -> List[str]:
31
35
return tags
32
36
33
37
38
+ class ClearMLLogger (Logger ):
39
+ def __init__ (self , ** kwargs : Any ):
40
+ try :
41
+ from clearml import Task
42
+ except ImportError as e :
43
+ raise ModuleNotFoundError (
44
+ "This contrib module requires clearml to be installed. "
45
+ "You may install clearml using: \n pip install clearml \n "
46
+ ) from e
47
+
48
+ experiment_kwargs = {
49
+ k : v for k , v in kwargs .items () if k not in ("project_name" , "task_name" , "task_type" , "offline_mode" )
50
+ }
51
+
52
+ if kwargs .get ("offline_mode" , False ):
53
+ Task .set_offline (offline_mode = True )
54
+ warnings .warn ("ClearMLSaver: running in offline mode" )
55
+
56
+ # Try to retrieve current the ClearML Task before trying to create a new one
57
+ self .task = Task .current_task ()
58
+ if self .task is None :
59
+ self .task = Task .init (
60
+ project_name = kwargs .get ("project_name" ),
61
+ task_name = kwargs .get ("task_name" ),
62
+ task_type = kwargs .get ("task_type" , Task .TaskTypes .training ),
63
+ ** experiment_kwargs ,
64
+ )
65
+
66
+ self .logger = self .task .get_logger ()
67
+
68
+ @property
69
+ def name (self ) -> str :
70
+ return "ClearMLLogger"
71
+
72
+ @property
73
+ def version (self ) -> Union [int , str ]:
74
+ return self .task .id
75
+
76
+ @rank_zero_only
77
+ def finalize (self , status : str ) -> None :
78
+ self .logger .flush ()
79
+
80
+ @rank_zero_only
81
+ def log_hyperparams (self , params : Optional [Union [Dict [str , Any ], Namespace ]]) -> None :
82
+ if isinstance (params , Namespace ):
83
+ params = vars (params )
84
+
85
+ if params is None :
86
+ params = {}
87
+ params = _flatten_dict (params )
88
+
89
+ self .task .connect (params )
90
+
91
+ @rank_zero_only
92
+ def log_metrics (self , metrics : Mapping [str , float ], step : Optional [int ] = None ) -> None :
93
+ assert rank_zero_only .rank == 0 , "experiment tried to log from global_rank != 0" # type: ignore
94
+ for k , v in metrics .items ():
95
+ self .logger .report_scalar (title = k , series = k , iteration = step , value = v )
96
+
97
+
34
98
class NeptunePipelineLogger (NeptuneLogger , IPipelineLogger ):
35
99
def log_pipeline_info (self , cfg : TCfg ) -> None :
36
100
warnings .warn (
@@ -132,10 +196,44 @@ def log_figure(self, fig: plt.Figure, title: str, idx: int) -> None:
132
196
self .experiment .log_figure (figure = fig , artifact_file = f"{ title } .png" , run_id = self .run_id )
133
197
134
198
199
+ class ClearMLPipelineLogger (ClearMLLogger , IPipelineLogger ):
200
+ def log_pipeline_info (self , cfg : TCfg ) -> None :
201
+ # log config
202
+ self .log_hyperparams (prepare_config_to_logging (cfg ))
203
+
204
+ # log tags
205
+ self .task .add_tags (prepare_tags (cfg ))
206
+
207
+ # log transforms as files
208
+ names_files = save_transforms_as_files (cfg )
209
+ if names_files :
210
+ for name , transforms_file in names_files :
211
+ self .task .upload_artifact (name = name , artifact_object = transforms_file )
212
+
213
+ # log code
214
+ self .task .upload_artifact (name = "code" , artifact_object = OML_PATH )
215
+
216
+ # log dataframe
217
+ self .task .upload_artifact (
218
+ name = "dataset" ,
219
+ artifact_object = str (Path (cfg ["dataset_root" ]) / cfg ["dataframe_name" ]),
220
+ )
221
+
222
+ def log_figure (self , fig : plt .Figure , title : str , idx : int ) -> None :
223
+ self .logger .report_matplotlib_figure (
224
+ title = title ,
225
+ series = "" ,
226
+ figure = fig ,
227
+ iteration = idx ,
228
+ report_image = True ,
229
+ )
230
+
231
+
135
232
__all__ = [
136
233
"IPipelineLogger" ,
137
234
"TensorBoardPipelineLogger" ,
138
235
"WandBPipelineLogger" ,
139
236
"NeptunePipelineLogger" ,
140
237
"MLFlowPipelineLogger" ,
238
+ "ClearMLPipelineLogger" ,
141
239
]
0 commit comments