@@ -68,6 +68,53 @@ def test_defaults(self, minimal_cfg):
68
68
assert cfg .train_on_inputs is False
69
69
assert cfg .weight_decay is None
70
70
71
+ def test_zero3_qlora_use_reentrant_false (self , minimal_cfg ):
72
+ test_cfg = DictDefault (
73
+ {
74
+ "deepspeed" : "deepspeed_configs/zero3_bf16.json" ,
75
+ "gradient_checkpointing" : True ,
76
+ "gradient_checkpointing_kwargs" : {"use_reentrant" : False },
77
+ "load_in_4bit" : True ,
78
+ "adapter" : "qlora" ,
79
+ }
80
+ | minimal_cfg
81
+ )
82
+
83
+ with self ._caplog .at_level (logging .WARNING ):
84
+ validate_config (test_cfg )
85
+ assert (
86
+ "qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
87
+ in self ._caplog .records [0 ].message
88
+ )
89
+
90
+ def test_deepspeed_empty (self , minimal_cfg ):
91
+ test_cfg = DictDefault (
92
+ {
93
+ "deepspeed" : "" ,
94
+ "gradient_checkpointing" : True ,
95
+ "gradient_checkpointing_kwargs" : {"use_reentrant" : False },
96
+ "load_in_4bit" : True ,
97
+ "adapter" : "qlora" ,
98
+ }
99
+ | minimal_cfg
100
+ )
101
+
102
+ _ = validate_config (test_cfg )
103
+
104
+ def test_deepspeed_not_set (self , minimal_cfg ):
105
+ test_cfg = DictDefault (
106
+ {
107
+ "deepspeed" : None ,
108
+ "gradient_checkpointing" : True ,
109
+ "gradient_checkpointing_kwargs" : {"use_reentrant" : False },
110
+ "load_in_4bit" : True ,
111
+ "adapter" : "qlora" ,
112
+ }
113
+ | minimal_cfg
114
+ )
115
+
116
+ _ = validate_config (test_cfg )
117
+
71
118
def test_datasets_min_length (self ):
72
119
cfg = DictDefault (
73
120
{
0 commit comments