Skip to content

Commit 1a16e74

Browse files
authored
True epoch stats (#248)
* stats q
1 parent fe33cfd commit 1a16e74

File tree

1 file changed

+72
-16
lines changed

1 file changed

+72
-16
lines changed

ml4cvd/tensor_generators.py

Lines changed: 72 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,13 @@ def __init__(
8181
self.augment = augment
8282
self.run_on_main_thread = num_workers == 0
8383
self.q = None
84+
self.stats_q = None
8485
self._started = False
8586
self.workers = []
8687
self.worker_instances = []
8788
self.batch_size, self.input_maps, self.output_maps, self.num_workers, self.cache_size, self.weights, self.name, self.keep_paths = \
8889
batch_size, input_maps, output_maps, num_workers, cache_size, weights, name, keep_paths
90+
self.true_epochs = 0
8991
if num_workers == 0:
9092
num_workers = 1 # The one worker is the main thread
9193
if weights is None:
@@ -117,11 +119,14 @@ def __init__(
117119

118120
def _init_workers(self):
119121
self.q = Queue(min(self.batch_size, TENSOR_GENERATOR_MAX_Q_SIZE))
122+
self.stats_q = Queue(len(self.worker_instances))
120123
self._started = True
121124
for i, (path_iter, iter_len) in enumerate(zip(self.path_iters, self.true_epoch_lens)):
122125
name = f'{self.name}_{i}'
123126
worker_instance = _MultiModalMultiTaskWorker(
124127
self.q,
128+
self.stats_q,
129+
self.num_workers,
125130
self.input_maps, self.output_maps,
126131
path_iter, iter_len,
127132
self.batch_function, self.batch_size, self.keep_paths, self.batch_function_kwargs,
@@ -153,8 +158,48 @@ def __next__(self) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Option
153158
if self.run_on_main_thread:
154159
return next(self.worker_instances[0])
155160
else:
161+
if self.stats_q.qsize() == self.num_workers:
162+
self.aggregate_and_print_stats()
156163
return self.q.get(TENSOR_GENERATOR_TIMEOUT)
157164

165+
def aggregate_and_print_stats(self):
166+
stats = Counter()
167+
self.true_epochs += 1
168+
while self.stats_q.qsize() != 0:
169+
stats += self.stats_q.get()
170+
171+
error_info = '\n\t\t'.join([
172+
f'[{error}] - {count:.0f}'
173+
for error, count in sorted(stats.items(), key=lambda x: x[1], reverse=True) if 'Error' in error
174+
])
175+
176+
info_string = '\n\t'.join([
177+
f"Generator looped & shuffled over {sum(self.true_epoch_lens)} paths. Epoch: {self.true_epochs:.0f}",
178+
f"{stats['Tensors presented']/self.true_epochs:0.0f} tensors were presented.",
179+
f"{stats['skipped_paths']} paths were skipped because they previously failed.",
180+
f"The following errors occurred:\n\t\t{error_info}",
181+
])
182+
logging.info(f"\n!>~~~~~~~~~~~~ {self.name} completed true epoch {self.true_epochs} ~~~~~~~~~~~~<!\nAggregated information string:\n\t{info_string}")
183+
eps = 1e-7
184+
for tm in self.input_maps + self.output_maps:
185+
if self.true_epochs != 1:
186+
break
187+
if tm.is_categorical() and tm.axes() == 1:
188+
n = stats[f'{tm.name}_n'] + eps
189+
message = f'Categorical \n{tm.name} has {n:.0f} total examples.'
190+
for channel, index in tm.channel_map.items():
191+
examples = stats[f'{tm.name}_index_{index:.0f}']
192+
message = f'{message}\n\tLabel {channel} {examples} examples, {100 * (examples / n):0.2f}% of total.'
193+
logging.info(message)
194+
elif tm.is_continuous() and tm.axes() == 1:
195+
sum_squared = stats[f'{tm.name}_sum_squared']
196+
n = stats[f'{tm.name}_n'] + eps
197+
n_sum = stats[f'{tm.name}_sum']
198+
mean = n_sum / n
199+
std = np.sqrt((sum_squared/n)-(mean*mean))
200+
logging.info(f'Continuous value \n{tm.name} Mean:{mean:0.2f} Standard Deviation:{std:0.2f} '
201+
f"Maximum:{stats[f'{tm.name}_max']:0.2f} Minimum:{stats[f'{tm.name}_min']:0.2f}")
202+
158203
def kill_workers(self):
159204
if self._started and not self.run_on_main_thread:
160205
for worker in self.workers:
@@ -242,6 +287,8 @@ class _MultiModalMultiTaskWorker:
242287
def __init__(
243288
self,
244289
q: Queue,
290+
stats_q: Queue,
291+
num_workers: int,
245292
input_maps: List[TensorMap], output_maps: List[TensorMap],
246293
path_iter: PathIterator, true_epoch_len: int,
247294
batch_function: BatchFunction, batch_size: int, return_paths: bool, batch_func_kwargs: Dict,
@@ -250,6 +297,8 @@ def __init__(
250297
augment: bool,
251298
):
252299
self.q = q
300+
self.stats_q = stats_q
301+
self.num_workers = num_workers
253302
self.input_maps = input_maps
254303
self.output_maps = output_maps
255304
self.path_iter = path_iter
@@ -281,6 +330,7 @@ def _handle_tm(self, tm: TensorMap, is_input: bool, path: Path) -> h5py.File:
281330
batch[name][idx] = self.dependents[tm]
282331
if tm.cacheable:
283332
self.cache[path, name] = self.dependents[tm]
333+
self._collect_stats(tm, self.dependents[tm])
284334
return self.hd5
285335
if (path, name) in self.cache:
286336
batch[name][idx] = self.cache[path, name]
@@ -291,8 +341,24 @@ def _handle_tm(self, tm: TensorMap, is_input: bool, path: Path) -> h5py.File:
291341
batch[name][idx] = tensor
292342
if tm.cacheable:
293343
self.cache[path, name] = tensor
344+
self._collect_stats(tm, tensor)
294345
return self.hd5
295346

347+
def _collect_stats(self, tm, tensor):
348+
if tm.is_categorical() and tm.axes() == 1:
349+
self.epoch_stats[f'{tm.name}_index_{np.argmax(tensor):.0f}'] += 1
350+
self.epoch_stats[f'{tm.name}_n'] += 1
351+
if tm.is_continuous() and tm.axes() == 1:
352+
self.epoch_stats[f'{tm.name}_n'] += 1
353+
rescaled = tm.rescale(tensor)[0]
354+
if 0.0 == self.epoch_stats[f'{tm.name}_max'] == self.epoch_stats[f'{tm.name}_min']:
355+
self.epoch_stats[f'{tm.name}_max'] = min(0, rescaled)
356+
self.epoch_stats[f'{tm.name}_min'] = max(0, rescaled)
357+
self.epoch_stats[f'{tm.name}_max'] = max(rescaled, self.epoch_stats[f'{tm.name}_max'])
358+
self.epoch_stats[f'{tm.name}_min'] = min(rescaled, self.epoch_stats[f'{tm.name}_min'])
359+
self.epoch_stats[f'{tm.name}_sum'] += rescaled
360+
self.epoch_stats[f'{tm.name}_sum_squared'] += rescaled * rescaled
361+
296362
def _handle_tensor_path(self, path: Path) -> None:
297363
hd5 = None
298364
if path in self.cache.failed_paths:
@@ -307,6 +373,7 @@ def _handle_tensor_path(self, path: Path) -> None:
307373
hd5 = self._handle_tm(tm, False, path)
308374
self.paths_in_batch.append(path)
309375
self.stats['Tensors presented'] += 1
376+
self.epoch_stats['Tensors presented'] += 1
310377
self.stats['batch_index'] += 1
311378
except (IndexError, KeyError, ValueError, OSError, RuntimeError) as e:
312379
error_name = type(e).__name__
@@ -320,23 +387,12 @@ def _handle_tensor_path(self, path: Path) -> None:
320387

321388
def _on_epoch_end(self):
322389
self.stats['epochs'] += 1
323-
for k in self.stats:
324-
logging.debug(f"{k}: {self.stats[k]}")
325-
error_info = '\n\t\t'.join([
326-
f'[{error}] - {count}'
327-
for error, count in sorted(self.epoch_stats.items(), key=lambda x: x[1], reverse=True)
328-
])
329-
info_string = '\n\t'.join([
330-
f"The following errors occurred:\n\t\t{error_info}",
331-
f"Generator looped & shuffled over {self.true_epoch_len} paths.",
332-
f"{int(self.stats['Tensors presented']/self.stats['epochs'])} tensors were presented.",
333-
f"{self.epoch_stats['skipped_paths']} paths were skipped because they previously failed.",
334-
str(self.cache),
335-
f"{(time.time() - self.start):.2f} seconds elapsed.",
336-
])
337-
logging.info(f"Worker {self.name} - In true epoch {self.stats['epochs']}:\n\t{info_string}")
390+
self.epoch_stats['epochs'] = self.stats['epochs']
391+
while self.stats_q.qsize() == self.num_workers:
392+
continue
393+
self.stats_q.put(self.epoch_stats)
338394
if self.stats['Tensors presented'] == 0:
339-
raise ValueError(f"Completed an epoch but did not find any tensors to yield")
395+
logging.error(f"Completed an epoch but did not find any tensors to yield")
340396
if 'test' in self.name:
341397
logging.warning(f'Test worker {self.name} completed a full epoch. Test results may be double counting samples.')
342398
self.start = time.time()

0 commit comments

Comments
 (0)