@@ -825,176 +825,6 @@ XGBOOST_DEVICE auto tcrend(xgboost::common::Span<T> const &span) { // NOLINT
825
825
return tcrbegin (span) + span.size ();
826
826
}
827
827
828
- // This type sorts an array which is divided into multiple groups. The sorting is influenced
829
- // by the function object 'Comparator'
830
- template <typename T>
831
- class SegmentSorter {
832
- private:
833
- // Items sorted within the group
834
- caching_device_vector<T> ditems_;
835
-
836
- // Original position of the items before they are sorted descending within their groups
837
- caching_device_vector<uint32_t > doriginal_pos_;
838
-
839
- // Segments within the original list that delineates the different groups
840
- caching_device_vector<uint32_t > group_segments_;
841
-
842
- // Need this on the device as it is used in the kernels
843
- caching_device_vector<uint32_t > dgroups_; // Group information on device
844
-
845
- // Where did the item that was originally present at position 'x' move to after they are sorted
846
- caching_device_vector<uint32_t > dindexable_sorted_pos_;
847
-
848
- // Initialize everything but the segments
849
- void Init (uint32_t num_elems) {
850
- ditems_.resize (num_elems);
851
-
852
- doriginal_pos_.resize (num_elems);
853
- thrust::sequence (doriginal_pos_.begin (), doriginal_pos_.end ());
854
- }
855
-
856
- // Initialize all with group info
857
- void Init (const std::vector<uint32_t > &groups) {
858
- uint32_t num_elems = groups.back ();
859
- this ->Init (num_elems);
860
- this ->CreateGroupSegments (groups);
861
- }
862
-
863
- public:
864
- // This needs to be public due to device lambda
865
- void CreateGroupSegments (const std::vector<uint32_t > &groups) {
866
- uint32_t num_elems = groups.back ();
867
- group_segments_.resize (num_elems, 0 );
868
-
869
- dgroups_ = groups;
870
-
871
- if (GetNumGroups () == 1 ) return ; // There are no segments; hence, no need to compute them
872
-
873
- // Define the segments by assigning a group ID to each element
874
- const uint32_t *dgroups = dgroups_.data ().get ();
875
- uint32_t ngroups = dgroups_.size ();
876
- auto ComputeGroupIDLambda = [=] __device__ (uint32_t idx) {
877
- return thrust::upper_bound (thrust::seq, dgroups, dgroups + ngroups, idx) -
878
- dgroups - 1 ;
879
- }; // NOLINT
880
-
881
- thrust::transform (thrust::make_counting_iterator (static_cast <uint32_t >(0 )),
882
- thrust::make_counting_iterator (num_elems),
883
- group_segments_.begin (),
884
- ComputeGroupIDLambda);
885
- }
886
-
887
- // Accessors that returns device pointer
888
- inline uint32_t GetNumItems () const { return ditems_.size (); }
889
- inline const xgboost::common::Span<const T> GetItemsSpan () const {
890
- return { ditems_.data ().get (), ditems_.size () };
891
- }
892
-
893
- inline const xgboost::common::Span<const uint32_t > GetOriginalPositionsSpan () const {
894
- return { doriginal_pos_.data ().get (), doriginal_pos_.size () };
895
- }
896
-
897
- inline const xgboost::common::Span<const uint32_t > GetGroupSegmentsSpan () const {
898
- return { group_segments_.data ().get (), group_segments_.size () };
899
- }
900
-
901
- inline uint32_t GetNumGroups () const { return dgroups_.size () - 1 ; }
902
- inline const xgboost::common::Span<const uint32_t > GetGroupsSpan () const {
903
- return { dgroups_.data ().get (), dgroups_.size () };
904
- }
905
-
906
- inline const xgboost::common::Span<const uint32_t > GetIndexableSortedPositionsSpan () const {
907
- return { dindexable_sorted_pos_.data ().get (), dindexable_sorted_pos_.size () };
908
- }
909
-
910
- // Sort an array that is divided into multiple groups. The array is sorted within each group.
911
- // This version provides the group information that is on the host.
912
- // The array is sorted based on an adaptable binary predicate. By default a stateless predicate
913
- // is used.
914
- template <typename Comparator = thrust::greater<T>>
915
- void SortItems (const T *ditems, uint32_t item_size, const std::vector<uint32_t > &groups,
916
- const Comparator &comp = Comparator()) {
917
- this ->Init (groups);
918
- this ->SortItems (ditems, item_size, this ->GetGroupSegmentsSpan (), comp);
919
- }
920
-
921
- // Sort an array that is divided into multiple groups. The array is sorted within each group.
922
- // This version provides the group information that is on the device.
923
- // The array is sorted based on an adaptable binary predicate. By default a stateless predicate
924
- // is used.
925
- template <typename Comparator = thrust::greater<T>>
926
- void SortItems (const T *ditems, uint32_t item_size,
927
- const xgboost::common::Span<const uint32_t > &group_segments,
928
- const Comparator &comp = Comparator()) {
929
- this ->Init (item_size);
930
-
931
- // Sort the items that are grouped. We would like to avoid using predicates to perform the sort,
932
- // as thrust resorts to using a merge sort as opposed to a much much faster radix sort
933
- // when comparators are used. Hence, the following algorithm is used. This is done so that
934
- // we can grab the appropriate related values from the original list later, after the
935
- // items are sorted.
936
- //
937
- // Here is the internal representation:
938
- // dgroups_: [ 0, 3, 5, 8, 10 ]
939
- // group_segments_: 0 0 0 | 1 1 | 2 2 2 | 3 3
940
- // doriginal_pos_: 0 1 2 | 3 4 | 5 6 7 | 8 9
941
- // ditems_: 1 0 1 | 2 1 | 1 3 3 | 4 4 (from original items)
942
- //
943
- // Sort the items first and make a note of the original positions in doriginal_pos_
944
- // based on the sort
945
- // ditems_: 4 4 3 3 2 1 1 1 1 0
946
- // doriginal_pos_: 8 9 6 7 3 0 2 4 5 1
947
- // NOTE: This consumes space, but is much faster than some of the other approaches - sorting
948
- // in kernel, sorting using predicates etc.
949
-
950
- ditems_.assign (thrust::device_ptr<const T>(ditems),
951
- thrust::device_ptr<const T>(ditems) + item_size);
952
-
953
- // Allocator to be used by sort for managing space overhead while sorting
954
- dh::XGBCachingDeviceAllocator<char > alloc;
955
-
956
- thrust::stable_sort_by_key (thrust::cuda::par (alloc),
957
- ditems_.begin (), ditems_.end (),
958
- doriginal_pos_.begin (), comp);
959
-
960
- if (GetNumGroups () == 1 ) return ; // The entire array is sorted, as it isn't segmented
961
-
962
- // Next, gather the segments based on the doriginal_pos_. This is to reflect the
963
- // holisitic item sort order on the segments
964
- // group_segments_c_: 3 3 2 2 1 0 0 1 2 0
965
- // doriginal_pos_: 8 9 6 7 3 0 2 4 5 1 (stays the same)
966
- caching_device_vector<uint32_t > group_segments_c (item_size);
967
- thrust::gather (doriginal_pos_.begin (), doriginal_pos_.end (),
968
- dh::tcbegin (group_segments), group_segments_c.begin ());
969
-
970
- // Now, sort the group segments so that you may bring the items within the group together,
971
- // in the process also noting the relative changes to the doriginal_pos_ while that happens
972
- // group_segments_c_: 0 0 0 1 1 2 2 2 3 3
973
- // doriginal_pos_: 0 2 1 3 4 6 7 5 8 9
974
- thrust::stable_sort_by_key (thrust::cuda::par (alloc),
975
- group_segments_c.begin (), group_segments_c.end (),
976
- doriginal_pos_.begin (), thrust::less<uint32_t >());
977
-
978
- // Finally, gather the original items based on doriginal_pos_ to sort the input and
979
- // to store them in ditems_
980
- // doriginal_pos_: 0 2 1 3 4 6 7 5 8 9 (stays the same)
981
- // ditems_: 1 1 0 2 1 3 3 1 4 4 (from unsorted items - ditems)
982
- thrust::gather (doriginal_pos_.begin (), doriginal_pos_.end (),
983
- thrust::device_ptr<const T>(ditems), ditems_.begin ());
984
- }
985
-
986
- // Determine where an item that was originally present at position 'x' has been relocated to
987
- // after a sort. Creation of such an index has to be explicitly requested after a sort
988
- void CreateIndexableSortedPositions () {
989
- dindexable_sorted_pos_.resize (GetNumItems ());
990
- thrust::scatter (thrust::make_counting_iterator (static_cast <uint32_t >(0 )),
991
- thrust::make_counting_iterator (GetNumItems ()), // Rearrange indices...
992
- // ...based on this map
993
- dh::tcbegin (GetOriginalPositionsSpan ()),
994
- dindexable_sorted_pos_.begin ()); // Write results into this
995
- }
996
- };
997
-
998
828
// Atomic add function for gradients
999
829
template <typename OutputGradientT, typename InputGradientT>
1000
830
XGBOOST_DEV_INLINE void AtomicAddGpair (OutputGradientT* dest,
0 commit comments