Skip to content

Commit ce27d8e

Browse files
authored
Removing duplicates and preserving order in Accumulator
Removing duplicates and preserving order in Accumulator
1 parent 9a1305e commit ce27d8e

File tree

9 files changed

+280
-95
lines changed

9 files changed

+280
-95
lines changed

oml/inference/abstract.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,7 @@
77

88
from oml.ddp.patching import patch_dataloader_to_ddp
99
from oml.ddp.utils import get_world_size_safe, is_ddp, sync_dicts_ddp
10-
from oml.utils.misc_torch import (
11-
drop_duplicates_by_ids,
12-
get_device,
13-
temporary_setting_model_mode,
14-
)
10+
from oml.utils.misc_torch import get_device, temporary_setting_model_mode, unique_by_ids
1511

1612

1713
@torch.no_grad()
@@ -53,7 +49,7 @@ def _inference(
5349
data_synced = sync_dicts_ddp(data_to_sync, world_size=get_world_size_safe())
5450
outputs, ids = data_synced["outputs"], data_synced["ids"]
5551

56-
ids, outputs = drop_duplicates_by_ids(ids=ids, data=outputs, sort=True)
52+
ids, outputs = unique_by_ids(ids=ids, data=outputs)
5753

5854
assert len(outputs) == len(dataset), "Data was not collected correctly after DDP sync."
5955
assert list(range(len(dataset))) == ids, "Data was not collected correctly after DDP sync."

oml/lightning/callbacks/metric.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pytorch_lightning.utilities.types import STEP_OUTPUT
99
from torch.utils.data import DataLoader
1010

11-
from oml.const import LOG_IMAGE_FOLDER
11+
from oml.const import INDEX_KEY, LOG_IMAGE_FOLDER
1212
from oml.ddp.patching import check_loaders_is_patched, patch_dataloader_to_ddp
1313
from oml.interfaces.loggers import IFigureLogger
1414
from oml.interfaces.metrics import IBasicMetric, IMetricDDP, IMetricVisualisable
@@ -83,7 +83,7 @@ def on_validation_batch_end(
8383
if dataloader_idx == self.loader_idx:
8484
assert self._ready_to_accumulate
8585

86-
self.metric.update_data(outputs)
86+
self.metric.update_data(outputs, indices=outputs[INDEX_KEY].tolist())
8787

8888
self._collected_samples += len(outputs[list(outputs.keys())[0]])
8989
if self._collected_samples > self._expected_samples:

oml/metrics/accumulation.py

+64-25
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1-
from typing import Any, Dict, List, Sequence, Union
1+
from typing import Any, Dict, List, Optional, Tuple, Union
22

33
import numpy as np
44
import torch
5-
from torch.distributed import get_world_size
5+
from torch import Tensor
66

7-
from oml.ddp.utils import is_ddp, sync_dicts_ddp
7+
from oml.ddp.utils import get_world_size_safe, sync_dicts_ddp
8+
from oml.utils.misc_torch import unique_by_ids
89

9-
TStorage = Dict[str, Union[torch.Tensor, np.ndarray, List[Any]]]
10+
TStorage = Dict[str, Union[Tensor, np.ndarray, List[Any]]]
1011

1112

1213
class Accumulator:
13-
def __init__(self, keys_to_accumulate: Sequence[str]):
14+
def __init__(self, keys_to_accumulate: Tuple[str, ...]):
1415
"""
1516
Class for accumulating values of different types, for instance,
1617
torch.Tensor and numpy.array.
@@ -27,12 +28,15 @@ def __init__(self, keys_to_accumulate: Sequence[str]):
2728
self._collected_samples = 0
2829
self._storage: TStorage = dict()
2930

31+
self._indices_key = "__element_indices" # internal key to keep track of elements order if provided
32+
3033
def refresh(self, num_samples: int) -> None:
3134
"""
3235
This method refreshes the state.
3336
3437
Args:
35-
num_samples: The total number of elements you are going to collect (for memory allocation)
38+
num_samples: The total number of elements you are going to collect (for memory allocation).
39+
3640
"""
3741
assert isinstance(num_samples, int) and num_samples > 0
3842
self.num_samples = num_samples # type: ignore
@@ -75,20 +79,39 @@ def _put_in_storage(self, key: str, batch_value: Any) -> None:
7579
else:
7680
raise TypeError(f"Type '{type(batch_value)}' is not available for accumulating")
7781

78-
def update_data(self, data_dict: Dict[str, Any]) -> None:
82+
def update_data(self, data_dict: Dict[str, Any], indices: Optional[List[int]] = None) -> None:
7983
"""
8084
Args:
81-
data_dict: We will accumulate data getting values via ``self.keys_to_accumulate``.
85+
data_dict: We will accumulate data getting values via ``self.keys_to_accumulate``. All elements
86+
of the dictionary have to have the same size.
87+
indices: Global indices of the elements in your batch of data. If provided, the accumulator
88+
will remove accumulated duplicates and return the elements in the sorted order after ``.sync()``.
89+
Indices may be useful in DDP (because data is gathered shuffled, additionally you may also get
90+
some duplicates due to padding). In the single device regime it's also useful if you accumulate
91+
data in shuffled order.
8292
8393
"""
84-
bs_values = [len(data_dict[k]) for k in self.keys_to_accumulate]
94+
keys = list(self.keys_to_accumulate)
95+
96+
if indices is None:
97+
assert self._indices_key not in self.storage, "We are tracking ids, but they are not currently provided."
98+
else:
99+
assert isinstance(indices, List)
100+
if (self.collected_samples > 0) and (self._indices_key not in self.storage):
101+
raise RuntimeError("You provided ids, but seems like you had not done it before.")
102+
103+
keys += [self._indices_key]
104+
data_dict[self._indices_key] = indices
105+
106+
bs_values = [len(data_dict[k]) for k in keys]
85107
bs = bs_values[0]
86108
assert all(bs == bs_value for bs_value in bs_values), f"Lengths of data are not equal, lengths: {bs_values}"
87109

88-
for k in self.keys_to_accumulate:
110+
for k in keys:
89111
v = data_dict[k]
90112
self._allocate_memory_if_need(k, v)
91113
self._put_in_storage(k, v)
114+
92115
self._collected_samples += bs
93116

94117
@property
@@ -103,31 +126,47 @@ def is_storage_full(self) -> bool:
103126
return self.num_samples == self.collected_samples
104127

105128
def sync(self) -> "Accumulator":
129+
"""
130+
The method drops duplicates and sort elements by indices if they have been provided in ``self.update_data()``.
131+
In DDP it also gathers data collected on several devices.
132+
133+
"""
106134
# TODO: add option to broadcast instead of sync to avoid duplicating data
107135
if not self.is_storage_full():
108136
raise ValueError("Only full storages could be synced")
109137

110-
if is_ddp():
111-
world_size = get_world_size()
112-
if world_size == 1:
113-
return self
114-
else:
115-
params = {"num_samples": [self.num_samples], "keys_to_accumulate": self.keys_to_accumulate}
138+
params = {"num_samples": [self.num_samples], "keys_to_accumulate": self.keys_to_accumulate}
139+
storage = self._storage
116140

117-
gathered_params = sync_dicts_ddp(params, world_size=world_size, device="cpu")
118-
gathered_storage = sync_dicts_ddp(self._storage, world_size=world_size, device="cpu")
141+
world_size = get_world_size_safe()
142+
need_rebuilding = False
119143

120-
assert set(gathered_params["keys_to_accumulate"]) == set(
121-
self.keys_to_accumulate
122-
), "Keys of accumulators should be the same on each device"
144+
if world_size > 1:
145+
params = sync_dicts_ddp(params, world_size=world_size, device="cpu")
146+
storage = sync_dicts_ddp(self._storage, world_size=world_size, device="cpu")
147+
need_rebuilding = True
123148

124-
synced_accum = Accumulator(list(set(gathered_params["keys_to_accumulate"])))
125-
synced_accum.refresh(sum(gathered_params["num_samples"]))
126-
synced_accum.update_data(gathered_storage)
149+
assert set(params["keys_to_accumulate"]) == set(
150+
self.keys_to_accumulate
151+
), "Keys of accumulators should be the same on each device"
127152

128-
return synced_accum
153+
if self._indices_key in storage:
154+
for key, data in storage.items():
155+
storage[key] = unique_by_ids(storage[self._indices_key], data)[1] # type: ignore
156+
indices = storage[self._indices_key]
157+
need_rebuilding = True
129158
else:
159+
indices = None
160+
161+
if not need_rebuilding:
162+
# If indices were not provided & it's not DDP we may save time & memory avoiding re-building accumulator
130163
return self
131164

165+
synced_accum = Accumulator(tuple(set(params["keys_to_accumulate"])))
166+
synced_accum.refresh(num_samples=len(storage[list(storage.keys())[0]]))
167+
synced_accum.update_data(storage, indices=indices)
168+
169+
return synced_accum
170+
132171

133172
__all__ = ["TStorage", "Accumulator"]

oml/metrics/embeddings.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,18 @@ def setup(self, num_samples: int) -> None: # type: ignore
161161

162162
self.acc.refresh(num_samples=num_samples)
163163

164-
def update_data(self, data_dict: Dict[str, Any]) -> None: # type: ignore
165-
self.acc.update_data(data_dict=data_dict)
164+
def update_data(self, data_dict: Dict[str, Any], indices: Optional[List[int]] = None) -> None: # type: ignore
165+
"""
166+
Args:
167+
data_dict: Batch of data containing records of the same size: ``bs``.
168+
indices: Global indices of the elements in your records within the range of ``(0, dataset_size - 1)``.
169+
Indices are needed in DDP (because data is gathered shuffled, additionally you may also get
170+
some duplicates due to padding). In the single device regime it's may be useful if you accumulate
171+
data in shuffled order.
172+
173+
"""
174+
# todo 522: make indices non optional and add the test
175+
self.acc.update_data(data_dict=data_dict, indices=indices)
166176

167177
def _calc_matrices(self) -> None:
168178
embeddings = self.acc.storage[self.embeddings_key]
@@ -382,5 +392,9 @@ class EmbeddingMetricsDDP(EmbeddingMetrics, IMetricDDP):
382392
def sync(self) -> None:
383393
self.acc = self.acc.sync()
384394

395+
def update_data(self, data_dict: Dict[str, Any], indices: List[int]) -> None: # type: ignore
396+
# indices are obligatory in DDP
397+
return super().update_data(data_dict, indices)
398+
385399

386400
__all__ = ["TMetricsDict_ByLabels", "EmbeddingMetrics", "EmbeddingMetricsDDP"]

oml/utils/misc_torch.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
import torch
88
from torch import Tensor, cdist
99

10-
from oml.utils.misc import find_first_occurrences
11-
1210
TSingleValues = Union[int, float, np.float_, np.int_, torch.Tensor]
1311
TSequenceValues = Union[List[float], Tuple[float, ...], np.ndarray, torch.Tensor]
1412
TOnlineValues = Union[TSingleValues, TSequenceValues]
@@ -126,31 +124,33 @@ def _check_is_sequence(val: Any) -> bool:
126124
return False
127125

128126

129-
def drop_duplicates_by_ids(ids: List[Hashable], data: Tensor, sort: bool = True) -> Tuple[List[Hashable], Tensor]:
127+
TData = Tuple[List[Any], Tensor, np.ndarray]
128+
129+
130+
def unique_by_ids(ids: List[int], data: TData) -> Tuple[List[int], TData]:
130131
"""
131-
The function returns rows of data that have unique ids.
132-
Thus, if there are multiple occurrences of some id, it leaves the first one.
132+
The function sort data by the corresponding indices and drops duplicates.
133+
Thus, if there are multiple occurrences of the same id, it takes the first one.
133134
134135
Args:
135-
ids: Identifiers of data records with the length of ``N``
136-
data: Tensor of data records in the shape of ``[N, *]``
137-
sort: Set ``True`` to return unique records sorted by their ids
136+
ids: Indices of data with the length of ``N``
137+
data: Data with the length of ``N``
138138
139139
Returns:
140-
Unique data records with their ids
140+
Unique data records with their ids in the sorted order without duplicates
141141
142142
"""
143-
assert isinstance(ids, list)
144-
ids_first = find_first_occurrences(ids)
145-
ids = [ids[i] for i in ids_first]
146-
data = data[ids_first]
143+
assert len(ids) == len(data)
144+
assert isinstance(ids, list) and len(ids) >= 1
145+
146+
ids_unq, positions_unq = np.unique(ids, return_index=True)
147147

148-
if sort:
149-
ii_permute = torch.argsort(torch.tensor(ids))
150-
ids = [ids[i] for i in ii_permute]
151-
data = data[ii_permute]
148+
if isinstance(data, (list, tuple)):
149+
data = [data[i] for i in positions_unq] # type: ignore
150+
else:
151+
data = data[positions_unq]
152152

153-
return ids, data
153+
return ids_unq.tolist(), data
154154

155155

156156
@contextmanager
@@ -465,6 +465,6 @@ def _check_dimensions(self, n_components: int) -> None:
465465
"take_2d",
466466
"assign_2d",
467467
"PCA",
468-
"drop_duplicates_by_ids",
468+
"unique_by_ids",
469469
"normalise",
470470
]

tests/test_integrations/test_lightning/test_pipeline.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from oml.const import (
1414
EMBEDDINGS_KEY,
15+
INDEX_KEY,
1516
INPUT_TENSORS_KEY,
1617
IS_GALLERY_KEY,
1718
IS_QUERY_KEY,
@@ -31,7 +32,13 @@ def __init__(self, labels: List[int], im_size: int):
3132
def __getitem__(self, item: int) -> Dict[str, Any]:
3233
input_tensors = torch.rand((3, self.im_size, self.im_size))
3334
label = torch.tensor(self.labels[item]).long()
34-
return {INPUT_TENSORS_KEY: input_tensors, LABELS_KEY: label, IS_QUERY_KEY: True, IS_GALLERY_KEY: True}
35+
return {
36+
INPUT_TENSORS_KEY: input_tensors,
37+
LABELS_KEY: label,
38+
IS_QUERY_KEY: True,
39+
IS_GALLERY_KEY: True,
40+
INDEX_KEY: item,
41+
}
3542

3643
def __len__(self) -> int:
3744
return len(self.labels)

tests/test_oml/test_ddp/test_accumulator.py

+41-11
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,61 @@
1212
@pytest.mark.long
1313
@pytest.mark.parametrize("world_size", [1, 2, 3])
1414
@pytest.mark.parametrize("device", ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"])
15-
def test_ddp_accumulator(world_size: int, device: str) -> None:
16-
run_in_ddp(world_size=world_size, fn=check_ddp_accumulator, args=(device,))
15+
@pytest.mark.parametrize("create_duplicate", [True, False])
16+
def test_ddp_accumulator(world_size: int, device: str, create_duplicate: bool) -> None:
17+
run_in_ddp(world_size=world_size, fn=check_ddp_accumulator, args=(device, create_duplicate))
1718

1819

19-
def check_ddp_accumulator(rank: int, world_size: int, device: str) -> None:
20+
@pytest.mark.parametrize("device", ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"])
21+
@pytest.mark.parametrize("create_duplicate", [True, False])
22+
def test_fake_ddp_accumulator(device: str, create_duplicate: bool) -> None:
23+
# we expect the same duplicate removing behaviour without initializing DDP
24+
check_accumulator(rank=0, world_size=1, device=device, create_duplicate=create_duplicate)
25+
26+
27+
def check_ddp_accumulator(rank: int, world_size: int, device: str, create_duplicate: bool) -> None:
2028
init_ddp(rank, world_size)
29+
check_accumulator(rank, world_size, device, create_duplicate)
2130

31+
32+
def check_accumulator(rank: int, world_size: int, device: str, create_duplicate: bool) -> None:
2233
value = rank + 1
34+
size = value
35+
36+
indices = {0: [0], 1: [1, 2], 2: [3, 4, 5]}[rank]
37+
38+
if create_duplicate and (rank == 0):
39+
# let's pretend we doubled our single record at the rank 0
40+
size = 2
41+
indices = [0, 0]
2342

2443
data = {
25-
"list": [value] * value,
26-
"tensor_1d": value * torch.ones(value, device=device),
27-
"tensor_3d": value * torch.ones((value, 2, 3), device=device),
28-
"numpy_1d": value * np.ones(value),
29-
"numpy_3d": value * np.ones((value, 4, 5)),
44+
"list": [value] * size,
45+
"tensor_1d": value * torch.ones(size, device=device),
46+
"tensor_3d": value * torch.ones((size, 2, 3), device=device),
47+
"numpy_1d": value * np.ones(size),
48+
"numpy_3d": value * np.ones((size, 4, 5)),
3049
}
3150

32-
acc = Accumulator(keys_to_accumulate=list(data.keys()))
51+
acc = Accumulator(keys_to_accumulate=tuple(data.keys()))
3352
acc.refresh(len(data["list"]))
34-
acc.update_data(data)
53+
acc.update_data(data, indices=indices)
54+
55+
acc_synced = acc.sync()
56+
synced_data = acc_synced.storage
57+
synced_num_samples = acc_synced.num_samples
3558

36-
synced_data = acc.sync().storage
59+
assert acc_synced.is_storage_full()
3760

3861
len_after_sync = sum(range(1, world_size + 1))
3962

63+
indices_synced = synced_data[acc._indices_key]
64+
65+
assert len_after_sync == synced_num_samples
66+
67+
assert len(indices_synced) == len(set(indices_synced))
68+
assert sorted(indices_synced) == list(range(len_after_sync))
69+
4070
assert len(synced_data["list"]) == len_after_sync
4171

4272
assert synced_data["tensor_1d"].ndim == 1 # type: ignore

0 commit comments

Comments
 (0)