Skip to content

Commit

Permalink
Added AdamW 8 bit optimizer which can save some GPU memory during tra…
Browse files Browse the repository at this point in the history
…ining (use 'adamw8bit' for training.optimizer in config). Also added possibility to provide optimizer parameters in config file using keyword 'optimizer'.
  • Loading branch information
ZFTurbo committed Jul 31, 2024
1 parent 71b3a90 commit ba951a4
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ torch_audiomentations
asteroid==0.7.0
auraloss
torchseg
bitsandbytes
16 changes: 12 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,15 +395,23 @@ def train_model(args):
if 0:
valid_multi_gpu(model, args, config, verbose=True)

optim_params = dict()
if 'optimizer' in config:
optim_params = dict(config['optimizer'])
print('Optimizer params from config:\n{}'.format(optim_params))

if config.training.optimizer == 'adam':
optimizer = Adam(model.parameters(), lr=config.training.lr)
optimizer = Adam(model.parameters(), lr=config.training.lr, **optim_params)
elif config.training.optimizer == 'adamw':
optimizer = AdamW(model.parameters(), lr=config.training.lr)
optimizer = AdamW(model.parameters(), lr=config.training.lr, **optim_params)
elif config.training.optimizer == 'radam':
optimizer = RAdam(model.parameters(), lr=config.training.lr)
optimizer = RAdam(model.parameters(), lr=config.training.lr, **optim_params)
elif config.training.optimizer == 'adamw8bit':
import bitsandbytes as bnb
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.training.lr, **optim_params)
elif config.training.optimizer == 'sgd':
print('Use SGD optimizer')
optimizer = SGD(model.parameters(), lr=config.training.lr, momentum=0.999)
optimizer = SGD(model.parameters(), lr=config.training.lr, **optim_params)
else:
print('Unknown optimizer: {}'.format(config.training.optimizer))
exit()
Expand Down

0 comments on commit ba951a4

Please sign in to comment.