Skip to content

Commit e5dfddc

Browse files
committed
filter the args
1 parent eecf3b9 commit e5dfddc

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

Diff for: train.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,9 @@ def main():
735735
if has_tensorboard:
736736
tensorboard_writer = SummaryWriter(args.log_tensorboard)
737737
#write Hyperparameters to tensorboard
738-
tensorboard_writer.add_hparams(vars(args), {})
738+
#get all args keys that are one of int, float, str, bool, or torch.Tensor
739+
hparams = {k: v for k, v in vars(args).items() if type(v) in [int, float, str, bool, torch.Tensor]}
740+
tensorboard_writer.add_hparams(hparams, {})
739741

740742
else:
741743
_logger.warning(

0 commit comments

Comments
 (0)