Skip to content

Commit ecdda00

Browse files
authored
One cycle lr (#1803)
* refactor one_cycle lr scheduler so it's reusable in more situations * fix validation for lr_scheduler * default to cosine anneal strategy * one cycle lr exepects cos
1 parent b7665c2 commit ecdda00

File tree

3 files changed

+35
-43
lines changed

3 files changed

+35
-43
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ repos:
88
- id: check-yaml
99
- id: end-of-file-fixer
1010
- id: trailing-whitespace
11+
- id: no-commit-to-branch
12+
args: ['--branch', 'main']
1113
- repo: https://github.com/psf/black
1214
rev: 23.3.0
1315
hooks:

src/axolotl/core/trainer_builder.py

Lines changed: 32 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,12 @@ class AxolotlTrainingMixins:
242242
"help": "workaround to pass an alternate optimizer to the HF trainer"
243243
},
244244
)
245+
alternate_lr_scheduler_type: Optional[str] = field(
246+
default=None,
247+
metadata={
248+
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
249+
},
250+
)
245251

246252

247253
@dataclass
@@ -318,7 +324,23 @@ def create_scheduler(
318324
# fmt: off
319325
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
320326
# fmt: on
321-
if use_cosine_quadratic:
327+
if self.args.alternate_lr_scheduler_type == "one_cycle":
328+
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
329+
pct_start = num_warmup_steps / num_training_steps
330+
extra_lr_kwargs = {}
331+
if "pct_start" not in self.args.lr_scheduler_kwargs:
332+
extra_lr_kwargs["pct_start"] = pct_start
333+
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
334+
extra_lr_kwargs["anneal_strategy"] = "cos"
335+
336+
self.lr_scheduler = OneCycleLR(
337+
optimizer,
338+
max_lr=self.args.learning_rate,
339+
total_steps=num_training_steps,
340+
**extra_lr_kwargs,
341+
**self.args.lr_scheduler_kwargs,
342+
)
343+
elif use_cosine_quadratic:
322344
if use_cosine_min_lr:
323345
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
324346

@@ -876,37 +898,6 @@ def compute_loss(
876898
return lm_loss
877899

878900

879-
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
880-
"""
881-
Trainer subclass that uses the OneCycleLR scheduler
882-
"""
883-
884-
tag_names = ["axolotl", "onecycle"]
885-
886-
def __init__(self, *args, **kwargs):
887-
super().__init__(*args, **kwargs)
888-
self.lr_scheduler = None
889-
890-
def create_scheduler(
891-
self,
892-
num_training_steps: int,
893-
optimizer: Optional[torch.optim.Optimizer] = None,
894-
):
895-
optimizer = self.optimizer if optimizer is None else optimizer
896-
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
897-
pct_start = num_warmup_steps / num_training_steps
898-
899-
self.lr_scheduler = OneCycleLR(
900-
optimizer,
901-
max_lr=self.args.learning_rate,
902-
total_steps=num_training_steps,
903-
pct_start=pct_start,
904-
div_factor=6,
905-
)
906-
907-
return self.lr_scheduler
908-
909-
910901
class ReLoRATrainer(AxolotlTrainer):
911902
"""
912903
Trainer subclass that uses the OneCycleLR scheduler
@@ -1190,10 +1181,6 @@ def get_post_trainer_create_callbacks(self, trainer):
11901181
return callbacks
11911182

11921183
def _get_trainer_cls(self):
1193-
if self.cfg.lr_scheduler == "one_cycle" and (
1194-
self.cfg.fsdp or self.cfg.adapter == "qlora"
1195-
):
1196-
return OneCycleLRSchedulerTrainer
11971184
if self.cfg.relora_steps:
11981185
return ReLoRATrainer
11991186
if self.cfg.model_config_type == "mamba":
@@ -1443,12 +1430,15 @@ def build(self, total_num_steps):
14431430
training_arguments_kwargs[
14441431
"loraplus_lr_embedding"
14451432
] = self.cfg.loraplus_lr_embedding
1446-
training_arguments_kwargs["lr_scheduler_type"] = (
1447-
self.cfg.lr_scheduler
1448-
if self.cfg.lr_scheduler
1449-
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
1450-
else "cosine"
1451-
)
1433+
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
1434+
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
1435+
training_arguments_kwargs[
1436+
"alternate_lr_scheduler_type"
1437+
] = self.cfg.lr_scheduler
1438+
else:
1439+
training_arguments_kwargs["lr_scheduler_type"] = (
1440+
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
1441+
)
14521442
training_arguments_kwargs["lr_scheduler_kwargs"] = (
14531443
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
14541444
)

src/axolotl/utils/config/models/input/v0_4_1/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ class HyperparametersConfig(BaseModel):
378378
},
379379
)
380380
torchdistx_path: Optional[str] = None
381-
lr_scheduler: Optional[SchedulerType] = "cosine"
381+
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine"
382382
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
383383
lr_quadratic_warmup: Optional[bool] = None
384384
cosine_min_lr_ratio: Optional[float] = None

0 commit comments

Comments
 (0)