Skip to content

Commit b6e403b

Browse files
committed
minor upd
1 parent b7a3fdd commit b6e403b

File tree

6 files changed

+12
-22
lines changed

6 files changed

+12
-22
lines changed

ml-runs/0/meta.yaml

-6
This file was deleted.

oml/inference/abstract.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,7 @@ def _inference(
5252
ids, outputs = unique_by_ids(ids=ids, data=outputs)
5353

5454
assert len(outputs) == len(dataset), "Data was not collected correctly after DDP sync."
55-
assert list(range(len(dataset))) == ids, (
56-
list(range(len(dataset))),
57-
ids,
58-
"zzz",
59-
) # , "Data was not collected correctly after DDP sync."
55+
assert list(range(len(dataset))) == ids, "Data was not collected correctly after DDP sync."
6056

6157
return outputs
6258

oml/metrics/accumulation.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def update_data(self, data_dict: Dict[str, Any], indices: Optional[List[int]] =
8686
data_dict: We will accumulate data getting values via ``self.keys_to_accumulate``. All elements
8787
of the dictionary have to have the same size.
8888
indices: Global indices of the elements in your batch of data. If provided, the accumulator
89-
will remove all the accumulated duplicates and return the elements in sorted order.
89+
will remove accumulated duplicates and return the elements in the sorted order after ``.sync()``.
9090
Indices may be useful in DDP (because data is gathered shuffled, additionally you may also get
9191
some duplicates due to padding). In the single device regime it's also useful if you accumulate
9292
data in shuffled order.
@@ -125,7 +125,7 @@ def is_storage_full(self) -> bool:
125125

126126
def sync(self) -> "Accumulator":
127127
"""
128-
The method drops duplicates if ids have been provided in ``self.update_data``.
128+
The method drops duplicates and sort elements by indices if they have been provided in ``self.update_data()``.
129129
In DDP it also gathers data collected on several devices.
130130
131131
"""
@@ -157,7 +157,7 @@ def sync(self) -> "Accumulator":
157157
indices = None
158158

159159
if not need_rebuilding:
160-
# If we found no duplicates and there are no multiple devices, we may save time & memory on re-creating
160+
# If indices were not provided & it's not DDP we may save time & memory avoiding re-building accumulator
161161
return self
162162

163163
synced_accum = Accumulator(tuple(set(params["keys_to_accumulate"])))

oml/metrics/embeddings.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,10 @@ def setup(self, num_samples: int) -> None: # type: ignore
164164
def update_data(self, data_dict: Dict[str, Any], indices: Optional[List[int]] = None) -> None: # type: ignore
165165
"""
166166
Args:
167-
data_dict: Batch of data containing elements of the same size: ``bs``.
168-
indices: Global indices of the elements in your batch of data withing the range ``(0, dataset_size - 1)``.
167+
data_dict: Batch of data containing records of the same size: ``bs``.
168+
indices: Global indices of the elements in your records withing the range of ``(0, dataset_size - 1)``.
169169
Indices are needed in DDP (because data is gathered shuffled, additionally you may also get
170-
some duplicates due to padding). In the single device regime it's also useful if you accumulate
170+
some duplicates due to padding). In the single device regime it's may be useful if you accumulate
171171
data in shuffled order.
172172
173173
"""
@@ -392,7 +392,7 @@ def sync(self) -> None:
392392
self.acc = self.acc.sync()
393393

394394
def update_data(self, data_dict: Dict[str, Any], indices: List[int]) -> None: # type: ignore
395-
# indices are obligatory in DDP
395+
# indices are obligatory in DDP, so we don't accumulate shuffled data
396396
return super().update_data(data_dict, indices)
397397

398398

oml/utils/misc_torch.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,12 @@ def _check_is_sequence(val: Any) -> bool:
129129

130130
def unique_by_ids(ids: List[int], data: TData) -> Tuple[List[int], TData]:
131131
"""
132-
The function sort data by the corresponding ids and drops duplicates.
132+
The function sort data by the corresponding indices and drops duplicates.
133133
Thus, if there are multiple occurrences of the same id, it takes the first one.
134134
135135
Args:
136-
ids: Indices of data records with the length of ``N``
137-
data: Data records with the lengths of ``N``
136+
ids: Indices of data with the length of ``N``
137+
data: Data with the length of ``N``
138138
139139
Returns:
140140
Unique data records with their ids in the sorted order without duplicates

tests/test_oml/test_ddp/test_accumulator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_ddp_accumulator(world_size: int, device: str, create_duplicate: bool) -
2020
@pytest.mark.parametrize("device", ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"])
2121
@pytest.mark.parametrize("create_duplicate", [True, False])
2222
def test_fake_ddp_accumulator(device: str, create_duplicate: bool) -> None:
23-
# we expect the same behaviour outside DDP
23+
# we expect the same behaviour without initializing DDP
2424
check_accumulator(rank=0, world_size=1, device=device, create_duplicate=create_duplicate)
2525

2626

0 commit comments

Comments
 (0)