Skip to content

Commit 8f0d125

Browse files
authored
checkpointing: move to subfolder (#105)
1 parent 9533676 commit 8f0d125

File tree

6 files changed

+98
-73
lines changed

6 files changed

+98
-73
lines changed

torchft/checkpointing/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Checkpointing
9+
==============
10+
11+
This module implements methods for checkpointing and resuming training from a checkpoint.
12+
"""
13+
14+
from torchft.checkpointing.http_transport import HTTPTransport
15+
from torchft.checkpointing.transport import CheckpointTransport
16+
17+
__all__ = [
18+
"HTTPTransport",
19+
"CheckpointTransport",
20+
]

torchft/checkpointing.py renamed to torchft/checkpointing/http_transport.py

Lines changed: 2 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
"""
8-
Checkpointing
9-
==============
10-
11-
This module implements methods for checkpointing and resuming training from a checkpoint.
12-
"""
13-
147
import io
158
import logging
169
import socket
@@ -24,70 +17,14 @@
2417

2518
import torch
2619

20+
from torchft.checkpointing.transport import CheckpointTransport
2721
from torchft.http import _IPv6HTTPServer
2822

2923
logger: logging.Logger = logging.getLogger(__name__)
3024

3125
T = TypeVar("T")
3226

3327

34-
class CheckpointTransport(Generic[T], ABC):
35-
@abstractmethod
36-
def metadata(self) -> str:
37-
"""
38-
Returns a string that will be used by the remote CheckpointTransport to fetch the checkpoint.
39-
"""
40-
...
41-
42-
@abstractmethod
43-
def send_checkpoint(
44-
self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta
45-
) -> None:
46-
"""
47-
Sends the checkpoint, only called when there is a rank that is behind.
48-
49-
This may be async.
50-
51-
Args:
52-
dst_ranks: the ranks to send to
53-
step: the step number to send
54-
state_dict: the state dict to send
55-
timeout: the timeout to wait for the checkpoint to be sent
56-
"""
57-
...
58-
59-
def disallow_checkpoint(self) -> None:
60-
"""
61-
Called after send_checkpoint to wait for the checkpoint to be sent.
62-
63-
Once this returns, the state_dict may be mutated so no further data should be sent.
64-
"""
65-
...
66-
67-
@abstractmethod
68-
def recv_checkpoint(
69-
self, src_rank: int, metadata: str, step: int, timeout: timedelta
70-
) -> T:
71-
"""
72-
Receives the checkpoint from the given rank.
73-
74-
Args:
75-
src_rank: the rank to receive the checkpoint from
76-
metadata: the metadata returned by the remote CheckpointTransport
77-
step: the step number to receive
78-
timeout: the timeout to wait for the checkpoint
79-
"""
80-
...
81-
82-
def shutdown(self, wait: bool = True) -> None:
83-
"""
84-
Called to shutdown the checkpoint transport.
85-
86-
Args:
87-
wait: whether to wait for the transport to shutdown
88-
"""
89-
90-
9128
@contextmanager
9229
def _timed_acquire(
9330
lock: threading.Lock, timeout: timedelta
@@ -107,7 +44,7 @@ def _timed_acquire(
10744
lock.release()
10845

10946

110-
class CheckpointServer(CheckpointTransport[T]):
47+
class HTTPTransport(CheckpointTransport[T]):
11148
"""
11249
This is an HTTP server that can be used to transfer checkpoints
11350
between workers.

torchft/checkpointing_test.py renamed to torchft/checkpointing/http_transport_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@
1010
from unittest import TestCase
1111
from unittest.mock import MagicMock
1212

13-
from torchft.checkpointing import CheckpointServer, _timed_acquire
13+
from torchft.checkpointing.http_transport import HTTPTransport, _timed_acquire
1414

1515

1616
class TestCheckpointing(TestCase):
1717
def test_checkpoint_server(self) -> None:
1818
expected = {"state": "dict"}
1919
state_dict_fn = MagicMock()
2020
state_dict_fn.return_value = expected
21-
server = CheckpointServer(
21+
server = HTTPTransport(
2222
timeout=timedelta(seconds=10),
2323
)
2424

@@ -58,7 +58,7 @@ def test_checkpoint_server(self) -> None:
5858
server.shutdown()
5959

6060
def test_checkpoint_server_locking(self) -> None:
61-
server = CheckpointServer(
61+
server = HTTPTransport(
6262
timeout=timedelta(seconds=10),
6363
)
6464

torchft/checkpointing/transport.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from abc import ABC, abstractmethod
8+
from datetime import timedelta
9+
from typing import Generic, List, TypeVar
10+
11+
T = TypeVar("T")
12+
13+
14+
class CheckpointTransport(Generic[T], ABC):
15+
@abstractmethod
16+
def metadata(self) -> str:
17+
"""
18+
Returns a string that will be used by the remote CheckpointTransport to fetch the checkpoint.
19+
"""
20+
...
21+
22+
@abstractmethod
23+
def send_checkpoint(
24+
self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta
25+
) -> None:
26+
"""
27+
Sends the checkpoint, only called when there is a rank that is behind.
28+
29+
This may be async.
30+
31+
Args:
32+
dst_ranks: the ranks to send to
33+
step: the step number to send
34+
state_dict: the state dict to send
35+
timeout: the timeout to wait for the checkpoint to be sent
36+
"""
37+
...
38+
39+
def disallow_checkpoint(self) -> None:
40+
"""
41+
Called after send_checkpoint to wait for the checkpoint to be sent.
42+
43+
Once this returns, the state_dict may be mutated so no further data should be sent.
44+
"""
45+
...
46+
47+
@abstractmethod
48+
def recv_checkpoint(
49+
self, src_rank: int, metadata: str, step: int, timeout: timedelta
50+
) -> T:
51+
"""
52+
Receives the checkpoint from the given rank.
53+
54+
Args:
55+
src_rank: the rank to receive the checkpoint from
56+
metadata: the metadata returned by the remote CheckpointTransport
57+
step: the step number to receive
58+
timeout: the timeout to wait for the checkpoint
59+
"""
60+
...
61+
62+
def shutdown(self, wait: bool = True) -> None:
63+
"""
64+
Called to shutdown the checkpoint transport.
65+
66+
Args:
67+
wait: whether to wait for the transport to shutdown
68+
"""

torchft/fsdp_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def _test_fsdp(world_size: int, rank: int) -> None:
6666
# pyre-ignore[56]: Pyre was not able to infer the type of argument
6767
@unittest.skipIf(torch.cuda.device_count() < 4, "Not enough GPUs")
6868
def test_fsdp(self) -> None:
69-
multiprocessing.set_start_method("spawn")
70-
with ProcessPoolExecutor(max_workers=4) as executor:
69+
context = multiprocessing.get_context("spawn")
70+
with ProcessPoolExecutor(max_workers=4, mp_context=context) as executor:
7171
futures = []
7272
for i in range(4):
7373
future = executor.submit(self._test_fsdp, 4, i)

torchft/manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
import torch
3939
from torch.distributed import ReduceOp, TCPStore
4040

41-
from torchft.checkpointing import CheckpointServer, CheckpointTransport
41+
from torchft.checkpointing import CheckpointTransport, HTTPTransport
4242
from torchft.futures import future_timeout
4343
from torchft.torchft import Manager as _Manager, ManagerClient
4444

@@ -141,7 +141,7 @@ def __init__(
141141
replica_id: if rank==0, the replica_id for this group
142142
hostname: if rank==0, the hostname to advertise to the lighthouse server
143143
checkpoint_transport: the checkpoint transport to use for
144-
transfering checkpoints to recovering replicas
144+
transfering checkpoints to recovering replicas, defaults to HTTPTransport
145145
"""
146146
self._load_state_dict = load_state_dict
147147
self._user_state_dict = state_dict
@@ -160,7 +160,7 @@ def __init__(
160160
self._min_replica_size = min_replica_size
161161

162162
if checkpoint_transport is None:
163-
checkpoint_transport = CheckpointServer[Dict[str, T]](
163+
checkpoint_transport = HTTPTransport[Dict[str, T]](
164164
timeout=timeout,
165165
)
166166

0 commit comments

Comments
 (0)