Skip to content

Commit 6489a93

Browse files
committed
upd
1 parent 52b2437 commit 6489a93

File tree

6 files changed

+26
-16
lines changed

6 files changed

+26
-16
lines changed

docs/readme/examples_source/extractor/train_val_pl.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ from oml.miners.inbatch_all_tri import AllTripletsMiner
1616
from oml.models import ViTExtractor
1717
from oml.samplers.balance import BalanceSampler
1818
from oml.utils.download_mock_dataset import download_mock_dataset
19-
from pytorch_lightning.loggers import NeptuneLogger, TensorBoardLogger, WandbLogger
19+
from oml.lightning.pipelines.logging import NeptunePipelineLogger, TensorBoardPipelineLogger, WandBPipelineLogger
2020

2121
dataset_root = "mock_dataset/"
2222
df_train, df_val = download_mock_dataset(dataset_root)
@@ -37,15 +37,15 @@ val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4)
3737
metric_callback = MetricValCallback(metric=EmbeddingMetrics(extra_keys=[train_dataset.paths_key,]), log_images=True)
3838

3939
# 1) Logging with Tensorboard
40-
logger = TensorBoardLogger(".")
40+
logger = TensorBoardPipelineLogger(".")
4141

4242
# 2) Logging with Neptune
43-
# logger = NeptuneLogger(api_key="", project="", log_model_checkpoints=False)
43+
# logger = NeptunePipelineLogger(api_key="", project="", log_model_checkpoints=False)
4444

4545
# 3) Logging with Weights and Biases
4646
# import os
4747
# os.environ["WANDB_API_KEY"] = ""
48-
# logger = WandbLogger(project="test_project", log_model=False)
48+
# logger = WandBPipelineLogger(project="test_project", log_model=False)
4949

5050
# run
5151
pl_model = ExtractorModule(extractor, criterion, optimizer)

oml/interfaces/loggers.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
from oml.const import TCfg
77

88

9-
class IPipelineLogger(LightningLogger):
9+
class IFigureLogger:
1010
@abstractmethod
11-
def log_experiment_info(self, cfg: TCfg) -> None:
11+
def log_figure(self, fig: plt.Figure, title: str, idx: int) -> None:
1212
raise NotImplementedError()
1313

14+
15+
class IPipelineLogger(LightningLogger, IFigureLogger):
1416
@abstractmethod
15-
def log_figure(self, fig: plt.Figure, title: str, idx: int) -> None:
17+
def log_pipeline_info(self, cfg: TCfg) -> None:
1618
raise NotImplementedError()

oml/lightning/callbacks/metric.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from math import ceil
23
from typing import Any, Optional
34

@@ -9,7 +10,7 @@
910

1011
from oml.const import LOG_IMAGE_FOLDER
1112
from oml.ddp.patching import check_loaders_is_patched, patch_dataloader_to_ddp
12-
from oml.interfaces.loggers import IPipelineLogger
13+
from oml.interfaces.loggers import IFigureLogger
1314
from oml.interfaces.metrics import IBasicMetric, IMetricDDP, IMetricVisualisable
1415
from oml.lightning.modules.ddp import ModuleDDP
1516
from oml.utils.misc import flatten_dict
@@ -100,12 +101,16 @@ def _log_images(self, pl_module: pl.LightningModule) -> None:
100101
if not isinstance(self.metric, IMetricVisualisable):
101102
return
102103

104+
if not isinstance(pl_module.logger, IFigureLogger):
105+
warnings.warn(
106+
f"Unexpected logger {pl_module.logger}. Figures have not been saved. "
107+
f"Please, use a child of {IFigureLogger}."
108+
)
109+
return
110+
103111
for fig, metric_log_str in zip(*self.metric.visualize()):
104112
log_str = f"{LOG_IMAGE_FOLDER}/{metric_log_str}"
105-
106-
assert isinstance(pl_module.logger, IPipelineLogger)
107113
pl_module.logger.log_figure(fig=fig, title=log_str, idx=pl_module.current_epoch)
108-
109114
plt.close(fig=fig)
110115

111116
def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:

oml/lightning/pipelines/logging.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def prepare_tags(cfg: TCfg) -> List[str]:
2727

2828

2929
class NeptunePipelineLogger(NeptuneLogger, IPipelineLogger):
30-
def log_experiment_info(self, cfg: TCfg) -> None:
30+
def log_pipeline_info(self, cfg: TCfg) -> None:
3131
warnings.warn(
3232
"Unfortunately, in the case of using Neptune, you may experience that long experiments are "
3333
"stacked and not responding. It's not an issue on OML's side, so, we cannot fix it."
@@ -57,7 +57,7 @@ def log_figure(self, fig: plt.Figure, title: str, idx: int) -> None:
5757

5858

5959
class WandBPipelineLogger(WandbLogger, IPipelineLogger):
60-
def log_experiment_info(self, cfg: TCfg) -> None:
60+
def log_pipeline_info(self, cfg: TCfg) -> None:
6161
# this is the optional dependency
6262
import wandb
6363

@@ -90,9 +90,12 @@ def log_figure(self, fig: plt.Figure, title: str, idx: int) -> None:
9090

9191

9292
class TensorBoardPipelineLogger(TensorBoardLogger, IPipelineLogger):
93-
def log_experiment_info(self, cfg: TCfg) -> None:
93+
def log_pipeline_info(self, cfg: TCfg) -> None:
9494
pass
9595

9696
def log_figure(self, fig: plt.Figure, title: str, idx: int) -> None:
9797
fig_img = figure_to_nparray(fig)
9898
self.experiment.add_image(title, np.transpose(fig_img, (2, 0, 1)), idx)
99+
100+
101+
__all__ = ["IPipelineLogger", "TensorBoardPipelineLogger", "WandBPipelineLogger", "NeptunePipelineLogger"]

oml/lightning/pipelines/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def extractor_training_pipeline(cfg: TCfg) -> None:
7272
pprint(cfg)
7373

7474
logger = parse_logger_from_config(cfg)
75-
logger.log_experiment_info(cfg)
75+
logger.log_pipeline_info(cfg)
7676

7777
trainer_engine_params = parse_engine_params_from_config(cfg)
7878
is_ddp = check_is_config_for_ddp(trainer_engine_params)

oml/lightning/pipelines/train_postprocessor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def postprocessor_training_pipeline(cfg: DictConfig) -> None:
120120
pprint(cfg)
121121

122122
logger = parse_logger_from_config(cfg)
123-
logger.log_experiment_info(cfg)
123+
logger.log_pipeline_info(cfg)
124124

125125
trainer_engine_params = parse_engine_params_from_config(cfg)
126126
is_ddp = check_is_config_for_ddp(trainer_engine_params)

0 commit comments

Comments
 (0)