@@ -831,6 +831,188 @@ TEST(DataTest, CanUseCustomTypeAsIndexType) {
831
831
}
832
832
}
833
833
834
+ TEST (DataTest, DistributedRandomSamplerSingleReplicaProduceCorrectSamples) {
835
+ size_t sample_count = 10 ;
836
+ samplers::DistributedRandomSampler drs (sample_count);
837
+
838
+ std::vector<size_t > res;
839
+ torch::optional<std::vector<size_t >> idx;
840
+ while ((idx = drs.next (3 )).has_value ()) {
841
+ res.insert (std::end (res), std::begin (*idx), std::end (*idx));
842
+ }
843
+
844
+ ASSERT_EQ (res.size (), sample_count);
845
+
846
+ std::sort (res.begin (), res.end ());
847
+ for (size_t i = 0 ; i < res.size (); ++i) {
848
+ ASSERT_EQ (res[i], i);
849
+ }
850
+ }
851
+
852
+ TEST (DataTest, DistributedRandomSamplerMultiReplicaProduceCorrectSamples) {
853
+ size_t sample_count = 10 ;
854
+ size_t num_replicas = 3 ;
855
+
856
+ auto test_function = [&](bool allow_duplicates,
857
+ size_t local_sample_count,
858
+ std::vector<size_t >& output,
859
+ size_t batch_size) {
860
+ std::vector<std::unique_ptr<samplers::DistributedRandomSampler>> samplers;
861
+
862
+ for (size_t i = 0 ; i < num_replicas; ++i) {
863
+ samplers.emplace_back (
864
+ torch::make_unique<samplers::DistributedRandomSampler>(
865
+ sample_count, num_replicas, i, allow_duplicates));
866
+ }
867
+
868
+ std::vector<size_t > res;
869
+ for (size_t i = 0 ; i < num_replicas; ++i) {
870
+ (*samplers[i]).reset ();
871
+ torch::optional<std::vector<size_t >> idx;
872
+ while ((idx = (*samplers[i]).next (batch_size)).has_value ()) {
873
+ res.insert (std::end (res), std::begin (*idx), std::end (*idx));
874
+ }
875
+ ASSERT_EQ (res.size (), local_sample_count * (i + 1 ));
876
+ }
877
+ std::sort (res.begin (), res.end ());
878
+ ASSERT_EQ (res, output);
879
+ };
880
+
881
+ for (size_t batch_size = 1 ; batch_size <= 3 ; ++batch_size) {
882
+ size_t local_sample_count =
883
+ static_cast <size_t >(std::ceil (sample_count * 1.0 / num_replicas));
884
+ std::vector<size_t > output1{0 , 0 , 1 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 };
885
+ test_function (true , local_sample_count, output1, batch_size);
886
+
887
+ local_sample_count =
888
+ static_cast <size_t >(std::floor (sample_count * 1.0 / num_replicas));
889
+ std::vector<size_t > output2{0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 };
890
+ test_function (false , local_sample_count, output2, batch_size);
891
+ }
892
+ }
893
+
894
+ TEST (DataTest, CanSaveAndLoadDistributedRandomSampler) {
895
+ {
896
+ samplers::DistributedRandomSampler a (10 );
897
+ ASSERT_EQ (a.index (), 0 );
898
+ std::stringstream stream;
899
+ torch::save (a, stream);
900
+
901
+ samplers::DistributedRandomSampler b (10 );
902
+ torch::load (b, stream);
903
+ ASSERT_EQ (b.index (), 0 );
904
+ }
905
+ {
906
+ samplers::DistributedRandomSampler a (10 );
907
+ a.next (3 );
908
+ a.next (4 );
909
+ ASSERT_EQ (a.index (), 7 );
910
+ std::stringstream stream;
911
+ torch::save (a, stream);
912
+
913
+ samplers::DistributedRandomSampler b (10 );
914
+ torch::load (b, stream);
915
+ ASSERT_EQ (b.index (), 7 );
916
+ }
917
+ {
918
+ samplers::DistributedRandomSampler a (10 );
919
+ a.set_epoch (3 );
920
+ std::stringstream stream;
921
+ torch::save (a, stream);
922
+
923
+ samplers::DistributedRandomSampler b (10 );
924
+ torch::load (b, stream);
925
+ ASSERT_EQ (b.epoch (), 3 );
926
+ }
927
+ }
928
+
929
+ TEST (DataTest, DistributedSequentialSamplerSingleReplicaProduceCorrectSamples) {
930
+ size_t sample_count = 10 ;
931
+ size_t batch_size = 3 ;
932
+ samplers::DistributedSequentialSampler dss (sample_count);
933
+
934
+ std::vector<size_t > res;
935
+ torch::optional<std::vector<size_t >> idx;
936
+ while ((idx = dss.next (batch_size)).has_value ()) {
937
+ res.insert (std::end (res), std::begin (*idx), std::end (*idx));
938
+ }
939
+
940
+ ASSERT_EQ (res.size (), sample_count);
941
+
942
+ std::sort (res.begin (), res.end ());
943
+ for (size_t i = 0 ; i < res.size (); ++i) {
944
+ ASSERT_EQ (res[i], i);
945
+ }
946
+ }
947
+
948
+ TEST (DataTest, DistributedSequentialSamplerMultiReplicaProduceCorrectSamples) {
949
+ size_t sample_count = 10 ;
950
+ size_t num_replicas = 3 ;
951
+
952
+ auto test_function = [&](bool allow_duplicates,
953
+ size_t local_sample_count,
954
+ std::vector<size_t >& output,
955
+ size_t batch_size) {
956
+ std::vector<std::unique_ptr<samplers::DistributedSequentialSampler>>
957
+ samplers;
958
+
959
+ for (size_t i = 0 ; i < num_replicas; ++i) {
960
+ samplers.emplace_back (
961
+ torch::make_unique<samplers::DistributedSequentialSampler>(
962
+ sample_count, num_replicas, i, allow_duplicates));
963
+ }
964
+
965
+ std::vector<size_t > res;
966
+ for (size_t i = 0 ; i < num_replicas; ++i) {
967
+ (*samplers[i]).reset ();
968
+ torch::optional<std::vector<size_t >> idx;
969
+ while ((idx = (*samplers[i]).next (batch_size)).has_value ()) {
970
+ res.insert (std::end (res), std::begin (*idx), std::end (*idx));
971
+ }
972
+ ASSERT_EQ (res.size (), local_sample_count * (i + 1 ));
973
+ }
974
+ std::sort (res.begin (), res.end ());
975
+ ASSERT_EQ (res, output);
976
+ };
977
+
978
+ for (size_t batch_size = 1 ; batch_size <= 3 ; ++batch_size) {
979
+ size_t local_sample_count =
980
+ static_cast <size_t >(std::ceil (sample_count * 1.0 / num_replicas));
981
+ std::vector<size_t > output1{0 , 0 , 1 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 };
982
+ test_function (true , local_sample_count, output1, batch_size);
983
+
984
+ local_sample_count =
985
+ static_cast <size_t >(std::floor (sample_count * 1.0 / num_replicas));
986
+ std::vector<size_t > output2{0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 };
987
+ test_function (false , local_sample_count, output2, batch_size);
988
+ }
989
+ }
990
+
991
+ TEST (DataTest, CanSaveAndLoadDistributedSequentialSampler) {
992
+ {
993
+ samplers::DistributedSequentialSampler a (10 );
994
+ ASSERT_EQ (a.index (), 0 );
995
+ std::stringstream stream;
996
+ torch::save (a, stream);
997
+
998
+ samplers::DistributedSequentialSampler b (10 );
999
+ torch::load (b, stream);
1000
+ ASSERT_EQ (b.index (), 0 );
1001
+ }
1002
+ {
1003
+ samplers::DistributedSequentialSampler a (10 );
1004
+ a.next (3 );
1005
+ a.next (4 );
1006
+ ASSERT_EQ (a.index (), 7 );
1007
+ std::stringstream stream;
1008
+ torch::save (a, stream);
1009
+
1010
+ samplers::DistributedSequentialSampler b (10 );
1011
+ torch::load (b, stream);
1012
+ ASSERT_EQ (b.index (), 7 );
1013
+ }
1014
+ }
1015
+
834
1016
TEST (DataLoaderTest, DataLoaderOptionsDefaultAsExpected) {
835
1017
DataLoaderOptions partial_options;
836
1018
FullDataLoaderOptions full_options (partial_options);
0 commit comments