Skip to content

Commit 9477c14

Browse files
jaliyaefacebook-github-bot
authored andcommitted
C++ Frontend: adding two distributed samples (Random and Sequential) (pytorch#16910)
Summary: Adding two distrbuted samplers, Random and Sequential to the mix. Similar to python counterpart, DistributedSampler introduces a new method `set_epoch(size_t epoch)` which can be use to shuffle data determinstically between distributed processes. Pull Request resolved: pytorch#16910 Differential Revision: D14130980 Pulled By: soumith fbshipit-source-id: ec08b7130c01e2fc6dc3693f7ac622a0a6d60f10
1 parent 8852e21 commit 9477c14

File tree

5 files changed

+484
-0
lines changed

5 files changed

+484
-0
lines changed

test/cpp/api/dataloader.cpp

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,188 @@ TEST(DataTest, CanUseCustomTypeAsIndexType) {
831831
}
832832
}
833833

834+
TEST(DataTest, DistributedRandomSamplerSingleReplicaProduceCorrectSamples) {
835+
size_t sample_count = 10;
836+
samplers::DistributedRandomSampler drs(sample_count);
837+
838+
std::vector<size_t> res;
839+
torch::optional<std::vector<size_t>> idx;
840+
while ((idx = drs.next(3)).has_value()) {
841+
res.insert(std::end(res), std::begin(*idx), std::end(*idx));
842+
}
843+
844+
ASSERT_EQ(res.size(), sample_count);
845+
846+
std::sort(res.begin(), res.end());
847+
for (size_t i = 0; i < res.size(); ++i) {
848+
ASSERT_EQ(res[i], i);
849+
}
850+
}
851+
852+
TEST(DataTest, DistributedRandomSamplerMultiReplicaProduceCorrectSamples) {
853+
size_t sample_count = 10;
854+
size_t num_replicas = 3;
855+
856+
auto test_function = [&](bool allow_duplicates,
857+
size_t local_sample_count,
858+
std::vector<size_t>& output,
859+
size_t batch_size) {
860+
std::vector<std::unique_ptr<samplers::DistributedRandomSampler>> samplers;
861+
862+
for (size_t i = 0; i < num_replicas; ++i) {
863+
samplers.emplace_back(
864+
torch::make_unique<samplers::DistributedRandomSampler>(
865+
sample_count, num_replicas, i, allow_duplicates));
866+
}
867+
868+
std::vector<size_t> res;
869+
for (size_t i = 0; i < num_replicas; ++i) {
870+
(*samplers[i]).reset();
871+
torch::optional<std::vector<size_t>> idx;
872+
while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
873+
res.insert(std::end(res), std::begin(*idx), std::end(*idx));
874+
}
875+
ASSERT_EQ(res.size(), local_sample_count * (i + 1));
876+
}
877+
std::sort(res.begin(), res.end());
878+
ASSERT_EQ(res, output);
879+
};
880+
881+
for (size_t batch_size = 1; batch_size <= 3; ++batch_size) {
882+
size_t local_sample_count =
883+
static_cast<size_t>(std::ceil(sample_count * 1.0 / num_replicas));
884+
std::vector<size_t> output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9};
885+
test_function(true, local_sample_count, output1, batch_size);
886+
887+
local_sample_count =
888+
static_cast<size_t>(std::floor(sample_count * 1.0 / num_replicas));
889+
std::vector<size_t> output2{0, 1, 2, 3, 4, 5, 6, 7, 8};
890+
test_function(false, local_sample_count, output2, batch_size);
891+
}
892+
}
893+
894+
TEST(DataTest, CanSaveAndLoadDistributedRandomSampler) {
895+
{
896+
samplers::DistributedRandomSampler a(10);
897+
ASSERT_EQ(a.index(), 0);
898+
std::stringstream stream;
899+
torch::save(a, stream);
900+
901+
samplers::DistributedRandomSampler b(10);
902+
torch::load(b, stream);
903+
ASSERT_EQ(b.index(), 0);
904+
}
905+
{
906+
samplers::DistributedRandomSampler a(10);
907+
a.next(3);
908+
a.next(4);
909+
ASSERT_EQ(a.index(), 7);
910+
std::stringstream stream;
911+
torch::save(a, stream);
912+
913+
samplers::DistributedRandomSampler b(10);
914+
torch::load(b, stream);
915+
ASSERT_EQ(b.index(), 7);
916+
}
917+
{
918+
samplers::DistributedRandomSampler a(10);
919+
a.set_epoch(3);
920+
std::stringstream stream;
921+
torch::save(a, stream);
922+
923+
samplers::DistributedRandomSampler b(10);
924+
torch::load(b, stream);
925+
ASSERT_EQ(b.epoch(), 3);
926+
}
927+
}
928+
929+
TEST(DataTest, DistributedSequentialSamplerSingleReplicaProduceCorrectSamples) {
930+
size_t sample_count = 10;
931+
size_t batch_size = 3;
932+
samplers::DistributedSequentialSampler dss(sample_count);
933+
934+
std::vector<size_t> res;
935+
torch::optional<std::vector<size_t>> idx;
936+
while ((idx = dss.next(batch_size)).has_value()) {
937+
res.insert(std::end(res), std::begin(*idx), std::end(*idx));
938+
}
939+
940+
ASSERT_EQ(res.size(), sample_count);
941+
942+
std::sort(res.begin(), res.end());
943+
for (size_t i = 0; i < res.size(); ++i) {
944+
ASSERT_EQ(res[i], i);
945+
}
946+
}
947+
948+
TEST(DataTest, DistributedSequentialSamplerMultiReplicaProduceCorrectSamples) {
949+
size_t sample_count = 10;
950+
size_t num_replicas = 3;
951+
952+
auto test_function = [&](bool allow_duplicates,
953+
size_t local_sample_count,
954+
std::vector<size_t>& output,
955+
size_t batch_size) {
956+
std::vector<std::unique_ptr<samplers::DistributedSequentialSampler>>
957+
samplers;
958+
959+
for (size_t i = 0; i < num_replicas; ++i) {
960+
samplers.emplace_back(
961+
torch::make_unique<samplers::DistributedSequentialSampler>(
962+
sample_count, num_replicas, i, allow_duplicates));
963+
}
964+
965+
std::vector<size_t> res;
966+
for (size_t i = 0; i < num_replicas; ++i) {
967+
(*samplers[i]).reset();
968+
torch::optional<std::vector<size_t>> idx;
969+
while ((idx = (*samplers[i]).next(batch_size)).has_value()) {
970+
res.insert(std::end(res), std::begin(*idx), std::end(*idx));
971+
}
972+
ASSERT_EQ(res.size(), local_sample_count * (i + 1));
973+
}
974+
std::sort(res.begin(), res.end());
975+
ASSERT_EQ(res, output);
976+
};
977+
978+
for (size_t batch_size = 1; batch_size <= 3; ++batch_size) {
979+
size_t local_sample_count =
980+
static_cast<size_t>(std::ceil(sample_count * 1.0 / num_replicas));
981+
std::vector<size_t> output1{0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9};
982+
test_function(true, local_sample_count, output1, batch_size);
983+
984+
local_sample_count =
985+
static_cast<size_t>(std::floor(sample_count * 1.0 / num_replicas));
986+
std::vector<size_t> output2{0, 1, 2, 3, 4, 5, 6, 7, 8};
987+
test_function(false, local_sample_count, output2, batch_size);
988+
}
989+
}
990+
991+
TEST(DataTest, CanSaveAndLoadDistributedSequentialSampler) {
992+
{
993+
samplers::DistributedSequentialSampler a(10);
994+
ASSERT_EQ(a.index(), 0);
995+
std::stringstream stream;
996+
torch::save(a, stream);
997+
998+
samplers::DistributedSequentialSampler b(10);
999+
torch::load(b, stream);
1000+
ASSERT_EQ(b.index(), 0);
1001+
}
1002+
{
1003+
samplers::DistributedSequentialSampler a(10);
1004+
a.next(3);
1005+
a.next(4);
1006+
ASSERT_EQ(a.index(), 7);
1007+
std::stringstream stream;
1008+
torch::save(a, stream);
1009+
1010+
samplers::DistributedSequentialSampler b(10);
1011+
torch::load(b, stream);
1012+
ASSERT_EQ(b.index(), 7);
1013+
}
1014+
}
1015+
8341016
TEST(DataLoaderTest, DataLoaderOptionsDefaultAsExpected) {
8351017
DataLoaderOptions partial_options;
8361018
FullDataLoaderOptions full_options(partial_options);

torch/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ if (NOT NO_API)
225225
list(APPEND TORCH_SRCS
226226
${TORCH_SRC_DIR}/csrc/api/src/cuda.cpp
227227
${TORCH_SRC_DIR}/csrc/api/src/data/datasets/mnist.cpp
228+
${TORCH_SRC_DIR}/csrc/api/src/data/samplers/distributed.cpp
228229
${TORCH_SRC_DIR}/csrc/api/src/data/samplers/random.cpp
229230
${TORCH_SRC_DIR}/csrc/api/src/data/samplers/sequential.cpp
230231
${TORCH_SRC_DIR}/csrc/api/src/data/samplers/stream.cpp

torch/csrc/api/include/torch/data/samplers.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <torch/data/samplers/base.h>
44
#include <torch/data/samplers/custom_batch_request.h>
5+
#include <torch/data/samplers/distributed.h>
56
#include <torch/data/samplers/random.h>
67
#include <torch/data/samplers/sequential.h>
78
#include <torch/data/samplers/serialize.h>
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#pragma once
2+
3+
#include <torch/csrc/WindowsTorchApiMacro.h>
4+
#include <torch/data/samplers/base.h>
5+
6+
#include <cstddef>
7+
#include <vector>
8+
9+
namespace torch {
10+
namespace serialize {
11+
class OutputArchive;
12+
class InputArchive;
13+
} // namespace serialize
14+
} // namespace torch
15+
16+
namespace torch {
17+
namespace data {
18+
namespace samplers {
19+
20+
/// A `Sampler` that selects a subset of indices to sample from and defines a
21+
/// sampling behavior. In a distributed setting, this selects a subset of the
22+
/// indices depending on the provided num_replicas and rank parameters. The
23+
/// `Sampler` performs a rounding operation based on the `allow_duplicates`
24+
/// parameter to decide the local sample count.
25+
template <typename BatchRequest = std::vector<size_t>>
26+
class DistributedSampler : public Sampler<BatchRequest> {
27+
public:
28+
TORCH_API DistributedSampler(
29+
size_t size,
30+
size_t num_replicas = 1,
31+
size_t rank = 0,
32+
bool allow_duplicates = true)
33+
: size_(size),
34+
num_replicas_(num_replicas),
35+
rank_(rank),
36+
epoch_(0),
37+
allow_duplicates_(allow_duplicates) {}
38+
39+
/// Set the epoch for the current enumeration. This can be used to alter the
40+
/// sample selection and shuffling behavior.
41+
TORCH_API void set_epoch(size_t epoch) {
42+
epoch_ = epoch;
43+
}
44+
45+
TORCH_API size_t epoch() const {
46+
return epoch_;
47+
}
48+
49+
protected:
50+
size_t local_sample_count() {
51+
if (allow_duplicates_) {
52+
return (size_ + num_replicas_ - 1) / num_replicas_;
53+
} else {
54+
return size_ / num_replicas_;
55+
}
56+
}
57+
58+
size_t size_;
59+
size_t num_replicas_;
60+
size_t rank_;
61+
size_t epoch_;
62+
bool allow_duplicates_;
63+
};
64+
65+
/// Select samples randomly. The sampling order is shuffled at each `reset()`
66+
/// call.
67+
class DistributedRandomSampler : public DistributedSampler<> {
68+
public:
69+
TORCH_API DistributedRandomSampler(
70+
size_t size,
71+
size_t num_replicas = 1,
72+
size_t rank = 0,
73+
bool allow_duplicates = true);
74+
75+
/// Resets the `DistributedRandomSampler` to a new set of indices.
76+
TORCH_API void reset(optional<size_t> new_size = nullopt) override;
77+
78+
/// Returns the next batch of indices.
79+
TORCH_API optional<std::vector<size_t>> next(size_t batch_size) override;
80+
81+
/// Serializes the `DistributedRandomSampler` to the `archive`.
82+
TORCH_API void save(serialize::OutputArchive& archive) const override;
83+
84+
/// Deserializes the `DistributedRandomSampler` from the `archive`.
85+
TORCH_API void load(serialize::InputArchive& archive) override;
86+
87+
/// Returns the current index of the `DistributedRandomSampler`.
88+
TORCH_API size_t index() const noexcept;
89+
90+
private:
91+
void populate_indices();
92+
93+
size_t begin_index_;
94+
size_t end_index_;
95+
size_t sample_index_;
96+
std::vector<size_t> all_indices_;
97+
};
98+
99+
/// Select samples sequentially.
100+
class DistributedSequentialSampler : public DistributedSampler<> {
101+
public:
102+
TORCH_API DistributedSequentialSampler(
103+
size_t size,
104+
size_t num_replicas = 1,
105+
size_t rank = 0,
106+
bool allow_duplicates = true);
107+
108+
/// Resets the `DistributedSequentialSampler` to a new set of indices.
109+
TORCH_API void reset(optional<size_t> new_size = nullopt) override;
110+
111+
/// Returns the next batch of indices.
112+
TORCH_API optional<std::vector<size_t>> next(size_t batch_size) override;
113+
114+
/// Serializes the `DistributedSequentialSampler` to the `archive`.
115+
TORCH_API void save(serialize::OutputArchive& archive) const override;
116+
117+
/// Deserializes the `DistributedSequentialSampler` from the `archive`.
118+
TORCH_API void load(serialize::InputArchive& archive) override;
119+
120+
/// Returns the current index of the `DistributedSequentialSampler`.
121+
TORCH_API size_t index() const noexcept;
122+
123+
private:
124+
void populate_indices();
125+
126+
size_t begin_index_;
127+
size_t end_index_;
128+
size_t sample_index_;
129+
std::vector<size_t> all_indices_;
130+
};
131+
132+
} // namespace samplers
133+
} // namespace data
134+
} // namespace torch

0 commit comments

Comments
 (0)