1
- from typing import Any , Dict , List , Sequence , Union
1
+ from typing import Any , Dict , List , Optional , Tuple , Union
2
2
3
3
import numpy as np
4
4
import torch
5
- from torch . distributed import get_world_size
5
+ from torch import Tensor
6
6
7
- from oml .ddp .utils import is_ddp , sync_dicts_ddp
7
+ from oml .ddp .utils import get_world_size_safe , sync_dicts_ddp
8
+ from oml .utils .misc_torch import unique_by_ids
8
9
9
- TStorage = Dict [str , Union [torch . Tensor , np .ndarray , List [Any ]]]
10
+ TStorage = Dict [str , Union [Tensor , np .ndarray , List [Any ]]]
10
11
11
12
12
13
class Accumulator :
13
- def __init__ (self , keys_to_accumulate : Sequence [str ]):
14
+ def __init__ (self , keys_to_accumulate : Tuple [str , ... ]):
14
15
"""
15
16
Class for accumulating values of different types, for instance,
16
17
torch.Tensor and numpy.array.
@@ -27,12 +28,15 @@ def __init__(self, keys_to_accumulate: Sequence[str]):
27
28
self ._collected_samples = 0
28
29
self ._storage : TStorage = dict ()
29
30
31
+ self ._indices_key = "__element_indices" # internal key to keep track of elements order if provided
32
+
30
33
def refresh (self , num_samples : int ) -> None :
31
34
"""
32
35
This method refreshes the state.
33
36
34
37
Args:
35
- num_samples: The total number of elements you are going to collect (for memory allocation)
38
+ num_samples: The total number of elements you are going to collect (for memory allocation).
39
+
36
40
"""
37
41
assert isinstance (num_samples , int ) and num_samples > 0
38
42
self .num_samples = num_samples # type: ignore
@@ -75,20 +79,39 @@ def _put_in_storage(self, key: str, batch_value: Any) -> None:
75
79
else :
76
80
raise TypeError (f"Type '{ type (batch_value )} ' is not available for accumulating" )
77
81
78
- def update_data (self , data_dict : Dict [str , Any ]) -> None :
82
+ def update_data (self , data_dict : Dict [str , Any ], indices : Optional [ List [ int ]] = None ) -> None :
79
83
"""
80
84
Args:
81
- data_dict: We will accumulate data getting values via ``self.keys_to_accumulate``.
85
+ data_dict: We will accumulate data getting values via ``self.keys_to_accumulate``. All elements
86
+ of the dictionary have to have the same size.
87
+ indices: Global indices of the elements in your batch of data. If provided, the accumulator
88
+ will remove accumulated duplicates and return the elements in the sorted order after ``.sync()``.
89
+ Indices may be useful in DDP (because data is gathered shuffled, additionally you may also get
90
+ some duplicates due to padding). In the single device regime it's also useful if you accumulate
91
+ data in shuffled order.
82
92
83
93
"""
84
- bs_values = [len (data_dict [k ]) for k in self .keys_to_accumulate ]
94
+ keys = list (self .keys_to_accumulate )
95
+
96
+ if indices is None :
97
+ assert self ._indices_key not in self .storage , "We are tracking ids, but they are not currently provided."
98
+ else :
99
+ assert isinstance (indices , List )
100
+ if (self .collected_samples > 0 ) and (self ._indices_key not in self .storage ):
101
+ raise RuntimeError ("You provided ids, but seems like you had not done it before." )
102
+
103
+ keys += [self ._indices_key ]
104
+ data_dict [self ._indices_key ] = indices
105
+
106
+ bs_values = [len (data_dict [k ]) for k in keys ]
85
107
bs = bs_values [0 ]
86
108
assert all (bs == bs_value for bs_value in bs_values ), f"Lengths of data are not equal, lengths: { bs_values } "
87
109
88
- for k in self . keys_to_accumulate :
110
+ for k in keys :
89
111
v = data_dict [k ]
90
112
self ._allocate_memory_if_need (k , v )
91
113
self ._put_in_storage (k , v )
114
+
92
115
self ._collected_samples += bs
93
116
94
117
@property
@@ -103,31 +126,47 @@ def is_storage_full(self) -> bool:
103
126
return self .num_samples == self .collected_samples
104
127
105
128
def sync (self ) -> "Accumulator" :
129
+ """
130
+ The method drops duplicates and sort elements by indices if they have been provided in ``self.update_data()``.
131
+ In DDP it also gathers data collected on several devices.
132
+
133
+ """
106
134
# TODO: add option to broadcast instead of sync to avoid duplicating data
107
135
if not self .is_storage_full ():
108
136
raise ValueError ("Only full storages could be synced" )
109
137
110
- if is_ddp ():
111
- world_size = get_world_size ()
112
- if world_size == 1 :
113
- return self
114
- else :
115
- params = {"num_samples" : [self .num_samples ], "keys_to_accumulate" : self .keys_to_accumulate }
138
+ params = {"num_samples" : [self .num_samples ], "keys_to_accumulate" : self .keys_to_accumulate }
139
+ storage = self ._storage
116
140
117
- gathered_params = sync_dicts_ddp ( params , world_size = world_size , device = "cpu" )
118
- gathered_storage = sync_dicts_ddp ( self . _storage , world_size = world_size , device = "cpu" )
141
+ world_size = get_world_size_safe ( )
142
+ need_rebuilding = False
119
143
120
- assert set (gathered_params ["keys_to_accumulate" ]) == set (
121
- self .keys_to_accumulate
122
- ), "Keys of accumulators should be the same on each device"
144
+ if world_size > 1 :
145
+ params = sync_dicts_ddp (params , world_size = world_size , device = "cpu" )
146
+ storage = sync_dicts_ddp (self ._storage , world_size = world_size , device = "cpu" )
147
+ need_rebuilding = True
123
148
124
- synced_accum = Accumulator ( list ( set (gathered_params ["keys_to_accumulate" ])))
125
- synced_accum . refresh ( sum ( gathered_params [ "num_samples" ]))
126
- synced_accum . update_data ( gathered_storage )
149
+ assert set (params ["keys_to_accumulate" ]) == set (
150
+ self . keys_to_accumulate
151
+ ), "Keys of accumulators should be the same on each device"
127
152
128
- return synced_accum
153
+ if self ._indices_key in storage :
154
+ for key , data in storage .items ():
155
+ storage [key ] = unique_by_ids (storage [self ._indices_key ], data )[1 ] # type: ignore
156
+ indices = storage [self ._indices_key ]
157
+ need_rebuilding = True
129
158
else :
159
+ indices = None
160
+
161
+ if not need_rebuilding :
162
+ # If indices were not provided & it's not DDP we may save time & memory avoiding re-building accumulator
130
163
return self
131
164
165
+ synced_accum = Accumulator (tuple (set (params ["keys_to_accumulate" ])))
166
+ synced_accum .refresh (num_samples = len (storage [list (storage .keys ())[0 ]]))
167
+ synced_accum .update_data (storage , indices = indices )
168
+
169
+ return synced_accum
170
+
132
171
133
172
__all__ = ["TStorage" , "Accumulator" ]
0 commit comments