diff --git a/train.py b/train.py index 816f4ae804..86f567d2cc 100755 --- a/train.py +++ b/train.py @@ -64,7 +64,12 @@ has_functorch = True except ImportError as e: has_functorch = False - +#test tensorboard install +try: + from torch.utils.tensorboard import SummaryWriter + has_tensorboard = True +except ImportError as e: + has_tensorboard = False has_compile = hasattr(torch, 'compile') @@ -347,8 +352,8 @@ help='use the multi-epochs-loader to save time at the beginning of every epoch') group.add_argument('--log-wandb', action='store_true', default=False, help='log training and validation metrics to wandb') - - +group.add_argument('--log-tensorboard', default='', type=str, metavar='PATH', + help='log training and validation metrics to TensorBoard') def _parse_args(): # Do we have a config file to parse? args_config, remaining = config_parser.parse_known_args() @@ -725,6 +730,18 @@ def main(): _logger.warning( "You've requested to log metrics to wandb but package not found. " "Metrics not being logged to wandb, try `pip install wandb`") + tensorboard_writer = None + if should_log_to_tensorboard(args): + if has_tensorboard: + tensorboard_writer = SummaryWriter(args.log_tensorboard) + + + else: + _logger.warning( + "You've requested to log metrics to tensorboard but package not found. " + "Metrics not being logged to tensorboard, try `pip install tensorboard`") + + # setup learning rate schedule and starting epoch updates_per_epoch = len(loader_train) @@ -770,6 +787,7 @@ def main(): loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, + tensorboard_writer=tensorboard_writer, ) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): @@ -783,6 +801,8 @@ def main(): validate_loss_fn, args, amp_autocast=amp_autocast, + tensorboard_writer=tensorboard_writer, + epoch=epoch, ) if model_ema is not None and not args.model_ema_force_cpu: @@ -796,6 +816,8 @@ def main(): args, amp_autocast=amp_autocast, log_suffix=' (EMA)', + tensorboard_writer=tensorboard_writer, + epoch=epoch, ) eval_metrics = ema_eval_metrics @@ -809,6 +831,7 @@ def main(): lr=sum(lrs) / len(lrs), write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb, + ) if saver is not None: @@ -825,6 +848,8 @@ def main(): if best_metric is not None: _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) + if should_log_to_tensorboard(args) and tensorboard_writer is not None: + tensorboard_writer.close() def train_one_epoch( @@ -841,7 +866,8 @@ def train_one_epoch( amp_autocast=suppress, loss_scaler=None, model_ema=None, - mixup_fn=None + mixup_fn=None, + tensorboard_writer=None, ): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: @@ -903,6 +929,10 @@ def train_one_epoch( num_updates += 1 batch_time_m.update(time.time() - end) + #write to tensorboard if enabled + if should_log_to_tensorboard(args): + tensorboard_writer.add_scalar('train/loss', losses_m.val, num_updates) + tensorboard_writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], num_updates) if last_batch or batch_idx % args.log_interval == 0: lrl = [param_group['lr'] for param_group in optimizer.param_groups] lr = sum(lrl) / len(lrl) @@ -954,6 +984,10 @@ def train_one_epoch( return OrderedDict([('loss', losses_m.avg)]) +def should_log_to_tensorboard(args): + return args.log_tensorboard and utils.is_primary(args) and has_tensorboard + + def validate( model, loader, @@ -961,7 +995,10 @@ def validate( args, device=torch.device('cuda'), amp_autocast=suppress, - log_suffix='' + log_suffix='', + tensorboard_writer=None, + epoch=None, + ): batch_time_m = utils.AverageMeter() losses_m = utils.AverageMeter() @@ -1011,6 +1048,11 @@ def validate( batch_time_m.update(time.time() - end) end = time.time() + if should_log_to_tensorboard(args) and epoch is not None: + #by the updates + tensorboard_writer.add_scalar('val/loss', losses_m.val, epoch*last_idx+batch_idx) + tensorboard_writer.add_scalar('val/acc1', top1_m.val, epoch*last_idx+batch_idx) + tensorboard_writer.add_scalar('val/acc5', top5_m.val, epoch*last_idx+batch_idx) if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0): log_name = 'Test' + log_suffix _logger.info(