Skip to content

Commit 151abb7

Browse files
authoredNov 21, 2024
fix None-type not iterable error when deepspeed is left blank w/ use_… (#2087)
* fix None-type not iterable error when deepspeed is left blank w/ use_reentrant: false and qlora * added unit test[skip e2e] * corrected test case[skip e2e] * assert warning message [skip e2e] * assert warning message [skip e2e] * corrected test cases [skip e2e] * lint
1 parent bf416bd commit 151abb7

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed
 

‎src/axolotl/utils/config/models/input/v0_4_1/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1314,6 +1314,7 @@ def warn_qlora_zero3_w_use_reentrant(cls, data):
13141314
and data.get("gradient_checkpointing_kwargs", {})
13151315
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
13161316
is False
1317+
and data.get("deepspeed", "") is not None
13171318
and "zero3" in data.get("deepspeed", "")
13181319
):
13191320
# may result in:

‎tests/test_validation.py

+47
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,53 @@ def test_defaults(self, minimal_cfg):
6868
assert cfg.train_on_inputs is False
6969
assert cfg.weight_decay is None
7070

71+
def test_zero3_qlora_use_reentrant_false(self, minimal_cfg):
72+
test_cfg = DictDefault(
73+
{
74+
"deepspeed": "deepspeed_configs/zero3_bf16.json",
75+
"gradient_checkpointing": True,
76+
"gradient_checkpointing_kwargs": {"use_reentrant": False},
77+
"load_in_4bit": True,
78+
"adapter": "qlora",
79+
}
80+
| minimal_cfg
81+
)
82+
83+
with self._caplog.at_level(logging.WARNING):
84+
validate_config(test_cfg)
85+
assert (
86+
"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
87+
in self._caplog.records[0].message
88+
)
89+
90+
def test_deepspeed_empty(self, minimal_cfg):
91+
test_cfg = DictDefault(
92+
{
93+
"deepspeed": "",
94+
"gradient_checkpointing": True,
95+
"gradient_checkpointing_kwargs": {"use_reentrant": False},
96+
"load_in_4bit": True,
97+
"adapter": "qlora",
98+
}
99+
| minimal_cfg
100+
)
101+
102+
_ = validate_config(test_cfg)
103+
104+
def test_deepspeed_not_set(self, minimal_cfg):
105+
test_cfg = DictDefault(
106+
{
107+
"deepspeed": None,
108+
"gradient_checkpointing": True,
109+
"gradient_checkpointing_kwargs": {"use_reentrant": False},
110+
"load_in_4bit": True,
111+
"adapter": "qlora",
112+
}
113+
| minimal_cfg
114+
)
115+
116+
_ = validate_config(test_cfg)
117+
71118
def test_datasets_min_length(self):
72119
cfg = DictDefault(
73120
{

0 commit comments

Comments
 (0)
Please sign in to comment.