Skip to content

add Tensorboard support to the training script (step based) #1720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
52 changes: 47 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@
has_functorch = True
except ImportError as e:
has_functorch = False

#test tensorboard install
try:
from torch.utils.tensorboard import SummaryWriter
has_tensorboard = True
except ImportError as e:
has_tensorboard = False
has_compile = hasattr(torch, 'compile')


Expand Down Expand Up @@ -347,8 +352,8 @@
help='use the multi-epochs-loader to save time at the beginning of every epoch')
group.add_argument('--log-wandb', action='store_true', default=False,
help='log training and validation metrics to wandb')


group.add_argument('--log-tensorboard', default='', type=str, metavar='PATH',
help='log training and validation metrics to TensorBoard')
def _parse_args():
# Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args()
Expand Down Expand Up @@ -725,6 +730,18 @@ def main():
_logger.warning(
"You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
tensorboard_writer = None
if should_log_to_tensorboard(args):
if has_tensorboard:
tensorboard_writer = SummaryWriter(args.log_tensorboard)


else:
_logger.warning(
"You've requested to log metrics to tensorboard but package not found. "
"Metrics not being logged to tensorboard, try `pip install tensorboard`")



# setup learning rate schedule and starting epoch
updates_per_epoch = len(loader_train)
Expand Down Expand Up @@ -770,6 +787,7 @@ def main():
loss_scaler=loss_scaler,
model_ema=model_ema,
mixup_fn=mixup_fn,
tensorboard_writer=tensorboard_writer,
)

if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
Expand All @@ -783,6 +801,8 @@ def main():
validate_loss_fn,
args,
amp_autocast=amp_autocast,
tensorboard_writer=tensorboard_writer,
epoch=epoch,
)

if model_ema is not None and not args.model_ema_force_cpu:
Expand All @@ -796,6 +816,8 @@ def main():
args,
amp_autocast=amp_autocast,
log_suffix=' (EMA)',
tensorboard_writer=tensorboard_writer,
epoch=epoch,
)
eval_metrics = ema_eval_metrics

Expand All @@ -809,6 +831,7 @@ def main():
lr=sum(lrs) / len(lrs),
write_header=best_metric is None,
log_wandb=args.log_wandb and has_wandb,

)

if saver is not None:
Expand All @@ -825,6 +848,8 @@ def main():

if best_metric is not None:
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
if should_log_to_tensorboard(args) and tensorboard_writer is not None:
tensorboard_writer.close()


def train_one_epoch(
Expand All @@ -841,7 +866,8 @@ def train_one_epoch(
amp_autocast=suppress,
loss_scaler=None,
model_ema=None,
mixup_fn=None
mixup_fn=None,
tensorboard_writer=None,
):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
Expand Down Expand Up @@ -903,6 +929,10 @@ def train_one_epoch(

num_updates += 1
batch_time_m.update(time.time() - end)
#write to tensorboard if enabled
if should_log_to_tensorboard(args):
tensorboard_writer.add_scalar('train/loss', losses_m.val, num_updates)
tensorboard_writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], num_updates)
if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
Expand Down Expand Up @@ -954,14 +984,21 @@ def train_one_epoch(
return OrderedDict([('loss', losses_m.avg)])


def should_log_to_tensorboard(args):
return args.log_tensorboard and utils.is_primary(args) and has_tensorboard


def validate(
model,
loader,
loss_fn,
args,
device=torch.device('cuda'),
amp_autocast=suppress,
log_suffix=''
log_suffix='',
tensorboard_writer=None,
epoch=None,

):
batch_time_m = utils.AverageMeter()
losses_m = utils.AverageMeter()
Expand Down Expand Up @@ -1011,6 +1048,11 @@ def validate(

batch_time_m.update(time.time() - end)
end = time.time()
if should_log_to_tensorboard(args) and epoch is not None:
#by the updates
tensorboard_writer.add_scalar('val/loss', losses_m.val, epoch*last_idx+batch_idx)
tensorboard_writer.add_scalar('val/acc1', top1_m.val, epoch*last_idx+batch_idx)
tensorboard_writer.add_scalar('val/acc5', top5_m.val, epoch*last_idx+batch_idx)
if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
_logger.info(
Expand Down