Skip to content

Commit f44aaa5

Browse files
authored
checkpointing/HTTPTransport: added streaming serialization and parallel transfer support (#106)
1 parent e55542a commit f44aaa5

File tree

7 files changed

+469
-85
lines changed

7 files changed

+469
-85
lines changed

torchft/checkpointing/_rwlock.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# -*- coding: utf-8 -*-
2+
""" rwlock.py
3+
4+
Adapted from: https://github.com/tylerneylon/rwlock/blob/main/rwlock.py
5+
6+
A class to implement read-write locks on top of the standard threading
7+
library.
8+
9+
This is implemented with two mutexes (threading.Lock instances) as per this
10+
wikipedia pseudocode:
11+
12+
https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock#Using_two_mutexes
13+
14+
__________________________
15+
License info (MIT):
16+
17+
*******
18+
19+
Copyright 2023 Tyler Neylon and contributors
20+
21+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
22+
documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
23+
rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
24+
persons to whom the Software is furnished to do so, subject to the following conditions:
25+
26+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
27+
28+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
29+
WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
30+
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
31+
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
32+
33+
*******
34+
"""
35+
36+
37+
from contextlib import contextmanager
38+
from threading import Lock
39+
from typing import Generator
40+
41+
42+
class RWLock(object):
43+
"""RWLock class; this is meant to allow an object to be read from by
44+
multiple threads, but only written to by a single thread at a time. See:
45+
https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock
46+
47+
All operations are timed and will throw TimeoutError if the timeout is
48+
exceeded.
49+
50+
Usage:
51+
52+
from rwlock import RWLock
53+
54+
my_obj_rwlock = RWLock(timeout=60.0)
55+
56+
# When reading from my_obj:
57+
with my_obj_rwlock.r_lock():
58+
do_read_only_things_with(my_obj)
59+
60+
# When writing to my_obj:
61+
with my_obj_rwlock.w_lock():
62+
mutate(my_obj)
63+
"""
64+
65+
def __init__(self, timeout: float = -1) -> None:
66+
self.timeout = timeout
67+
68+
self._w_lock = Lock()
69+
self._num_r_lock = Lock()
70+
self._num_r = 0
71+
72+
# ___________________________________________________________________
73+
# Reading methods.
74+
75+
def r_acquire(self) -> None:
76+
if not self._num_r_lock.acquire(timeout=self.timeout):
77+
raise TimeoutError(
78+
f"Timed out waiting for rlock after {self.timeout} seconds"
79+
)
80+
81+
self._num_r += 1
82+
if self._num_r == 1:
83+
if not self._w_lock.acquire(timeout=self.timeout):
84+
self._num_r -= 1
85+
self._num_r_lock.release()
86+
raise TimeoutError(
87+
f"Timed out waiting for wlock after {self.timeout} seconds"
88+
)
89+
90+
self._num_r_lock.release()
91+
92+
def r_release(self) -> None:
93+
assert self._num_r > 0
94+
self._num_r_lock.acquire()
95+
self._num_r -= 1
96+
if self._num_r == 0:
97+
self._w_lock.release()
98+
self._num_r_lock.release()
99+
100+
@contextmanager
101+
def r_lock(self) -> Generator[None, None, None]:
102+
"""This method is designed to be used via the `with` statement."""
103+
self.r_acquire()
104+
try:
105+
yield
106+
finally:
107+
self.r_release()
108+
109+
# ___________________________________________________________________
110+
# Writing methods.
111+
112+
def w_acquire(self) -> None:
113+
if not self._w_lock.acquire(timeout=self.timeout):
114+
raise TimeoutError(
115+
f"Timed out waiting for wlock after {self.timeout} seconds"
116+
)
117+
118+
def w_release(self) -> None:
119+
self._w_lock.release()
120+
121+
@contextmanager
122+
def w_lock(self) -> Generator[None, None, None]:
123+
"""This method is designed to be used via the `with` statement."""
124+
self.w_acquire()
125+
try:
126+
yield
127+
finally:
128+
self.w_release()
129+
130+
def w_locked(self) -> bool:
131+
"""Returns True if the lock is currently locked for reading."""
132+
return self._w_lock.locked()
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import io
2+
import warnings
3+
from typing import IO
4+
5+
import torch
6+
7+
8+
def _fallback_save(obj: object, f: IO[bytes]) -> None:
9+
warnings.warn(
10+
"using slow fallback torch.save implementation, please upgrade to PT 2.7+ for fast streaming saves"
11+
)
12+
13+
torch.save(obj, f)
14+
15+
16+
def _fallback_load(f: IO[bytes], weights_only: bool = True) -> object:
17+
warnings.warn(
18+
"using slow fallback torch.load implementation, please upgrade to PT 2.7+ for fast streaming loads"
19+
)
20+
21+
# torch.load requires a seekable file object
22+
buf = f.read()
23+
reader = io.BytesIO(buf)
24+
25+
return torch.load(reader, weights_only=weights_only)
26+
27+
28+
try:
29+
# pyre-fixme[21]: upgrade to PT 2.7 once released
30+
from torch.distributed._serialization import _streaming_load, _streaming_save
31+
except ImportError:
32+
_streaming_load = _fallback_load
33+
_streaming_save = _fallback_save

0 commit comments

Comments
 (0)