-
Notifications
You must be signed in to change notification settings - Fork 309
/
Copy pathtest_checkpoint.py
287 lines (249 loc) · 9.73 KB
/
test_checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import shutil
import tempfile
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from unittest import mock
import torch
from torchtitan.components.checkpoint import CheckpointManager
def fake_dcp_save(state, checkpoint_id):
state = {k: v.state_dict() for k, v in state.items()}
os.makedirs(checkpoint_id, exist_ok=True)
torch.save(state, os.path.join(checkpoint_id, "state.pt"))
def fake_dcp_load(state, checkpoint_id):
state["trainer"].dcp_load_is_called = 7312
def fake_async_save(state, checkpoint_id, process_group):
def run_save():
fake_dcp_save(state, checkpoint_id)
with ThreadPoolExecutor(max_workers=1) as executor:
f = executor.submit(run_save)
mock_future = mock.Mock()
mock_future.result = mock.Mock(side_effect=f.result)
return mock_future
def fake_get_model_state_dict(model, *args, **kwargs):
return model.state_dict()
@dataclass
class DummyCheckpointConfig:
enable_checkpoint: bool = True
folder: str = "dummy_folder"
interval: int = 10
async_mode: str = "disabled"
keep_latest_k: int = 0
model_weights_only: bool = False
export_dtype: str = "float32"
exclude_from_loading = []
@dataclass
class DummyJob:
dump_folder: str = "dummy_folder"
@dataclass
class DummyExperimental:
ft_replica_id = 0
ft_group_size = 1
@dataclass
class DummyJobConfig:
checkpoint: DummyCheckpointConfig = field(default_factory=DummyCheckpointConfig)
job: DummyJob = field(default_factory=DummyJob)
experimental: DummyExperimental = field(default_factory=DummyExperimental)
ft_manager = None
# Dummy instances to supply as constructor arguments.
dummy_dataloader = mock.Mock()
dummy_dataloader.state_dict = mock.Mock(side_effect=lambda: {"dataloader": 1})
dummy_model_parts = [mock.Mock()]
dummy_model_parts[0].state_dict = mock.Mock(side_effect=lambda: {"model": 2})
dummy_optimizers = mock.Mock()
dummy_optimizers.state_dict = mock.Mock(side_effect=lambda: {"optimizer": 3})
dummy_lr_schedulers = mock.Mock()
dummy_lr_schedulers.state_dict = mock.Mock(side_effect=lambda: {"lr_scheduler": 4})
class TestCheckpointManager(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.mkdtemp()
self.dummy_job = DummyJob(dump_folder=self.temp_dir)
self.job_config = DummyJobConfig(job=self.dummy_job)
self.checkpoint_folder = os.path.join(
self.dummy_job.dump_folder, self.job_config.checkpoint.folder
)
os.makedirs(self.checkpoint_folder, exist_ok=True)
self.trainer_state = mock.Mock()
self.trainer_state.state_dict = mock.Mock(side_effect=lambda: {"my_state": 765})
def tearDown(self):
# Remove the temporary directory after each test.
shutil.rmtree(self.temp_dir)
@mock.patch(
"torchtitan.components.checkpoint.get_model_state_dict",
side_effect=fake_get_model_state_dict,
)
@mock.patch("torchtitan.components.checkpoint.dcp.save", side_effect=fake_dcp_save)
def test_save(self, *_):
"""Test that calling save() writes a checkpoint file to disk."""
job_config = DummyJobConfig(job=self.dummy_job)
manager = CheckpointManager(
dummy_dataloader,
dummy_model_parts,
dummy_optimizers,
dummy_lr_schedulers,
{"trainer": self.trainer_state},
job_config,
)
step = 20
manager.save(curr_step=step, force=True)
state_file = self._checkpoint_id(step)
self.assertTrue(
os.path.exists(state_file), "The checkpoint file was not created on disk."
)
loaded_state = torch.load(state_file, weights_only=False)
self.assertEqual(
loaded_state["trainer"]["my_state"],
765,
"Saved state does not match expected value.",
)
@mock.patch(
"torchtitan.components.checkpoint.get_model_state_dict",
side_effect=fake_get_model_state_dict,
)
@mock.patch("torchtitan.components.checkpoint.dcp.load", side_effect=fake_dcp_load)
@mock.patch("torchtitan.components.checkpoint.dcp.save", side_effect=fake_dcp_save)
def test_load(self, *_):
"""Test that load() properly reads the checkpoint file from disk and restores state."""
job_config = DummyJobConfig(job=self.dummy_job)
manager = CheckpointManager(
dummy_dataloader,
dummy_model_parts,
dummy_optimizers,
dummy_lr_schedulers,
{"trainer": self.trainer_state},
job_config,
)
step = 30
manager.save(curr_step=step, force=True)
# Simulate a state change.
manager.states["test"] = 999
success = manager.load(step=step)
self.assertTrue(
success,
"The load() method should have returned True for an existing checkpoint.",
)
self.assertTrue(hasattr(manager.states["trainer"], "dcp_load_is_called"))
self.assertEqual(
manager.states["trainer"].dcp_load_is_called,
7312,
"The state was not correctly restored after loading.",
)
@mock.patch("torchtitan.components.checkpoint.dist.get_rank", return_value=0)
@mock.patch(
"torchtitan.components.checkpoint.get_model_state_dict",
side_effect=fake_get_model_state_dict,
)
@mock.patch("torchtitan.components.checkpoint.dcp.save", side_effect=fake_dcp_save)
def test_purge_stale_checkpoints_rank_zero(self, *_):
"""
Test that when keep_latest_k is 3 and dist.get_rank() returns 0, stale checkpoints
are purged by placing the correct paths into the purge queue.
"""
job_config = DummyJobConfig(job=self.dummy_job)
job_config.checkpoint.keep_latest_k = 3
manager = CheckpointManager(
dummy_dataloader,
dummy_model_parts,
dummy_optimizers,
dummy_lr_schedulers,
{"trainer": self.trainer_state},
job_config,
)
steps = [10, 20, 30, 40, 50]
for s in steps:
manager.save(curr_step=s, force=False)
while not manager.purge_queue.empty():
time.sleep(1)
time.sleep(1)
os.sync()
expected_paths = [
os.path.join(self.checkpoint_folder, "step-30"),
os.path.join(self.checkpoint_folder, "step-40"),
os.path.join(self.checkpoint_folder, "step-50"),
]
for step in [10, 20]:
self.assertFalse(
os.path.exists(self._checkpoint_id(step)),
"The checkpoint is not purged.",
)
for step in [30, 40, 50]:
self.assertTrue(
os.path.exists(self._checkpoint_id(step)), "The checkpointis purged."
)
@mock.patch("torchtitan.components.checkpoint.dist.get_rank", return_value=1)
@mock.patch(
"torchtitan.components.checkpoint.get_model_state_dict",
side_effect=fake_get_model_state_dict,
)
@mock.patch("torchtitan.components.checkpoint.dcp.save", side_effect=fake_dcp_save)
def test_purge_stale_checkpoints_rank_nonzero(self, *_):
"""
Test that when dist.get_rank() returns a non-zero value, the purge logic does not
place any paths in the purge queue.
"""
job_config = DummyJobConfig(job=self.dummy_job)
job_config.checkpoint.keep_latest_k = 3
manager = CheckpointManager(
dummy_dataloader,
dummy_model_parts,
dummy_optimizers,
dummy_lr_schedulers,
{"trainer": self.trainer_state},
job_config,
)
steps = [10, 20, 30, 40, 50]
for s in steps:
manager.save(curr_step=s, force=False)
while not manager.purge_queue.empty():
time.sleep(1)
time.sleep(1)
os.sync()
for step in [10, 20, 30, 40, 50]:
self.assertTrue(
os.path.exists(self._checkpoint_id(step)), "The checkpointis purged."
)
@mock.patch("torchtitan.components.checkpoint.dist.new_group")
@mock.patch(
"torchtitan.components.checkpoint.get_model_state_dict",
side_effect=fake_get_model_state_dict,
)
@mock.patch(
"torchtitan.components.checkpoint.dcp.async_save", side_effect=fake_async_save
)
def test_async_save_calls_async_wait(self, *_):
"""
Test that in async mode (AsyncMode.ASYNC), calling save() twice correctly waits
on the previous async future via _async_wait().
"""
# Set async_mode to "async" in the job configuration.
job_config = DummyJobConfig(job=self.dummy_job)
job_config.checkpoint.async_mode = "async"
manager = CheckpointManager(
dummy_dataloader,
dummy_model_parts,
dummy_optimizers,
dummy_lr_schedulers,
{"trainer": self.trainer_state},
job_config,
)
# First save: should schedule an async save.
manager.save(curr_step=10, force=False)
f = manager.async_future
f.result.assert_not_called()
manager.save(curr_step=20, force=False)
f.result.assert_called_once()
f = manager.async_future
f.result.assert_not_called()
def _checkpoint_id(self, step):
checkpoint_id = os.path.join(self.checkpoint_folder, f"step-{step}")
state_file = os.path.join(checkpoint_id, "state.pt")
return state_file
if __name__ == "__main__":
unittest.main()