Skip to content

Commit ad2b48c

Browse files
committed
fdsp config dict fix, todo list, add torchdistx support
1 parent 9190ada commit ad2b48c

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

TODO.md

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# todo list
2+
3+
- [] Validation of parameters for combinations that won't work
4+
5+
6+
7+
## things that are known not to work
8+
9+
- FSDP offload and gradient_checkpointing - https://github.com/pytorch/pytorch/issues/82203
10+
- adamw_bnb_8bit doesn't play well with FSDP offload

src/axolotl/utils/models.py

+5
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,11 @@ def load_model(
179179
m.scales = m.scales.half()
180180
m.bias = m.bias.half()
181181

182+
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1:
183+
model.is_parallelizable = True
184+
model.model_parallel = True
185+
186+
182187
# TODO resume_from_checkpoint handling
183188
return model, tokenizer, lora_config
184189

src/axolotl/utils/trainer.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import importlib
12
import math
23
import os
4+
import sys
35
from pathlib import Path
46

57
import bitsandbytes as bnb
@@ -35,9 +37,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
3537
else:
3638
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
3739
if cfg.fsdp:
38-
training_arguments_kwargs["fsdp"] = cfg.fsdp.split(" ")
39-
if cfg.fsdp_transformer_layer_cls_to_wrap:
40-
training_arguments_kwargs["fsdp_transformer_layer_cls_to_wrap"] = cfg.fsdp_transformer_layer_cls_to_wrap
40+
training_arguments_kwargs["fsdp"] = cfg.fsdp
41+
if cfg.fsdp_config:
42+
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
4143

4244

4345
# deepspeed
@@ -73,6 +75,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
7375

7476
trainer_kwargs = {}
7577

78+
if cfg.optimizer == "adamw_anyprecision":
79+
if Path(cfg.torchdistx_path).exists():
80+
sys.path.append(cfg.torchdistx_path)
81+
torchdistx = importlib.import_module('torchdistx')
7682
if cfg.optimizer == "adam8bit" and not cfg.load_4bit and not "deepspeed" in training_arguments_kwargs:
7783
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
7884
decay_parameters = [name for name in decay_parameters if "bias" not in name]

0 commit comments

Comments
 (0)