Skip to content

Commit d5e77fb

Browse files
authored
Port interface of store base class from Caffe2 (pytorch#7439)
The file store implementation is new and based on the file initialization method (which uses a single file and file locking) and the interface of the Caffe2 store handler. See pytorch#7434.
1 parent 6547245 commit d5e77fb

File tree

9 files changed

+536
-0
lines changed

9 files changed

+536
-0
lines changed

torch/lib/c10d/CMakeLists.txt

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
cmake_minimum_required(VERSION 3.2 FATAL_ERROR)
2+
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../../../cmake ${CMAKE_MODULE_PATH})
3+
4+
add_library(store Store.cpp FileStore.cpp)
5+
target_compile_options(store PUBLIC "-std=c++11")
6+
7+
enable_testing()
8+
add_subdirectory(test)

torch/lib/c10d/FileStore.cpp

+280
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
#include "FileStore.hpp"
2+
3+
#include <assert.h>
4+
#include <stdint.h>
5+
#include <sys/file.h>
6+
#include <fcntl.h>
7+
#include <unistd.h>
8+
#include <sys/stat.h>
9+
10+
#include <chrono>
11+
#include <functional>
12+
#include <iostream>
13+
#include <limits>
14+
#include <sstream>
15+
#include <system_error>
16+
#include <thread>
17+
18+
#define SYSASSERT(rv, ...) \
19+
if ((rv) < 0) { \
20+
throw std::system_error( \
21+
errno, \
22+
std::system_category(), \
23+
##__VA_ARGS__); \
24+
}
25+
26+
namespace c10d {
27+
28+
namespace {
29+
30+
template<typename F>
31+
typename std::result_of<F()>::type syscall(F fn) {
32+
while (true) {
33+
auto rv = fn();
34+
if (rv == -1) {
35+
if (errno == EINTR) {
36+
continue;
37+
}
38+
}
39+
return rv;
40+
}
41+
}
42+
43+
// For a comprehensive overview of file locking methods,
44+
// see: https://gavv.github.io/blog/file-locks/.
45+
// We stick to flock(2) here because we don't care about
46+
// locking byte ranges and don't want locks to be process-wide.
47+
48+
// RAII wrapper around flock(2)
49+
class Lock {
50+
public:
51+
explicit Lock(int fd, int operation) : fd_(fd) {
52+
flock(operation);
53+
}
54+
55+
~Lock() {
56+
unlock();
57+
}
58+
59+
Lock(const Lock& that) = delete;
60+
61+
Lock(Lock&& other) noexcept {
62+
fd_ = other.fd_;
63+
other.fd_ = -1;
64+
}
65+
66+
void unlock() {
67+
if (fd_ >= 0) {
68+
flock(LOCK_UN);
69+
fd_ = -1;
70+
}
71+
}
72+
73+
protected:
74+
int fd_;
75+
76+
void flock(int operation) {
77+
auto rv = syscall(std::bind(::flock, fd_, operation));
78+
SYSASSERT(rv, "flock");
79+
}
80+
};
81+
82+
class File {
83+
public:
84+
explicit File(const std::string& path, int flags) {
85+
fd_ = syscall(std::bind(::open, path.c_str(), flags, 0644));
86+
SYSASSERT(fd_, "open(" + path + ")");
87+
}
88+
89+
~File() {
90+
::close(fd_);
91+
}
92+
93+
Lock lockShared() {
94+
return Lock(fd_, LOCK_SH);
95+
}
96+
97+
Lock lockExclusive() {
98+
return Lock(fd_, LOCK_EX);
99+
}
100+
101+
off_t seek(off_t offset, int whence) {
102+
auto rv = syscall(std::bind(lseek, fd_, offset, whence));
103+
SYSASSERT(rv, "lseek");
104+
return rv;
105+
}
106+
107+
off_t tell() {
108+
auto rv = syscall(std::bind(lseek, fd_, 0, SEEK_CUR));
109+
SYSASSERT(rv, "lseek");
110+
return rv;
111+
}
112+
113+
off_t size() {
114+
auto pos = tell();
115+
auto size = seek(0, SEEK_END);
116+
seek(pos, SEEK_SET);
117+
return size;
118+
}
119+
120+
void write(const void* buf, size_t count) {
121+
while (count > 0) {
122+
auto rv = syscall(std::bind(::write, fd_, buf, count));
123+
SYSASSERT(rv, "write");
124+
buf = (uint8_t*) buf + count;
125+
count -= rv;
126+
}
127+
}
128+
129+
void read(void* buf, size_t count) {
130+
while (count > 0) {
131+
auto rv = syscall(std::bind(::read, fd_, buf, count));
132+
SYSASSERT(rv, "read");
133+
buf = (uint8_t*) buf + count;
134+
count -= rv;
135+
}
136+
}
137+
138+
void write(const std::string& str) {
139+
uint32_t len = str.size();
140+
assert(str.size() <= std::numeric_limits<decltype(len)>::max());
141+
write(&len, sizeof(len));
142+
write(str.c_str(), len);
143+
}
144+
145+
void write(const std::vector<uint8_t>& data) {
146+
uint32_t len = data.size();
147+
assert(data.size() <= std::numeric_limits<decltype(len)>::max());
148+
write(&len, sizeof(len));
149+
write(data.data(), len);
150+
}
151+
152+
void read(std::string& str) {
153+
uint32_t len;
154+
read(&len, sizeof(len));
155+
std::vector<uint8_t> buf(len);
156+
read(buf.data(), len);
157+
str.assign(buf.begin(), buf.end());
158+
}
159+
160+
void read(std::vector<uint8_t>& data) {
161+
uint32_t len;
162+
read(&len, sizeof(len));
163+
data.resize(len);
164+
read(data.data(), len);
165+
}
166+
167+
protected:
168+
int fd_;
169+
};
170+
171+
off_t refresh(
172+
File& file,
173+
off_t pos,
174+
std::unordered_map<std::string, std::vector<uint8_t>>& cache) {
175+
auto size = file.size();
176+
if (size != pos) {
177+
std::string tmpKey;
178+
std::vector<uint8_t> tmpValue;
179+
file.seek(pos, SEEK_SET);
180+
while (size > pos) {
181+
file.read(tmpKey);
182+
file.read(tmpValue);
183+
cache[tmpKey] = std::move(tmpValue);
184+
pos = file.tell();
185+
}
186+
}
187+
return pos;
188+
}
189+
190+
} // namespace
191+
192+
FileStore::FileStore(const std::string& path)
193+
: Store(),
194+
path_(path),
195+
pos_(0) {
196+
}
197+
198+
FileStore::~FileStore() {
199+
}
200+
201+
void FileStore::set(const std::string& key, const std::vector<uint8_t>& value) {
202+
File file(path_, O_RDWR | O_CREAT);
203+
auto lock = file.lockExclusive();
204+
file.seek(0, SEEK_END);
205+
file.write(key);
206+
file.write(value);
207+
}
208+
209+
std::vector<uint8_t> FileStore::get(const std::string& key) {
210+
while (cache_.count(key) == 0) {
211+
File file(path_, O_RDONLY);
212+
auto lock = file.lockShared();
213+
auto size = file.size();
214+
if (size == pos_) {
215+
// No new entries; release the shared lock and sleep for a bit
216+
lock.unlock();
217+
std::this_thread::sleep_for(std::chrono::milliseconds(10));
218+
continue;
219+
}
220+
221+
pos_ = refresh(file, pos_, cache_);
222+
}
223+
224+
return cache_[key];
225+
}
226+
227+
int64_t FileStore::add(const std::string& key, int64_t i) {
228+
File file(path_, O_RDWR | O_CREAT);
229+
auto lock = file.lockExclusive();
230+
pos_ = refresh(file, pos_, cache_);
231+
232+
const auto& value = cache_[key];
233+
int64_t ti = i;
234+
if (!value.empty()) {
235+
auto buf = reinterpret_cast<const char*>(value.data());
236+
auto len = value.size();
237+
ti += std::stoll(std::string(buf, len));
238+
}
239+
240+
// File cursor is at the end of the file now, and we have an
241+
// exclusive lock, so we can write the new value.
242+
file.write(key);
243+
file.write(std::to_string(ti));
244+
245+
return ti;
246+
}
247+
248+
bool FileStore::check(const std::vector<std::string>& keys) {
249+
File file(path_, O_RDONLY);
250+
auto lock = file.lockShared();
251+
pos_ = refresh(file, pos_, cache_);
252+
253+
for (const auto& key : keys) {
254+
if (cache_.count(key) == 0) {
255+
return false;
256+
}
257+
}
258+
259+
return true;
260+
}
261+
262+
void FileStore::wait(
263+
const std::vector<std::string>& keys,
264+
const std::chrono::milliseconds& timeout) {
265+
// Not using inotify because it doesn't work on many
266+
// shared filesystems (such as NFS).
267+
const auto start = std::chrono::steady_clock::now();
268+
while (!check(keys)) {
269+
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
270+
std::chrono::steady_clock::now() - start);
271+
if (timeout != kNoTimeout && elapsed > timeout) {
272+
throw std::runtime_error("Wait timeout");
273+
}
274+
275+
/* sleep override */
276+
std::this_thread::sleep_for(std::chrono::milliseconds(10));
277+
}
278+
}
279+
280+
} // namespace c10d

torch/lib/c10d/FileStore.hpp

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#pragma once
2+
3+
#include <sys/types.h>
4+
5+
#include <unordered_map>
6+
7+
#include "Store.hpp"
8+
9+
namespace c10d {
10+
11+
class FileStore : public Store {
12+
public:
13+
explicit FileStore(const std::string& path);
14+
15+
virtual ~FileStore();
16+
17+
void set(
18+
const std::string& key,
19+
const std::vector<uint8_t>& value) override;
20+
21+
std::vector<uint8_t> get(const std::string& key) override;
22+
23+
int64_t add(const std::string& key, int64_t value) override;
24+
25+
bool check(const std::vector<std::string>& keys) override;
26+
27+
void wait(
28+
const std::vector<std::string>& keys,
29+
const std::chrono::milliseconds& timeout = kDefaultTimeout) override;
30+
31+
protected:
32+
std::string path_;
33+
off_t pos_;
34+
35+
std::unordered_map<std::string, std::vector<uint8_t>> cache_;
36+
};
37+
38+
} // namespace c10d

torch/lib/c10d/README.md

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# THD refactor
2+
3+
This is a work in progress. It is separate from the main THD directory
4+
to avoid disrupting THD users or have to deal with backwards compat
5+
early on. Once this gets to a usable state, we'll add Python bindings
6+
and a compat layer.
7+
8+
See https://github.com/pytorch/pytorch/issues/7434 for the main issue.
9+
10+
This tree is intentionally not part of the main build and will be
11+
buildable/testable in isolation, as long as ATen is available in
12+
`<repository root>/torch/lib/tmp_install`.
13+
14+
To build and install ATen here, navigate to the root of this
15+
repository and run:
16+
17+
``` shell
18+
tools/build_pytorch_libs.sh --with-cuda ATen
19+
```

torch/lib/c10d/Store.cpp

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#include "Store.hpp"
2+
3+
namespace c10d {
4+
5+
constexpr std::chrono::milliseconds Store::kDefaultTimeout;
6+
constexpr std::chrono::milliseconds Store::kNoTimeout;
7+
8+
// Define destructor symbol for abstract base class.
9+
Store::~Store() {
10+
}
11+
12+
} // namespace c10d

torch/lib/c10d/Store.hpp

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#pragma once
2+
3+
#include <chrono>
4+
#include <cstdint>
5+
#include <stdexcept>
6+
#include <string>
7+
#include <vector>
8+
9+
namespace c10d {
10+
11+
class Store {
12+
public:
13+
static constexpr std::chrono::milliseconds kDefaultTimeout =
14+
std::chrono::seconds(30);
15+
static constexpr std::chrono::milliseconds kNoTimeout =
16+
std::chrono::milliseconds::zero();
17+
18+
virtual ~Store();
19+
20+
virtual void set(
21+
const std::string& key,
22+
const std::vector<uint8_t>& value) = 0;
23+
24+
virtual std::vector<uint8_t> get(const std::string& key) = 0;
25+
26+
virtual int64_t add(const std::string& key, int64_t value) = 0;
27+
28+
virtual bool check(const std::vector<std::string>& keys) = 0;
29+
30+
virtual void wait(
31+
const std::vector<std::string>& keys,
32+
const std::chrono::milliseconds& timeout = kDefaultTimeout) = 0;
33+
};
34+
35+
} // namespace c10d

0 commit comments

Comments
 (0)