@@ -27,17 +27,12 @@ def __init__(
27
27
metric : IBasicMetric ,
28
28
log_images : bool = False ,
29
29
loader_idx : int = 0 ,
30
- samples_in_getitem : int = 1 ,
31
30
):
32
31
"""
33
32
Args:
34
33
metric: Metric
35
34
log_images: Set ``True`` if you want to have visual logging
36
35
loader_idx: Idx of the loader to calculate metric for
37
- samples_in_getitem: Some of the datasets return several samples when calling ``__getitem__``,
38
- so we need to handle it for the proper calculation. For most of the cases this value equals to 1,
39
- but for the dataset which explicitly return triplets, this value must be equal to 3,
40
- for a dataset of pairs it must be equal to 2.
41
36
42
37
"""
43
38
@@ -46,7 +41,6 @@ def __init__(
46
41
assert not log_images or (isinstance (metric , IMetricVisualisable ) and metric .ready_to_visualize ())
47
42
48
43
self .loader_idx = loader_idx
49
- self .samples_in_getitem = samples_in_getitem
50
44
51
45
self ._expected_samples = 0
52
46
self ._collected_samples = 0
@@ -56,7 +50,11 @@ def _calc_expected_samples(self, trainer: pl.Trainer, dataloader_idx: int = 0) -
56
50
loaders = (
57
51
[trainer .val_dataloaders ] if isinstance (trainer .val_dataloaders , DataLoader ) else trainer .val_dataloaders
58
52
)
59
- return self .samples_in_getitem * len (loaders [dataloader_idx ].dataset )
53
+ len_dataset = len (loaders [dataloader_idx ].dataset )
54
+ if trainer .world_size > 1 :
55
+ # we use padding in DDP and sequential sampler for validation
56
+ len_dataset = ceil (len_dataset / trainer .world_size )
57
+ return len_dataset
60
58
61
59
def on_validation_batch_start (
62
60
self , trainer : pl .Trainer , pl_module : pl .LightningModule , batch : Any , batch_idx : int , dataloader_idx : int = 0
@@ -128,12 +126,23 @@ def _raise_computation_error(self) -> Exception:
128
126
raise ValueError (
129
127
f"Incorrect calculation for { self .metric .__class__ .__name__ } metric. "
130
128
f"Inconsistent number of samples, obtained: { self ._collected_samples } , "
131
- f"expected: { self ._expected_samples } , "
132
- f"'samples_in_getitem': { self .samples_in_getitem } .\n "
129
+ f"expected: { self ._expected_samples } . "
133
130
f"Make sure that you don't use the 'overfit_batches' parameter in 'pl.Trainer' and "
134
131
f"you set 'drop_last=False'. The idea is that lengths of dataset and dataloader must match."
135
132
)
136
133
134
+ @staticmethod
135
+ def _check_loaders (trainer : "pl.Trainer" ) -> None :
136
+ if trainer .world_size > 1 and trainer .val_dataloaders is not None :
137
+ if not check_loaders_is_patched (trainer .val_dataloaders ):
138
+ raise RuntimeError (err_message_loaders_is_not_patched )
139
+
140
+ def on_train_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
141
+ self ._check_loaders (trainer )
142
+
143
+ def on_validation_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
144
+ self ._check_loaders (trainer )
145
+
137
146
138
147
err_message_loaders_is_not_patched = (
139
148
"\n Experiment is runned in DDP mode, but some of validation dataloaders is not patched. Metric callback will "
@@ -150,46 +159,4 @@ def _raise_computation_error(self) -> Exception:
150
159
f"5) Turn off the 'overfit_batches' parameter in 'pl.Trainer'."
151
160
)
152
161
153
-
154
- class MetricValCallbackDDP (MetricValCallback ):
155
- """
156
- This is an extension to the regular callback that takes into account data reduction and padding
157
- on the inference for each device in DDP setup
158
-
159
- """
160
-
161
- metric : IBasicMetric
162
-
163
- def __init__ (self , metric : IBasicMetric , * args : Any , ** kwargs : Any ):
164
- super ().__init__ (metric , * args , ** kwargs )
165
-
166
- def _calc_expected_samples (self , trainer : pl .Trainer , dataloader_idx : int = 0 ) -> int :
167
- loaders = (
168
- [trainer .val_dataloaders ] if isinstance (trainer .val_dataloaders , DataLoader ) else trainer .val_dataloaders
169
- )
170
- len_dataset = len (loaders [dataloader_idx ].dataset )
171
- if trainer .world_size > 1 :
172
- # we use padding in DDP and sequential sampler for validation
173
- len_dataset = ceil (len_dataset / trainer .world_size )
174
- return self .samples_in_getitem * len_dataset
175
-
176
- def calc_and_log_metrics (self , pl_module : pl .LightningModule ) -> None :
177
- # TODO: optimize to avoid duplication of metrics on all devices.
178
- # Note: if we calculate metric only on main device, we need to log (!!!) metric for all devices,
179
- # because they need this metric for checkpointing
180
- return super ().calc_and_log_metrics (pl_module = pl_module )
181
-
182
- @staticmethod
183
- def _check_loaders (trainer : "pl.Trainer" ) -> None :
184
- if trainer .world_size > 1 and trainer .val_dataloaders is not None :
185
- if not check_loaders_is_patched (trainer .val_dataloaders ):
186
- raise RuntimeError (err_message_loaders_is_not_patched )
187
-
188
- def on_train_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
189
- self ._check_loaders (trainer )
190
-
191
- def on_validation_start (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
192
- self ._check_loaders (trainer )
193
-
194
-
195
- __all__ = ["MetricValCallback" , "MetricValCallbackDDP" ]
162
+ __all__ = ["MetricValCallback" ]
0 commit comments