Skip to content

Commit cd9fc85

Browse files
authored
Removed MetricValCallbackDDP and samples_in_getitem
* `MetricValCallback` is enough to handle DDP * `samples_in_getitem` is not used
1 parent 8fabaa8 commit cd9fc85

File tree

9 files changed

+44
-99
lines changed

9 files changed

+44
-99
lines changed

docs/readme/examples_source/extractor/train_val_pl_ddp.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ from torch.optim import SGD
1111

1212
from oml.datasets import ImageQueryGalleryLabeledDataset, ImageLabeledDataset
1313
from oml.lightning.modules.extractor import ExtractorModuleDDP
14-
from oml.lightning.callbacks.metric import MetricValCallbackDDP
14+
from oml.lightning.callbacks.metric import MetricValCallback
1515
from oml.losses.triplet import TripletLossWithMiner
1616
from oml.metrics.embeddings import EmbeddingMetrics
1717
from oml.miners.inbatch_all_tri import AllTripletsMiner
@@ -35,7 +35,7 @@ train_loader = DataLoader(train_dataset, batch_sampler=batch_sampler)
3535
# val
3636
val_dataset = ImageQueryGalleryLabeledDataset(df_val)
3737
val_loader = DataLoader(val_dataset, batch_size=4)
38-
metric_callback = MetricValCallbackDDP(metric=EmbeddingMetrics(dataset=val_dataset)) # DDP specific
38+
metric_callback = MetricValCallback(metric=EmbeddingMetrics(dataset=val_dataset))
3939

4040
# run
4141
pl_model = ExtractorModuleDDP(extractor=extractor, criterion=criterion, optimizer=optimizer,

docs/source/contents/ddp.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ ExtractorModuleDDP
2626

2727
.. automethod:: __init__
2828

29-
MetricValCallbackDDP
29+
PairwiseModuleDDP
3030
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
31-
.. autoclass:: oml.lightning.callbacks.metric.MetricValCallbackDDP
31+
.. autoclass:: oml.lightning.modules.pairwise_postprocessing.PairwiseModuleDDP
3232
:undoc-members:
3333
:show-inheritance:
3434

oml/lightning/callbacks/metric.py

+19-52
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,12 @@ def __init__(
2727
metric: IBasicMetric,
2828
log_images: bool = False,
2929
loader_idx: int = 0,
30-
samples_in_getitem: int = 1,
3130
):
3231
"""
3332
Args:
3433
metric: Metric
3534
log_images: Set ``True`` if you want to have visual logging
3635
loader_idx: Idx of the loader to calculate metric for
37-
samples_in_getitem: Some of the datasets return several samples when calling ``__getitem__``,
38-
so we need to handle it for the proper calculation. For most of the cases this value equals to 1,
39-
but for the dataset which explicitly return triplets, this value must be equal to 3,
40-
for a dataset of pairs it must be equal to 2.
4136
4237
"""
4338

@@ -46,7 +41,6 @@ def __init__(
4641
assert not log_images or (isinstance(metric, IMetricVisualisable) and metric.ready_to_visualize())
4742

4843
self.loader_idx = loader_idx
49-
self.samples_in_getitem = samples_in_getitem
5044

5145
self._expected_samples = 0
5246
self._collected_samples = 0
@@ -56,7 +50,11 @@ def _calc_expected_samples(self, trainer: pl.Trainer, dataloader_idx: int = 0) -
5650
loaders = (
5751
[trainer.val_dataloaders] if isinstance(trainer.val_dataloaders, DataLoader) else trainer.val_dataloaders
5852
)
59-
return self.samples_in_getitem * len(loaders[dataloader_idx].dataset)
53+
len_dataset = len(loaders[dataloader_idx].dataset)
54+
if trainer.world_size > 1:
55+
# we use padding in DDP and sequential sampler for validation
56+
len_dataset = ceil(len_dataset / trainer.world_size)
57+
return len_dataset
6058

6159
def on_validation_batch_start(
6260
self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
@@ -128,12 +126,23 @@ def _raise_computation_error(self) -> Exception:
128126
raise ValueError(
129127
f"Incorrect calculation for {self.metric.__class__.__name__} metric. "
130128
f"Inconsistent number of samples, obtained: {self._collected_samples}, "
131-
f"expected: {self._expected_samples}, "
132-
f"'samples_in_getitem': {self.samples_in_getitem}.\n"
129+
f"expected: {self._expected_samples}. "
133130
f"Make sure that you don't use the 'overfit_batches' parameter in 'pl.Trainer' and "
134131
f"you set 'drop_last=False'. The idea is that lengths of dataset and dataloader must match."
135132
)
136133

134+
@staticmethod
135+
def _check_loaders(trainer: "pl.Trainer") -> None:
136+
if trainer.world_size > 1 and trainer.val_dataloaders is not None:
137+
if not check_loaders_is_patched(trainer.val_dataloaders):
138+
raise RuntimeError(err_message_loaders_is_not_patched)
139+
140+
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
141+
self._check_loaders(trainer)
142+
143+
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
144+
self._check_loaders(trainer)
145+
137146

138147
err_message_loaders_is_not_patched = (
139148
"\nExperiment is runned in DDP mode, but some of validation dataloaders is not patched. Metric callback will "
@@ -150,46 +159,4 @@ def _raise_computation_error(self) -> Exception:
150159
f"5) Turn off the 'overfit_batches' parameter in 'pl.Trainer'."
151160
)
152161

153-
154-
class MetricValCallbackDDP(MetricValCallback):
155-
"""
156-
This is an extension to the regular callback that takes into account data reduction and padding
157-
on the inference for each device in DDP setup
158-
159-
"""
160-
161-
metric: IBasicMetric
162-
163-
def __init__(self, metric: IBasicMetric, *args: Any, **kwargs: Any):
164-
super().__init__(metric, *args, **kwargs)
165-
166-
def _calc_expected_samples(self, trainer: pl.Trainer, dataloader_idx: int = 0) -> int:
167-
loaders = (
168-
[trainer.val_dataloaders] if isinstance(trainer.val_dataloaders, DataLoader) else trainer.val_dataloaders
169-
)
170-
len_dataset = len(loaders[dataloader_idx].dataset)
171-
if trainer.world_size > 1:
172-
# we use padding in DDP and sequential sampler for validation
173-
len_dataset = ceil(len_dataset / trainer.world_size)
174-
return self.samples_in_getitem * len_dataset
175-
176-
def calc_and_log_metrics(self, pl_module: pl.LightningModule) -> None:
177-
# TODO: optimize to avoid duplication of metrics on all devices.
178-
# Note: if we calculate metric only on main device, we need to log (!!!) metric for all devices,
179-
# because they need this metric for checkpointing
180-
return super().calc_and_log_metrics(pl_module=pl_module)
181-
182-
@staticmethod
183-
def _check_loaders(trainer: "pl.Trainer") -> None:
184-
if trainer.world_size > 1 and trainer.val_dataloaders is not None:
185-
if not check_loaders_is_patched(trainer.val_dataloaders):
186-
raise RuntimeError(err_message_loaders_is_not_patched)
187-
188-
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
189-
self._check_loaders(trainer)
190-
191-
def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
192-
self._check_loaders(trainer)
193-
194-
195-
__all__ = ["MetricValCallback", "MetricValCallbackDDP"]
162+
__all__ = ["MetricValCallback"]

oml/lightning/pipelines/train.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from oml.const import TCfg
99
from oml.datasets.images import get_retrieval_images_datasets
10-
from oml.lightning.callbacks.metric import MetricValCallback, MetricValCallbackDDP
10+
from oml.lightning.callbacks.metric import MetricValCallback
1111
from oml.lightning.modules.extractor import ExtractorModule, ExtractorModuleDDP
1212
from oml.lightning.pipelines.parser import (
1313
check_is_config_for_ddp,
@@ -109,11 +109,7 @@ def extractor_training_pipeline(cfg: TCfg) -> None:
109109
**cfg.get("metric_args", {}),
110110
)
111111

112-
metrics_clb_constructor = MetricValCallbackDDP if is_ddp else MetricValCallback
113-
metrics_clb = metrics_clb_constructor(
114-
metric=metrics_calc,
115-
log_images=cfg.get("log_images", False),
116-
)
112+
metrics_clb = MetricValCallback(metric=metrics_calc, log_images=cfg.get("log_images", False))
117113

118114
trainer = pl.Trainer(
119115
max_epochs=cfg["max_epochs"],

oml/lightning/pipelines/train_postprocessor.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from oml.datasets.images import get_retrieval_images_datasets
1515
from oml.inference import inference, inference_cached
1616
from oml.interfaces.models import IPairwiseModel
17-
from oml.lightning.callbacks.metric import MetricValCallback, MetricValCallbackDDP
17+
from oml.lightning.callbacks.metric import MetricValCallback
1818
from oml.lightning.modules.pairwise_postprocessing import (
1919
PairwiseModule,
2020
PairwiseModuleDDP,
@@ -165,8 +165,7 @@ def postprocessor_training_pipeline(cfg: DictConfig) -> None:
165165
**cfg.get("metric_args", {}),
166166
)
167167

168-
metrics_clb_constructor = MetricValCallbackDDP if is_ddp else MetricValCallback
169-
metrics_clb = metrics_clb_constructor(metric=metrics_calc, log_images=cfg.get("log_images", True))
168+
metrics_clb = MetricValCallback(metric=metrics_calc, log_images=cfg.get("log_images", True))
170169

171170
trainer = pl.Trainer(
172171
max_epochs=cfg["max_epochs"],

oml/lightning/pipelines/validate.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from oml.const import TCfg
99
from oml.datasets.images import get_retrieval_images_datasets
10-
from oml.lightning.callbacks.metric import MetricValCallback, MetricValCallbackDDP
10+
from oml.lightning.callbacks.metric import MetricValCallback
1111
from oml.lightning.modules.extractor import ExtractorModule, ExtractorModuleDDP
1212
from oml.lightning.pipelines.parser import (
1313
check_is_config_for_ddp,
@@ -73,11 +73,7 @@ def extractor_validation_pipeline(cfg: TCfg) -> Tuple[pl.Trainer, Dict[str, Any]
7373
postprocessor=postprocessor,
7474
**cfg.get("metric_args", {}),
7575
)
76-
metrics_clb_constructor = MetricValCallbackDDP if is_ddp else MetricValCallback
77-
clb_metric = metrics_clb_constructor(
78-
metric=metrics_calc,
79-
log_images=False,
80-
)
76+
clb_metric = MetricValCallback(metric=metrics_calc, log_images=False)
8177

8278
trainer = pl.Trainer(callbacks=[clb_metric], precision=cfg.get("precision", 32), **trainer_engine_params)
8379

tests/test_integrations/test_lightning/test_pipeline.py

+7-22
Original file line numberDiff line numberDiff line change
@@ -91,25 +91,14 @@ def create_retrieval_dataloader(
9191
return train_retrieval_loader
9292

9393

94-
def create_retrieval_callback(
95-
dataset: IQueryGalleryLabeledDataset, loader_idx: int, samples_in_getitem: int
96-
) -> MetricValCallback:
94+
def create_retrieval_callback(dataset: IQueryGalleryLabeledDataset, loader_idx: int) -> MetricValCallback:
9795
metric = EmbeddingMetrics(dataset=dataset)
98-
metric_callback = MetricValCallback(metric=metric, loader_idx=loader_idx, samples_in_getitem=samples_in_getitem)
96+
metric_callback = MetricValCallback(metric=metric, loader_idx=loader_idx)
9997
return metric_callback
10098

10199

102-
@pytest.mark.parametrize(
103-
"samples_in_getitem, is_error_expected, pipeline",
104-
[
105-
(1, False, "retrieval"),
106-
(2, True, "retrieval"),
107-
],
108-
)
109100
@pytest.mark.parametrize("num_dataloaders", [1, 2])
110-
def test_lightning(
111-
samples_in_getitem: int, is_error_expected: bool, num_dataloaders: int, pipeline: str, num_workers: int
112-
) -> None:
101+
def test_lightning(num_dataloaders: int, num_workers: int) -> None:
113102
num_samples = 12
114103
im_size = 6
115104
n_labels = 2
@@ -128,7 +117,8 @@ def test_lightning(
128117
]
129118
callbacks = [
130119
create_retrieval_callback(
131-
dataset=val_dataloaders[k].dataset, loader_idx=k, samples_in_getitem=samples_in_getitem
120+
dataset=val_dataloaders[k].dataset,
121+
loader_idx=k,
132122
)
133123
for k in range(num_dataloaders)
134124
]
@@ -143,10 +133,5 @@ def test_lightning(
143133
num_sanity_val_steps=0,
144134
)
145135

146-
if is_error_expected:
147-
with pytest.raises(ValueError, match=callbacks[0].metric.__class__.__name__):
148-
trainer.fit(model=lightning_module, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)
149-
trainer.validate(model=lightning_module, dataloaders=val_dataloaders)
150-
else:
151-
trainer.fit(model=lightning_module, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)
152-
trainer.validate(model=lightning_module, dataloaders=val_dataloaders)
136+
trainer.fit(model=lightning_module, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)
137+
trainer.validate(model=lightning_module, dataloaders=val_dataloaders)

tests/test_oml/test_metrics/test_embedding_metrics.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,10 @@ def run_retrieval_metrics(case) -> None: # type: ignore
129129
compare_dicts_recursively(gt_metrics, metrics)
130130

131131
# the euclidean distance between any one-hots is always sqrt(2) or 0
132+
distances = calc.retrieval_results.distances # type: ignore
132133
assert (
133-
torch.isclose(calc.retrieval_results.distances, torch.tensor([0.0])).any()
134-
or torch.isclose(calc.retrieval_results.distances, torch.tensor([math.sqrt(2)])).any()
134+
torch.isclose(distances, torch.tensor([0.0])).any()
135+
or torch.isclose(distances, torch.tensor([math.sqrt(2)])).any()
135136
)
136137

137138
assert calc.acc.collected_samples == num_samples
@@ -166,9 +167,10 @@ def run_across_epochs(case) -> None: # type: ignore
166167
assert compare_dicts_recursively(metrics_all_epochs[0], metrics_all_epochs[-1])
167168

168169
# the euclidean distance between any one-hots is always sqrt(2) or 0
170+
distances = calc.retrieval_results.distances # type: ignore
169171
assert (
170-
torch.isclose(calc.retrieval_results.distances, torch.tensor([0.0])).any()
171-
or torch.isclose(calc.retrieval_results.distances, torch.tensor([math.sqrt(2)])).any()
172+
torch.isclose(distances, torch.tensor([0.0])).any()
173+
or torch.isclose(distances, torch.tensor([math.sqrt(2)])).any()
172174
)
173175

174176
assert calc.acc.collected_samples == num_samples

tests/test_runs/test_ddp_cases/run_retrieval_experiment_ddp.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from oml.const import INDEX_KEY, INPUT_TENSORS_KEY, LABELS_KEY, TMP_PATH
1414
from oml.ddp.utils import sync_dicts_ddp
15-
from oml.lightning.callbacks.metric import MetricValCallbackDDP
15+
from oml.lightning.callbacks.metric import MetricValCallback
1616
from oml.lightning.modules.ddp import ModuleDDP
1717
from oml.lightning.pipelines.parser import parse_engine_params_from_config
1818
from oml.losses.arcface import ArcFaceLoss
@@ -138,7 +138,7 @@ def on_train_end(self) -> None:
138138
torch.save(self.model, self.save_path_ckpt_pattern.format(experiment=self.exp_num))
139139

140140

141-
class MetricValCallbackWithSaving(MetricValCallbackDDP):
141+
class MetricValCallbackWithSaving(MetricValCallback):
142142
"""
143143
We add saving of metrics for later comparison
144144
"""

0 commit comments

Comments
 (0)