@@ -81,11 +81,13 @@ def __init__(
81
81
self .augment = augment
82
82
self .run_on_main_thread = num_workers == 0
83
83
self .q = None
84
+ self .stats_q = None
84
85
self ._started = False
85
86
self .workers = []
86
87
self .worker_instances = []
87
88
self .batch_size , self .input_maps , self .output_maps , self .num_workers , self .cache_size , self .weights , self .name , self .keep_paths = \
88
89
batch_size , input_maps , output_maps , num_workers , cache_size , weights , name , keep_paths
90
+ self .true_epochs = 0
89
91
if num_workers == 0 :
90
92
num_workers = 1 # The one worker is the main thread
91
93
if weights is None :
@@ -117,11 +119,14 @@ def __init__(
117
119
118
120
def _init_workers (self ):
119
121
self .q = Queue (min (self .batch_size , TENSOR_GENERATOR_MAX_Q_SIZE ))
122
+ self .stats_q = Queue (len (self .worker_instances ))
120
123
self ._started = True
121
124
for i , (path_iter , iter_len ) in enumerate (zip (self .path_iters , self .true_epoch_lens )):
122
125
name = f'{ self .name } _{ i } '
123
126
worker_instance = _MultiModalMultiTaskWorker (
124
127
self .q ,
128
+ self .stats_q ,
129
+ self .num_workers ,
125
130
self .input_maps , self .output_maps ,
126
131
path_iter , iter_len ,
127
132
self .batch_function , self .batch_size , self .keep_paths , self .batch_function_kwargs ,
@@ -153,8 +158,48 @@ def __next__(self) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Option
153
158
if self .run_on_main_thread :
154
159
return next (self .worker_instances [0 ])
155
160
else :
161
+ if self .stats_q .qsize () == self .num_workers :
162
+ self .aggregate_and_print_stats ()
156
163
return self .q .get (TENSOR_GENERATOR_TIMEOUT )
157
164
165
+ def aggregate_and_print_stats (self ):
166
+ stats = Counter ()
167
+ self .true_epochs += 1
168
+ while self .stats_q .qsize () != 0 :
169
+ stats += self .stats_q .get ()
170
+
171
+ error_info = '\n \t \t ' .join ([
172
+ f'[{ error } ] - { count :.0f} '
173
+ for error , count in sorted (stats .items (), key = lambda x : x [1 ], reverse = True ) if 'Error' in error
174
+ ])
175
+
176
+ info_string = '\n \t ' .join ([
177
+ f"Generator looped & shuffled over { sum (self .true_epoch_lens )} paths. Epoch: { self .true_epochs :.0f} " ,
178
+ f"{ stats ['Tensors presented' ]/ self .true_epochs :0.0f} tensors were presented." ,
179
+ f"{ stats ['skipped_paths' ]} paths were skipped because they previously failed." ,
180
+ f"The following errors occurred:\n \t \t { error_info } " ,
181
+ ])
182
+ logging .info (f"\n !>~~~~~~~~~~~~ { self .name } completed true epoch { self .true_epochs } ~~~~~~~~~~~~<!\n Aggregated information string:\n \t { info_string } " )
183
+ eps = 1e-7
184
+ for tm in self .input_maps + self .output_maps :
185
+ if self .true_epochs != 1 :
186
+ break
187
+ if tm .is_categorical () and tm .axes () == 1 :
188
+ n = stats [f'{ tm .name } _n' ] + eps
189
+ message = f'Categorical \n { tm .name } has { n :.0f} total examples.'
190
+ for channel , index in tm .channel_map .items ():
191
+ examples = stats [f'{ tm .name } _index_{ index :.0f} ' ]
192
+ message = f'{ message } \n \t Label { channel } { examples } examples, { 100 * (examples / n ):0.2f} % of total.'
193
+ logging .info (message )
194
+ elif tm .is_continuous () and tm .axes () == 1 :
195
+ sum_squared = stats [f'{ tm .name } _sum_squared' ]
196
+ n = stats [f'{ tm .name } _n' ] + eps
197
+ n_sum = stats [f'{ tm .name } _sum' ]
198
+ mean = n_sum / n
199
+ std = np .sqrt ((sum_squared / n )- (mean * mean ))
200
+ logging .info (f'Continuous value \n { tm .name } Mean:{ mean :0.2f} Standard Deviation:{ std :0.2f} '
201
+ f"Maximum:{ stats [f'{ tm .name } _max' ]:0.2f} Minimum:{ stats [f'{ tm .name } _min' ]:0.2f} " )
202
+
158
203
def kill_workers (self ):
159
204
if self ._started and not self .run_on_main_thread :
160
205
for worker in self .workers :
@@ -242,6 +287,8 @@ class _MultiModalMultiTaskWorker:
242
287
def __init__ (
243
288
self ,
244
289
q : Queue ,
290
+ stats_q : Queue ,
291
+ num_workers : int ,
245
292
input_maps : List [TensorMap ], output_maps : List [TensorMap ],
246
293
path_iter : PathIterator , true_epoch_len : int ,
247
294
batch_function : BatchFunction , batch_size : int , return_paths : bool , batch_func_kwargs : Dict ,
@@ -250,6 +297,8 @@ def __init__(
250
297
augment : bool ,
251
298
):
252
299
self .q = q
300
+ self .stats_q = stats_q
301
+ self .num_workers = num_workers
253
302
self .input_maps = input_maps
254
303
self .output_maps = output_maps
255
304
self .path_iter = path_iter
@@ -281,6 +330,7 @@ def _handle_tm(self, tm: TensorMap, is_input: bool, path: Path) -> h5py.File:
281
330
batch [name ][idx ] = self .dependents [tm ]
282
331
if tm .cacheable :
283
332
self .cache [path , name ] = self .dependents [tm ]
333
+ self ._collect_stats (tm , self .dependents [tm ])
284
334
return self .hd5
285
335
if (path , name ) in self .cache :
286
336
batch [name ][idx ] = self .cache [path , name ]
@@ -291,8 +341,24 @@ def _handle_tm(self, tm: TensorMap, is_input: bool, path: Path) -> h5py.File:
291
341
batch [name ][idx ] = tensor
292
342
if tm .cacheable :
293
343
self .cache [path , name ] = tensor
344
+ self ._collect_stats (tm , tensor )
294
345
return self .hd5
295
346
347
+ def _collect_stats (self , tm , tensor ):
348
+ if tm .is_categorical () and tm .axes () == 1 :
349
+ self .epoch_stats [f'{ tm .name } _index_{ np .argmax (tensor ):.0f} ' ] += 1
350
+ self .epoch_stats [f'{ tm .name } _n' ] += 1
351
+ if tm .is_continuous () and tm .axes () == 1 :
352
+ self .epoch_stats [f'{ tm .name } _n' ] += 1
353
+ rescaled = tm .rescale (tensor )[0 ]
354
+ if 0.0 == self .epoch_stats [f'{ tm .name } _max' ] == self .epoch_stats [f'{ tm .name } _min' ]:
355
+ self .epoch_stats [f'{ tm .name } _max' ] = min (0 , rescaled )
356
+ self .epoch_stats [f'{ tm .name } _min' ] = max (0 , rescaled )
357
+ self .epoch_stats [f'{ tm .name } _max' ] = max (rescaled , self .epoch_stats [f'{ tm .name } _max' ])
358
+ self .epoch_stats [f'{ tm .name } _min' ] = min (rescaled , self .epoch_stats [f'{ tm .name } _min' ])
359
+ self .epoch_stats [f'{ tm .name } _sum' ] += rescaled
360
+ self .epoch_stats [f'{ tm .name } _sum_squared' ] += rescaled * rescaled
361
+
296
362
def _handle_tensor_path (self , path : Path ) -> None :
297
363
hd5 = None
298
364
if path in self .cache .failed_paths :
@@ -307,6 +373,7 @@ def _handle_tensor_path(self, path: Path) -> None:
307
373
hd5 = self ._handle_tm (tm , False , path )
308
374
self .paths_in_batch .append (path )
309
375
self .stats ['Tensors presented' ] += 1
376
+ self .epoch_stats ['Tensors presented' ] += 1
310
377
self .stats ['batch_index' ] += 1
311
378
except (IndexError , KeyError , ValueError , OSError , RuntimeError ) as e :
312
379
error_name = type (e ).__name__
@@ -320,23 +387,12 @@ def _handle_tensor_path(self, path: Path) -> None:
320
387
321
388
def _on_epoch_end (self ):
322
389
self .stats ['epochs' ] += 1
323
- for k in self .stats :
324
- logging .debug (f"{ k } : { self .stats [k ]} " )
325
- error_info = '\n \t \t ' .join ([
326
- f'[{ error } ] - { count } '
327
- for error , count in sorted (self .epoch_stats .items (), key = lambda x : x [1 ], reverse = True )
328
- ])
329
- info_string = '\n \t ' .join ([
330
- f"The following errors occurred:\n \t \t { error_info } " ,
331
- f"Generator looped & shuffled over { self .true_epoch_len } paths." ,
332
- f"{ int (self .stats ['Tensors presented' ]/ self .stats ['epochs' ])} tensors were presented." ,
333
- f"{ self .epoch_stats ['skipped_paths' ]} paths were skipped because they previously failed." ,
334
- str (self .cache ),
335
- f"{ (time .time () - self .start ):.2f} seconds elapsed." ,
336
- ])
337
- logging .info (f"Worker { self .name } - In true epoch { self .stats ['epochs' ]} :\n \t { info_string } " )
390
+ self .epoch_stats ['epochs' ] = self .stats ['epochs' ]
391
+ while self .stats_q .qsize () == self .num_workers :
392
+ continue
393
+ self .stats_q .put (self .epoch_stats )
338
394
if self .stats ['Tensors presented' ] == 0 :
339
- raise ValueError (f"Completed an epoch but did not find any tensors to yield" )
395
+ logging . error (f"Completed an epoch but did not find any tensors to yield" )
340
396
if 'test' in self .name :
341
397
logging .warning (f'Test worker { self .name } completed a full epoch. Test results may be double counting samples.' )
342
398
self .start = time .time ()
0 commit comments