@@ -242,6 +242,12 @@ class AxolotlTrainingMixins:
242
242
"help" : "workaround to pass an alternate optimizer to the HF trainer"
243
243
},
244
244
)
245
+ alternate_lr_scheduler_type : Optional [str ] = field (
246
+ default = None ,
247
+ metadata = {
248
+ "help" : "workaround to pass an alternate lr scheduler to the HF trainer"
249
+ },
250
+ )
245
251
246
252
247
253
@dataclass
@@ -318,7 +324,23 @@ def create_scheduler(
318
324
# fmt: off
319
325
if self .lr_scheduler is None : # type: ignore # pylint: disable=access-member-before-definition
320
326
# fmt: on
321
- if use_cosine_quadratic :
327
+ if self .args .alternate_lr_scheduler_type == "one_cycle" :
328
+ num_warmup_steps = self .args .get_warmup_steps (num_training_steps )
329
+ pct_start = num_warmup_steps / num_training_steps
330
+ extra_lr_kwargs = {}
331
+ if "pct_start" not in self .args .lr_scheduler_kwargs :
332
+ extra_lr_kwargs ["pct_start" ] = pct_start
333
+ if "anneal_strategy" not in self .args .lr_scheduler_kwargs :
334
+ extra_lr_kwargs ["anneal_strategy" ] = "cos"
335
+
336
+ self .lr_scheduler = OneCycleLR (
337
+ optimizer ,
338
+ max_lr = self .args .learning_rate ,
339
+ total_steps = num_training_steps ,
340
+ ** extra_lr_kwargs ,
341
+ ** self .args .lr_scheduler_kwargs ,
342
+ )
343
+ elif use_cosine_quadratic :
322
344
if use_cosine_min_lr :
323
345
LOG .warning ("Both cosine quadratic warmup and min lr detected. Using quadratic warmup." )
324
346
@@ -876,37 +898,6 @@ def compute_loss(
876
898
return lm_loss
877
899
878
900
879
- class OneCycleLRSchedulerTrainer (AxolotlTrainer ):
880
- """
881
- Trainer subclass that uses the OneCycleLR scheduler
882
- """
883
-
884
- tag_names = ["axolotl" , "onecycle" ]
885
-
886
- def __init__ (self , * args , ** kwargs ):
887
- super ().__init__ (* args , ** kwargs )
888
- self .lr_scheduler = None
889
-
890
- def create_scheduler (
891
- self ,
892
- num_training_steps : int ,
893
- optimizer : Optional [torch .optim .Optimizer ] = None ,
894
- ):
895
- optimizer = self .optimizer if optimizer is None else optimizer
896
- num_warmup_steps = self .args .get_warmup_steps (num_training_steps )
897
- pct_start = num_warmup_steps / num_training_steps
898
-
899
- self .lr_scheduler = OneCycleLR (
900
- optimizer ,
901
- max_lr = self .args .learning_rate ,
902
- total_steps = num_training_steps ,
903
- pct_start = pct_start ,
904
- div_factor = 6 ,
905
- )
906
-
907
- return self .lr_scheduler
908
-
909
-
910
901
class ReLoRATrainer (AxolotlTrainer ):
911
902
"""
912
903
Trainer subclass that uses the OneCycleLR scheduler
@@ -1190,10 +1181,6 @@ def get_post_trainer_create_callbacks(self, trainer):
1190
1181
return callbacks
1191
1182
1192
1183
def _get_trainer_cls (self ):
1193
- if self .cfg .lr_scheduler == "one_cycle" and (
1194
- self .cfg .fsdp or self .cfg .adapter == "qlora"
1195
- ):
1196
- return OneCycleLRSchedulerTrainer
1197
1184
if self .cfg .relora_steps :
1198
1185
return ReLoRATrainer
1199
1186
if self .cfg .model_config_type == "mamba" :
@@ -1443,12 +1430,15 @@ def build(self, total_num_steps):
1443
1430
training_arguments_kwargs [
1444
1431
"loraplus_lr_embedding"
1445
1432
] = self .cfg .loraplus_lr_embedding
1446
- training_arguments_kwargs ["lr_scheduler_type" ] = (
1447
- self .cfg .lr_scheduler
1448
- if self .cfg .lr_scheduler
1449
- and self .cfg .lr_scheduler not in ("one_cycle" , "log_sweep" )
1450
- else "cosine"
1451
- )
1433
+ if self .cfg .lr_scheduler in ["one_cycle" , "log_sweep" ]:
1434
+ training_arguments_kwargs ["lr_scheduler_type" ] = "cosine"
1435
+ training_arguments_kwargs [
1436
+ "alternate_lr_scheduler_type"
1437
+ ] = self .cfg .lr_scheduler
1438
+ else :
1439
+ training_arguments_kwargs ["lr_scheduler_type" ] = (
1440
+ self .cfg .lr_scheduler if self .cfg .lr_scheduler else "cosine"
1441
+ )
1452
1442
training_arguments_kwargs ["lr_scheduler_kwargs" ] = (
1453
1443
self .cfg .lr_scheduler_kwargs if self .cfg .lr_scheduler_kwargs else {}
1454
1444
)
0 commit comments