@@ -27,17 +27,12 @@ def __init__(
2727 metric : IBasicMetric ,
2828 log_images : bool = False ,
2929 loader_idx : int = 0 ,
30- samples_in_getitem : int = 1 ,
3130 ):
3231 """
3332 Args:
3433 metric: Metric
3534 log_images: Set ``True`` if you want to have visual logging
3635 loader_idx: Idx of the loader to calculate metric for
37- samples_in_getitem: Some of the datasets return several samples when calling ``__getitem__``,
38- so we need to handle it for the proper calculation. For most of the cases this value equals to 1,
39- but for the dataset which explicitly return triplets, this value must be equal to 3,
40- for a dataset of pairs it must be equal to 2.
4136
4237 """
4338
@@ -46,7 +41,6 @@ def __init__(
4641 assert not log_images or (isinstance (metric , IMetricVisualisable ) and metric .ready_to_visualize ())
4742
4843 self .loader_idx = loader_idx
49- self .samples_in_getitem = samples_in_getitem
5044
5145 self ._expected_samples = 0
5246 self ._collected_samples = 0
@@ -56,7 +50,11 @@ def _calc_expected_samples(self, trainer: pl.Trainer, dataloader_idx: int = 0) -
5650 loaders = (
5751 [trainer .val_dataloaders ] if isinstance (trainer .val_dataloaders , DataLoader ) else trainer .val_dataloaders
5852 )
59- return self .samples_in_getitem * len (loaders [dataloader_idx ].dataset )
53+ len_dataset = len (loaders [dataloader_idx ].dataset )
54+ if trainer .world_size > 1 :
55+ # we use padding in DDP and sequential sampler for validation
56+ len_dataset = ceil (len_dataset / trainer .world_size )
57+ return len_dataset
6058
6159 def on_validation_batch_start (
6260 self , trainer : pl .Trainer , pl_module : pl .LightningModule , batch : Any , batch_idx : int , dataloader_idx : int = 0
@@ -128,12 +126,23 @@ def _raise_computation_error(self) -> Exception:
128126 raise ValueError (
129127 f"Incorrect calculation for { self .metric .__class__ .__name__ } metric. "
130128 f"Inconsistent number of samples, obtained: { self ._collected_samples } , "
131- f"expected: { self ._expected_samples } , "
132- f"'samples_in_getitem': { self .samples_in_getitem } .\n "
129+ f"expected: { self ._expected_samples } . "
133130 f"Make sure that you don't use the 'overfit_batches' parameter in 'pl.Trainer' and "
134131 f"you set 'drop_last=False'. The idea is that lengths of dataset and dataloader must match."
135132 )
136133
134+ @staticmethod
135+ def _check_loaders (trainer : "pl.Trainer" ) -> None :
136+ if trainer .world_size > 1 and trainer .val_dataloaders is not None :
137+ if not check_loaders_is_patched (trainer .val_dataloaders ):
138+ raise RuntimeError (err_message_loaders_is_not_patched )
139+
140+ def on_train_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
141+ self ._check_loaders (trainer )
142+
143+ def on_validation_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
144+ self ._check_loaders (trainer )
145+
137146
138147err_message_loaders_is_not_patched = (
139148 "\n Experiment is runned in DDP mode, but some of validation dataloaders is not patched. Metric callback will "
@@ -150,46 +159,4 @@ def _raise_computation_error(self) -> Exception:
150159 f"5) Turn off the 'overfit_batches' parameter in 'pl.Trainer'."
151160)
152161
153-
154- class MetricValCallbackDDP (MetricValCallback ):
155- """
156- This is an extension to the regular callback that takes into account data reduction and padding
157- on the inference for each device in DDP setup
158-
159- """
160-
161- metric : IBasicMetric
162-
163- def __init__ (self , metric : IBasicMetric , * args : Any , ** kwargs : Any ):
164- super ().__init__ (metric , * args , ** kwargs )
165-
166- def _calc_expected_samples (self , trainer : pl .Trainer , dataloader_idx : int = 0 ) -> int :
167- loaders = (
168- [trainer .val_dataloaders ] if isinstance (trainer .val_dataloaders , DataLoader ) else trainer .val_dataloaders
169- )
170- len_dataset = len (loaders [dataloader_idx ].dataset )
171- if trainer .world_size > 1 :
172- # we use padding in DDP and sequential sampler for validation
173- len_dataset = ceil (len_dataset / trainer .world_size )
174- return self .samples_in_getitem * len_dataset
175-
176- def calc_and_log_metrics (self , pl_module : pl .LightningModule ) -> None :
177- # TODO: optimize to avoid duplication of metrics on all devices.
178- # Note: if we calculate metric only on main device, we need to log (!!!) metric for all devices,
179- # because they need this metric for checkpointing
180- return super ().calc_and_log_metrics (pl_module = pl_module )
181-
182- @staticmethod
183- def _check_loaders (trainer : "pl.Trainer" ) -> None :
184- if trainer .world_size > 1 and trainer .val_dataloaders is not None :
185- if not check_loaders_is_patched (trainer .val_dataloaders ):
186- raise RuntimeError (err_message_loaders_is_not_patched )
187-
188- def on_train_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
189- self ._check_loaders (trainer )
190-
191- def on_validation_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
192- self ._check_loaders (trainer )
193-
194-
195- __all__ = ["MetricValCallback" , "MetricValCallbackDDP" ]
162+ __all__ = ["MetricValCallback" ]
0 commit comments