Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove MetricValCallbackDDP and samples_in_getitem #560

Merged
merged 3 commits into from
May 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/readme/examples_source/extractor/train_val_pl_ddp.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ from torch.optim import SGD

from oml.datasets import ImageQueryGalleryLabeledDataset, ImageLabeledDataset
from oml.lightning.modules.extractor import ExtractorModuleDDP
from oml.lightning.callbacks.metric import MetricValCallbackDDP
from oml.lightning.callbacks.metric import MetricValCallback
from oml.losses.triplet import TripletLossWithMiner
from oml.metrics.embeddings import EmbeddingMetrics
from oml.miners.inbatch_all_tri import AllTripletsMiner
Expand All @@ -35,7 +35,7 @@ train_loader = DataLoader(train_dataset, batch_sampler=batch_sampler)
# val
val_dataset = ImageQueryGalleryLabeledDataset(df_val)
val_loader = DataLoader(val_dataset, batch_size=4)
metric_callback = MetricValCallbackDDP(metric=EmbeddingMetrics(dataset=val_dataset)) # DDP specific
metric_callback = MetricValCallback(metric=EmbeddingMetrics(dataset=val_dataset))

# run
pl_model = ExtractorModuleDDP(extractor=extractor, criterion=criterion, optimizer=optimizer,
Expand Down
4 changes: 2 additions & 2 deletions docs/source/contents/ddp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ ExtractorModuleDDP

.. automethod:: __init__

MetricValCallbackDDP
PairwiseModuleDDP
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: oml.lightning.callbacks.metric.MetricValCallbackDDP
.. autoclass:: oml.lightning.modules.pairwise_postprocessing.PairwiseModuleDDP
:undoc-members:
:show-inheritance:

Expand Down
71 changes: 19 additions & 52 deletions oml/lightning/callbacks/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,12 @@ def __init__(
metric: IBasicMetric,
log_images: bool = False,
loader_idx: int = 0,
samples_in_getitem: int = 1,
):
"""
Args:
metric: Metric
log_images: Set ``True`` if you want to have visual logging
loader_idx: Idx of the loader to calculate metric for
samples_in_getitem: Some of the datasets return several samples when calling ``__getitem__``,
so we need to handle it for the proper calculation. For most of the cases this value equals to 1,
but for the dataset which explicitly return triplets, this value must be equal to 3,
for a dataset of pairs it must be equal to 2.

"""

Expand All @@ -46,7 +41,6 @@ def __init__(
assert not log_images or (isinstance(metric, IMetricVisualisable) and metric.ready_to_visualize())

self.loader_idx = loader_idx
self.samples_in_getitem = samples_in_getitem

self._expected_samples = 0
self._collected_samples = 0
Expand All @@ -56,7 +50,11 @@ def _calc_expected_samples(self, trainer: pl.Trainer, dataloader_idx: int = 0) -
loaders = (
[trainer.val_dataloaders] if isinstance(trainer.val_dataloaders, DataLoader) else trainer.val_dataloaders
)
return self.samples_in_getitem * len(loaders[dataloader_idx].dataset)
len_dataset = len(loaders[dataloader_idx].dataset)
if trainer.world_size > 1:
# we use padding in DDP and sequential sampler for validation
len_dataset = ceil(len_dataset / trainer.world_size)
return len_dataset

def on_validation_batch_start(
self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0
Expand Down Expand Up @@ -128,12 +126,23 @@ def _raise_computation_error(self) -> Exception:
raise ValueError(
f"Incorrect calculation for {self.metric.__class__.__name__} metric. "
f"Inconsistent number of samples, obtained: {self._collected_samples}, "
f"expected: {self._expected_samples}, "
f"'samples_in_getitem': {self.samples_in_getitem}.\n"
f"expected: {self._expected_samples}. "
f"Make sure that you don't use the 'overfit_batches' parameter in 'pl.Trainer' and "
f"you set 'drop_last=False'. The idea is that lengths of dataset and dataloader must match."
)

@staticmethod
def _check_loaders(trainer: "pl.Trainer") -> None:
if trainer.world_size > 1 and trainer.val_dataloaders is not None:
if not check_loaders_is_patched(trainer.val_dataloaders):
raise RuntimeError(err_message_loaders_is_not_patched)

def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._check_loaders(trainer)

def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._check_loaders(trainer)


err_message_loaders_is_not_patched = (
"\nExperiment is runned in DDP mode, but some of validation dataloaders is not patched. Metric callback will "
Expand All @@ -150,46 +159,4 @@ def _raise_computation_error(self) -> Exception:
f"5) Turn off the 'overfit_batches' parameter in 'pl.Trainer'."
)


class MetricValCallbackDDP(MetricValCallback):
"""
This is an extension to the regular callback that takes into account data reduction and padding
on the inference for each device in DDP setup

"""

metric: IBasicMetric

def __init__(self, metric: IBasicMetric, *args: Any, **kwargs: Any):
super().__init__(metric, *args, **kwargs)

def _calc_expected_samples(self, trainer: pl.Trainer, dataloader_idx: int = 0) -> int:
loaders = (
[trainer.val_dataloaders] if isinstance(trainer.val_dataloaders, DataLoader) else trainer.val_dataloaders
)
len_dataset = len(loaders[dataloader_idx].dataset)
if trainer.world_size > 1:
# we use padding in DDP and sequential sampler for validation
len_dataset = ceil(len_dataset / trainer.world_size)
return self.samples_in_getitem * len_dataset

def calc_and_log_metrics(self, pl_module: pl.LightningModule) -> None:
# TODO: optimize to avoid duplication of metrics on all devices.
# Note: if we calculate metric only on main device, we need to log (!!!) metric for all devices,
# because they need this metric for checkpointing
return super().calc_and_log_metrics(pl_module=pl_module)

@staticmethod
def _check_loaders(trainer: "pl.Trainer") -> None:
if trainer.world_size > 1 and trainer.val_dataloaders is not None:
if not check_loaders_is_patched(trainer.val_dataloaders):
raise RuntimeError(err_message_loaders_is_not_patched)

def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._check_loaders(trainer)

def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._check_loaders(trainer)


__all__ = ["MetricValCallback", "MetricValCallbackDDP"]
__all__ = ["MetricValCallback"]
8 changes: 2 additions & 6 deletions oml/lightning/pipelines/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from oml.const import TCfg
from oml.datasets.images import get_retrieval_images_datasets
from oml.lightning.callbacks.metric import MetricValCallback, MetricValCallbackDDP
from oml.lightning.callbacks.metric import MetricValCallback
from oml.lightning.modules.extractor import ExtractorModule, ExtractorModuleDDP
from oml.lightning.pipelines.parser import (
check_is_config_for_ddp,
Expand Down Expand Up @@ -109,11 +109,7 @@ def extractor_training_pipeline(cfg: TCfg) -> None:
**cfg.get("metric_args", {}),
)

metrics_clb_constructor = MetricValCallbackDDP if is_ddp else MetricValCallback
metrics_clb = metrics_clb_constructor(
metric=metrics_calc,
log_images=cfg.get("log_images", False),
)
metrics_clb = MetricValCallback(metric=metrics_calc, log_images=cfg.get("log_images", False))

trainer = pl.Trainer(
max_epochs=cfg["max_epochs"],
Expand Down
5 changes: 2 additions & 3 deletions oml/lightning/pipelines/train_postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from oml.datasets.images import get_retrieval_images_datasets
from oml.inference import inference, inference_cached
from oml.interfaces.models import IPairwiseModel
from oml.lightning.callbacks.metric import MetricValCallback, MetricValCallbackDDP
from oml.lightning.callbacks.metric import MetricValCallback
from oml.lightning.modules.pairwise_postprocessing import (
PairwiseModule,
PairwiseModuleDDP,
Expand Down Expand Up @@ -165,8 +165,7 @@ def postprocessor_training_pipeline(cfg: DictConfig) -> None:
**cfg.get("metric_args", {}),
)

metrics_clb_constructor = MetricValCallbackDDP if is_ddp else MetricValCallback
metrics_clb = metrics_clb_constructor(metric=metrics_calc, log_images=cfg.get("log_images", True))
metrics_clb = MetricValCallback(metric=metrics_calc, log_images=cfg.get("log_images", True))

trainer = pl.Trainer(
max_epochs=cfg["max_epochs"],
Expand Down
8 changes: 2 additions & 6 deletions oml/lightning/pipelines/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from oml.const import TCfg
from oml.datasets.images import get_retrieval_images_datasets
from oml.lightning.callbacks.metric import MetricValCallback, MetricValCallbackDDP
from oml.lightning.callbacks.metric import MetricValCallback
from oml.lightning.modules.extractor import ExtractorModule, ExtractorModuleDDP
from oml.lightning.pipelines.parser import (
check_is_config_for_ddp,
Expand Down Expand Up @@ -73,11 +73,7 @@ def extractor_validation_pipeline(cfg: TCfg) -> Tuple[pl.Trainer, Dict[str, Any]
postprocessor=postprocessor,
**cfg.get("metric_args", {}),
)
metrics_clb_constructor = MetricValCallbackDDP if is_ddp else MetricValCallback
clb_metric = metrics_clb_constructor(
metric=metrics_calc,
log_images=False,
)
clb_metric = MetricValCallback(metric=metrics_calc, log_images=False)

trainer = pl.Trainer(callbacks=[clb_metric], precision=cfg.get("precision", 32), **trainer_engine_params)

Expand Down
29 changes: 7 additions & 22 deletions tests/test_integrations/test_lightning/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,25 +91,14 @@ def create_retrieval_dataloader(
return train_retrieval_loader


def create_retrieval_callback(
dataset: IQueryGalleryLabeledDataset, loader_idx: int, samples_in_getitem: int
) -> MetricValCallback:
def create_retrieval_callback(dataset: IQueryGalleryLabeledDataset, loader_idx: int) -> MetricValCallback:
metric = EmbeddingMetrics(dataset=dataset)
metric_callback = MetricValCallback(metric=metric, loader_idx=loader_idx, samples_in_getitem=samples_in_getitem)
metric_callback = MetricValCallback(metric=metric, loader_idx=loader_idx)
return metric_callback


@pytest.mark.parametrize(
"samples_in_getitem, is_error_expected, pipeline",
[
(1, False, "retrieval"),
(2, True, "retrieval"),
],
)
@pytest.mark.parametrize("num_dataloaders", [1, 2])
def test_lightning(
samples_in_getitem: int, is_error_expected: bool, num_dataloaders: int, pipeline: str, num_workers: int
) -> None:
def test_lightning(num_dataloaders: int, num_workers: int) -> None:
num_samples = 12
im_size = 6
n_labels = 2
Expand All @@ -128,7 +117,8 @@ def test_lightning(
]
callbacks = [
create_retrieval_callback(
dataset=val_dataloaders[k].dataset, loader_idx=k, samples_in_getitem=samples_in_getitem
dataset=val_dataloaders[k].dataset,
loader_idx=k,
)
for k in range(num_dataloaders)
]
Expand All @@ -143,10 +133,5 @@ def test_lightning(
num_sanity_val_steps=0,
)

if is_error_expected:
with pytest.raises(ValueError, match=callbacks[0].metric.__class__.__name__):
trainer.fit(model=lightning_module, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)
trainer.validate(model=lightning_module, dataloaders=val_dataloaders)
else:
trainer.fit(model=lightning_module, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)
trainer.validate(model=lightning_module, dataloaders=val_dataloaders)
trainer.fit(model=lightning_module, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders)
trainer.validate(model=lightning_module, dataloaders=val_dataloaders)
10 changes: 6 additions & 4 deletions tests/test_oml/test_metrics/test_embedding_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,10 @@ def run_retrieval_metrics(case) -> None: # type: ignore
compare_dicts_recursively(gt_metrics, metrics)

# the euclidean distance between any one-hots is always sqrt(2) or 0
distances = calc.retrieval_results.distances # type: ignore
assert (
torch.isclose(calc.retrieval_results.distances, torch.tensor([0.0])).any()
or torch.isclose(calc.retrieval_results.distances, torch.tensor([math.sqrt(2)])).any()
torch.isclose(distances, torch.tensor([0.0])).any()
or torch.isclose(distances, torch.tensor([math.sqrt(2)])).any()
)

assert calc.acc.collected_samples == num_samples
Expand Down Expand Up @@ -166,9 +167,10 @@ def run_across_epochs(case) -> None: # type: ignore
assert compare_dicts_recursively(metrics_all_epochs[0], metrics_all_epochs[-1])

# the euclidean distance between any one-hots is always sqrt(2) or 0
distances = calc.retrieval_results.distances # type: ignore
assert (
torch.isclose(calc.retrieval_results.distances, torch.tensor([0.0])).any()
or torch.isclose(calc.retrieval_results.distances, torch.tensor([math.sqrt(2)])).any()
torch.isclose(distances, torch.tensor([0.0])).any()
or torch.isclose(distances, torch.tensor([math.sqrt(2)])).any()
)

assert calc.acc.collected_samples == num_samples
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from oml.const import INDEX_KEY, INPUT_TENSORS_KEY, LABELS_KEY, TMP_PATH
from oml.ddp.utils import sync_dicts_ddp
from oml.lightning.callbacks.metric import MetricValCallbackDDP
from oml.lightning.callbacks.metric import MetricValCallback
from oml.lightning.modules.ddp import ModuleDDP
from oml.lightning.pipelines.parser import parse_engine_params_from_config
from oml.losses.arcface import ArcFaceLoss
Expand Down Expand Up @@ -138,7 +138,7 @@ def on_train_end(self) -> None:
torch.save(self.model, self.save_path_ckpt_pattern.format(experiment=self.exp_num))


class MetricValCallbackWithSaving(MetricValCallbackDDP):
class MetricValCallbackWithSaving(MetricValCallback):
"""
We add saving of metrics for later comparison
"""
Expand Down