diff --git a/doc/algorithm.qbk b/doc/algorithm.qbk index 02a156236..be2cd8d65 100644 --- a/doc/algorithm.qbk +++ b/doc/algorithm.qbk @@ -233,6 +233,8 @@ Convert a sequence of hexadecimal characters into a sequence of integers or char Convert a sequence of integral types into a lower case hexadecimal sequence of characters [endsect:hex_lower] +[include indirect_sort.qbk] + [include is_palindrome.qbk] [include is_partitioned_until.qbk] diff --git a/doc/indirect_sort.qbk b/doc/indirect_sort.qbk new file mode 100644 index 000000000..6ac6de569 --- /dev/null +++ b/doc/indirect_sort.qbk @@ -0,0 +1,111 @@ +[/ File indirect_sort.qbk] + +[section:indirect_sort indirect_sort ] + +[/license +Copyright (c) 2023 Marshall Clow + +Distributed under the Boost Software License, Version 1.0. +(See accompanying file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +] + +There are times that you want a sorted version of a sequence, but for some reason you don't want to modify it. Maybe the elements in the sequence can't be moved/copied, e.g. the sequence is const, or they're just really expensive to move around. An example of this might be a sequence of records from a database. + +That's where indirect sorting comes in. In a "normal" sort, the elements of the sequence to be sorted are shuffled in place. In indirect sorting, the elements are unchanged, but the sort algorithm returns a "permutation" of the elements that, when applied, will put the elements in the sequence in a sorted order. + +Assume have a sequence `[first, last)` of 1000 items that are expensive to swap: +``` + std::sort(first, last); // ['O(N ln N)] comparisons and ['O(N ln N)] swaps (of the element type). +``` + +On the other hand, using indirect sorting: +``` + auto perm = indirect_sort(first, last); // ['O(N lg N)] comparisons and ['O(N lg N)] swaps (of size_t). + apply_permutation(first, last, perm.begin(), perm.end()); // ['O(N)] swaps (of the element type) +``` + +If the element type is sufficiently expensive to swap, then 10,000 swaps of size_t + 1000 swaps of the element_type could be cheaper than 10,000 swaps of the element_type. + +Or maybe you don't need the elements to actually be sorted - you just want to traverse them in a sorted order: +``` + auto permutation = indirect_sort(first, last); + for (size_t idx: permutation) + std::cout << first[idx] << std::endl; +``` + + +Assume that instead of an "array of structures", you have a "struct of arrays". +``` +struct AType { + Type0 key; + Type1 value1; + Type1 value2; + }; + +std::array arrayOfStruct; +``` + +versus: + +``` +template +struct AType { + std::array key; + std::array value1; + std::array value2; + }; + +AType<1000> structOfArrays; +``` + +Sorting the first one is easy, because each set of fields (`key`, `value1`, `value2`) are part of the same struct. But with indirect sorting, the second one is easy to sort as well - just sort the keys, then apply the permutation to the keys and the values: +``` + auto perm = indirect_sort(std::begin(structOfArrays.key), std::end(structOfArrays.key)); + apply_permutation(structOfArrays.key.begin(), structOfArrays.key.end(), perm.begin(), perm.end()); + apply_permutation(structOfArrays.value1.begin(), structOfArrays.value1.end(), perm.begin(), perm.end()); + apply_permutation(structOfArrays.value2.begin(), structOfArrays.value2.end(), perm.begin(), perm.end()); +``` + +[heading interface] + +The function `indirect_sort` returns a `vector` containing the permutation necessary to put the input sequence into a sorted order. One version uses `std::less` to do the comparisons; the other lets the caller pass predicate to do the comparisons. + +There is also a variant called `indirect_stable_sort`; it bears the same relation to `indirect_sort` that `std::stable_sort` does to `std::sort`. + +``` +template +std::vector indirect_sort (RAIterator first, RAIterator last); + +template +std::vector indirect_sort (RAIterator first, RAIterator last, BinaryPredicate pred); + +template +std::vector indirect_stable_sort (RAIterator first, RAIterator last); + +template +std::vector indirect_stable_sort (RAIterator first, RAIterator last, BinaryPredicate pred); +``` + +[heading Examples] + +[heading Iterator Requirements] + +`indirect_sort` requires random-access iterators. + +[heading Complexity] + +Both of the variants of `indirect_sort` run in ['O(N lg N)] time; they are not more (or less) efficient than `std::sort`. There is an extra layer of indirection on each comparison, but all of the swaps are done on values of type `size_t` + +[heading Exception Safety] + +[heading Notes] + +In numpy, this algorithm is known as `argsort`. + +[endsect] + +[/ File indirect_sort.qbk +Copyright 2023 Marshall Clow +Distributed under the Boost Software License, Version 1.0. +(See accompanying file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt). +] diff --git a/include/boost/algorithm/indirect_sort.hpp b/include/boost/algorithm/indirect_sort.hpp new file mode 100644 index 000000000..72c2f2e12 --- /dev/null +++ b/include/boost/algorithm/indirect_sort.hpp @@ -0,0 +1,207 @@ +/* + Copyright (c) Marshall Clow 2023. + + Distributed under the Boost Software License, Version 1.0. (See accompanying + file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) + +*/ + +/// \file indirect_sort.hpp +/// \brief indirect sorting algorithms +/// \author Marshall Clow +/// + +#ifndef BOOST_ALGORITHM_INDIRECT_SORT +#define BOOST_ALGORITHM_INDIRECT_SORT + +#include // for std::sort (and others) +#include // for std::less +#include // for std::vector + +#include + +namespace boost { namespace algorithm { + +typedef std::vector Permutation; + +namespace detail { + + template + struct indirect_predicate { + indirect_predicate (Predicate pred, Iter iter) + : pred_(pred), iter_(iter) {} + + bool operator ()(size_t a, size_t b) const { + return pred_(iter_[a], iter_[b]); + } + + Predicate pred_; + Iter iter_; + }; + + // Initialize a permutation of size 'size'. [ 0, 1, 2, ... size-1 ] + // Note: it would be nice to use 'iota' here, but that call writes over + // existing elements - not append them. I don't want to initialize + // the elements of the permutation to zero, and then immediately + // overwrite them. + void init_permutation (Permutation &p, size_t size) { + p.reserve(size); + boost::algorithm::iota_n( + std::back_insert_iterator(p), size_t(0), size); + } +} + + // ===== sort ===== + +/// \fn indirect_sort (RAIterator first, RAIterator last, Predicate pred) +/// \returns a permutation of the elements in the range [first, last) +/// such that when the permutation is applied to the sequence, +/// the result is ordered as if 'std::sort(first, last, pred)' +// was called on the sequence. +/// +/// \param first The start of the input sequence +/// \param last The end of the input sequence +/// \param pred The predicate to compare elements with +/// +template +Permutation indirect_sort (RAIterator first, RAIterator last, Pred pred) { + + Permutation ret; + detail::init_permutation(ret, std::distance(first, last)); + std::sort(ret.begin(), ret.end(), + detail::indirect_predicate(pred, first)); + return ret; +} + +/// \fn indirect_sort (RAIterator first, RAIterator last) +/// \returns a permutation of the elements in the range [first, last) +/// such that when the permutation is applied to the sequence, +/// the result is ordered as if 'std::sort(first, last)' +// was called on the sequence. +/// +/// \param first The start of the input sequence +/// \param last The end of the input sequence +/// +template +Permutation indirect_sort (RAIterator first, RAIterator last) { + return indirect_sort(first, last, + std::less::value_type>()); +} + + // ===== stable_sort ===== + +/// \fn indirect_stable_sort (RAIterator first, RAIterator last, Predicate pred) +/// \returns a permutation of the elements in the range [first, last) +/// such that when the permutation is applied to the sequence, +/// the result is ordered as if 'std::stable_sort(first, last, pred)' +// was called on the sequence. +/// +/// \param first The start of the input sequence +/// \param last The end of the input sequence +/// \param pred The predicate to compare elements with +/// +template +Permutation indirect_stable_sort (RAIterator first, RAIterator last, Pred pred) { + Permutation ret; + detail::init_permutation(ret, std::distance(first, last)); + std::stable_sort(ret.begin(), ret.end(), + detail::indirect_predicate(pred, first)); + return ret; +} + +/// \fn indirect_stable_sort (RAIterator first, RAIterator last) +/// \returns a permutation of the elements in the range [first, last) +/// such that when the permutation is applied to the sequence, +/// the result is ordered as if 'std::stable_sort(first, last)' +// was called on the sequence. +/// +/// \param first The start of the input sequence +/// \param last The end of the input sequence +/// +template +Permutation indirect_stable_sort (RAIterator first, RAIterator last) { + return indirect_stable_sort(first, last, + std::less::value_type>()); +} + + // ===== partial_sort ===== + +/// \fn indirect_partial_sort (RAIterator first, RAIterator middle, RAIterator last, Predicate pred) +/// \returns a permutation of the elements in the range [first, last) +/// such that when the permutation is applied to the sequence, +/// the result is ordered as if 'std::partial_sort(first, middle, last, pred)' +// was called on the sequence. +/// +/// \param first The start of the input sequence +/// \param middle The end of the range to be sorted +/// \param last The end of the input sequence +/// \param pred The predicate to compare elements with +/// +template +Permutation indirect_partial_sort (RAIterator first, RAIterator middle, + RAIterator last, Pred pred) { + Permutation ret; + detail::init_permutation(ret, std::distance(first, last)); + std::partial_sort(ret.begin(), ret.begin() + std::distance(first, middle), ret.end(), + detail::indirect_predicate(pred, first)); + return ret; +} + +/// \fn indirect_partial_sort (RAIterator first, RAIterator middle, RAIterator last) +/// \returns a permutation of the elements in the range [first, last) +/// such that when the permutation is applied to the sequence, +/// the result is ordered as if 'std::partial_sort(first, middle, last)' +// was called on the sequence. +/// +/// \param first The start of the input sequence +/// \param middle The end of the range to be sorted +/// \param last The end of the input sequence +/// +template +Permutation indirect_partial_sort (RAIterator first, RAIterator middle, RAIterator last) { + return indirect_partial_sort(first, middle, last, + std::less::value_type>()); +} + + // ===== nth_element ===== + +/// \fn indirect_nth_element (RAIterator first, RAIterator nth, RAIterator last, Predicate p) +/// \returns a permutation of the elements in the range [first, last) +/// such that when the permutation is applied to the sequence, +/// the result is ordered as if 'std::nth_element(first, nth, last, p)' +// was called on the sequence. +/// +/// \param first The start of the input sequence +/// \param nth The sort partition point in the input sequence +/// \param last The end of the input sequence +/// \param pred The predicate to compare elements with +/// +template +Permutation indirect_nth_element (RAIterator first, RAIterator nth, + RAIterator last, Pred pred) { + Permutation ret; + detail::init_permutation(ret, std::distance(first, last)); + std::nth_element(ret.begin(), ret.begin() + std::distance(first, nth), ret.end(), + detail::indirect_predicate(pred, first)); + return ret; +} + +/// \fn indirect_nth_element (RAIterator first, RAIterator nth, RAIterator last) +/// \returns a permutation of the elements in the range [first, last) +/// such that when the permutation is applied to the sequence, +/// the result is ordered as if 'std::nth_element(first, nth, last)' +// was called on the sequence. +/// +/// \param first The start of the input sequence +/// \param nth The sort partition point in the input sequence +/// \param last The end of the input sequence +/// +template +Permutation indirect_nth_element (RAIterator first, RAIterator nth, RAIterator last) { + return indirect_nth_element(first, nth, last, + std::less::value_type>()); +} + +}} + +#endif // BOOST_ALGORITHM_INDIRECT_SORT diff --git a/test/Jamfile.v2 b/test/Jamfile.v2 index aef6bdb38..3390234f2 100644 --- a/test/Jamfile.v2 +++ b/test/Jamfile.v2 @@ -88,6 +88,10 @@ alias unit_test_framework # Apply_permutation tests [ run apply_permutation_test.cpp unit_test_framework : : : : apply_permutation_test ] + +# Indirect_sort tests + [ run indirect_sort_test.cpp unit_test_framework : : : : indirect_sort_test ] + # Find tests [ run find_not_test.cpp unit_test_framework : : : : find_not_test ] [ run find_backward_test.cpp unit_test_framework : : : : find_backward_test ] diff --git a/test/indirect_sort_test.cpp b/test/indirect_sort_test.cpp new file mode 100644 index 000000000..fa1b7e24f --- /dev/null +++ b/test/indirect_sort_test.cpp @@ -0,0 +1,348 @@ +/* + Copyright (c) Marshall Clow 2023. + + Distributed under the Boost Software License, Version 1.0. (See accompanying + file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) + + For more information, see http://www.boost.org +*/ + +#include +#include +#include +#include +#include + +#define BOOST_TEST_MAIN +#include + +#include +#include +#include +#include + +using boost::algorithm::Permutation; + +// A permutation of size N is a sequence of values in the range [0..N) +// such that no value appears more than once in the permutation. +bool is_a_permutation(Permutation p, size_t N) { + if (p.size() != N) return false; + +// Sort the permutation, and ensure that each value appears exactly once. + std::sort(p.begin(), p.end()); + for (size_t i = 0; i < N; ++i) + if (p[i] != i) return false; + return true; +} + +template ::value_type> > +struct indirect_comp { + indirect_comp (Iter it, Comp c = Comp()) + : iter_(it), comp_(c) {} + + bool operator ()(size_t a, size_t b) const { return comp_(iter_[a], iter_[b]);} + + Iter iter_; + Comp comp_; +}; + + //// ======================= + //// ==== indirect_sort ==== + //// ======================= +template +void test_one_sort(Iter first, Iter last) { + Permutation perm = boost::algorithm::indirect_sort(first, last); + BOOST_CHECK (is_a_permutation(perm, std::distance(first, last))); + BOOST_CHECK (boost::algorithm::is_sorted(perm.begin(), perm.end(), indirect_comp(first))); + +// Make a copy of the data, apply the permutation, and ensure that it is sorted. + typedef std::vector::value_type> Vector; + Vector v(first, last); + boost::algorithm::apply_permutation(v.begin(), v.end(), perm.begin(), perm.end()); + BOOST_CHECK (boost::algorithm::is_sorted(v.begin(), v.end())); +} + +template +void test_one_sort(Iter first, Iter last, Comp comp) { + Permutation perm = boost::algorithm::indirect_sort(first, last, comp); + BOOST_CHECK (is_a_permutation(perm, std::distance(first, last))); + BOOST_CHECK (boost::algorithm::is_sorted(perm.begin(), perm.end(), + indirect_comp(first, comp))); + +// Make a copy of the data, apply the permutation, and ensure that it is sorted. + typedef std::vector::value_type> Vector; + Vector v(first, last); + boost::algorithm::apply_permutation(v.begin(), v.end(), perm.begin(), perm.end()); + BOOST_CHECK (boost::algorithm::is_sorted(v.begin(), v.end(), comp)); +} + + +BOOST_AUTO_TEST_CASE(test_sort) { + int num[] = { 1,3,5,7,9, 2, 4, 6, 8, 10 }; + const int sz = sizeof (num)/sizeof(num[0]); + int *first = &num[0]; + int const *cFirst = &num[0]; + +// Test subsets + for (size_t i = 0; i <= sz; ++i) { + test_one_sort(first, first + i); + test_one_sort(first, first + i, std::greater()); + + // test with constant inputs + test_one_sort(cFirst, cFirst + i); + test_one_sort(cFirst, cFirst + i, std::greater()); + } + +// make sure we work with iterators as well as pointers + std::vector v(first, first + sz); + test_one_sort(v.begin(), v.end()); + test_one_sort(v.begin(), v.end(), std::greater()); + } + + + //// ============================== + //// ==== indirect_stable_sort ==== + //// ============================== + +template +struct MyPair { + MyPair () {} + + MyPair (const T1 &t1, const T2 &t2) + : first(t1), second(t2) {} + + T1 first; + T2 second; +}; + +template +bool operator < (const MyPair& lhs, const MyPair& rhs) { + return lhs.first < rhs.first; // compare only the first elements +} + +template +bool MyGreater (const MyPair& lhs, const MyPair& rhs) { + return lhs.first > rhs.first; // compare only the first elements +} + +template +void test_one_stable_sort(Iter first, Iter last) { + Permutation perm = boost::algorithm::indirect_stable_sort(first, last); + BOOST_CHECK (is_a_permutation(perm, std::distance(first, last))); + BOOST_CHECK (boost::algorithm::is_sorted(perm.begin(), perm.end(), indirect_comp(first))); + + if (first != last) { + Iter iFirst = first; + Iter iSecond = first; ++iSecond; + + while (iSecond != last) { + if (iFirst->first == iSecond->first) + BOOST_CHECK(iFirst->second < iSecond->second); + ++iFirst; + ++iSecond; + } + } + +// Make a copy of the data, apply the permutation, and ensure that it is sorted. + typedef std::vector::value_type> Vector; + Vector v(first, last); + boost::algorithm::apply_permutation(v.begin(), v.end(), perm.begin(), perm.end()); + BOOST_CHECK (boost::algorithm::is_sorted(v.begin(), v.end())); +} + +template +void test_one_stable_sort(Iter first, Iter last, Comp comp) { + Permutation perm = boost::algorithm::indirect_stable_sort(first, last, comp); + BOOST_CHECK (is_a_permutation(perm, std::distance(first, last))); + BOOST_CHECK (boost::algorithm::is_sorted(perm.begin(), perm.end(), indirect_comp(first, comp))); + + if (first != last) { + Iter iFirst = first; + Iter iSecond = first; ++iSecond; + + while (iSecond != last) { + if (iFirst->first == iSecond->first) + BOOST_CHECK(iFirst->second < iSecond->second); + ++iFirst; + ++iSecond; + } + } + +// Make a copy of the data, apply the permutation, and ensure that it is sorted. + typedef std::vector::value_type> Vector; + Vector v(first, last); + boost::algorithm::apply_permutation(v.begin(), v.end(), perm.begin(), perm.end()); + BOOST_CHECK (boost::algorithm::is_sorted(v.begin(), v.end(), comp)); +} + +BOOST_AUTO_TEST_CASE(test_stable_sort) { + typedef MyPair Pair; + const int sz = 10; + Pair vals[sz]; + + for (int i = 0; i < sz; ++i) { + vals[i].first = 100 - (i >> 1); + vals[i].second = i; + } + + Pair *first = &vals[0]; + Pair const *cFirst = &vals[0]; + +// Test subsets + for (size_t i = 0; i <= sz; ++i) { + test_one_stable_sort(first, first + i); + test_one_stable_sort(first, first + i, MyGreater); + + // test with constant inputs + test_one_sort(cFirst, cFirst + i); + test_one_sort(cFirst, cFirst + i, MyGreater); + } +} + + //// =============================== + //// ==== indirect_partial_sort ==== + //// =============================== + +template +void test_one_partial_sort(Iter first, Iter middle, Iter last) { + const size_t middleIdx = std::distance(first, middle); + Permutation perm = boost::algorithm::indirect_partial_sort(first, middle, last); + BOOST_CHECK (is_a_permutation(perm, std::distance(first, last))); + BOOST_CHECK (boost::algorithm::is_sorted(perm.begin(), perm.begin() + middleIdx, indirect_comp(first))); + +// Make a copy of the data, apply the permutation, and ensure that it is sorted. + typedef std::vector::value_type> Vector; + Vector v(first, last); + boost::algorithm::apply_permutation(v.begin(), v.end(), perm.begin(), perm.end()); + BOOST_CHECK (boost::algorithm::is_sorted(v.begin(), v.begin() + middleIdx)); + +// Make sure that [middle, end) are all "greater" than the sorted part + if (middleIdx > 0) { + typename Vector::iterator lastSorted = v.begin() + middleIdx - 1; + for (typename Vector::iterator it = v.begin () + middleIdx; it != v.end(); ++it) + BOOST_CHECK(*lastSorted < *it); + } +} + +template +void test_one_partial_sort(Iter first, Iter middle, Iter last, Comp comp) { + const size_t middleIdx = std::distance(first, middle); + Permutation perm = boost::algorithm::indirect_partial_sort(first, middle, last, comp); + BOOST_CHECK (is_a_permutation(perm, std::distance(first, last))); + BOOST_CHECK (boost::algorithm::is_sorted(perm.begin(), perm.begin() + middleIdx, + indirect_comp(first, comp))); + +// Make a copy of the data, apply the permutation, and ensure that it is sorted. + typedef std::vector::value_type> Vector; + Vector v(first, last); + boost::algorithm::apply_permutation(v.begin(), v.end(), perm.begin(), perm.end()); + BOOST_CHECK (boost::algorithm::is_sorted(v.begin(), v.begin() + middleIdx, comp)); + +// Make sure that [middle, end) are all "greater" than the sorted part + if (middleIdx > 0) { + typename Vector::iterator lastSorted = v.begin() + middleIdx - 1; + for (typename Vector::iterator it = v.begin () + middleIdx; it != v.end(); ++it) + BOOST_CHECK(comp(*lastSorted, *it)); + } +} + + +BOOST_AUTO_TEST_CASE(test_partial_sort) { + int num[] = { 1,3,5,7,9, 2, 4, 6, 8, 10 }; + const int sz = sizeof (num)/sizeof(num[0]); + int *first = &num[0]; + int const *cFirst = &num[0]; + +// Test subsets + for (size_t i = 0; i <= sz; ++i) { + for (size_t j = 0; j < i; ++j) { + test_one_partial_sort(first, first + j, first + i); + test_one_partial_sort(first, first + j, first + i, std::greater()); + + // test with constant inputs + test_one_partial_sort(cFirst, cFirst + j, cFirst + i); + test_one_partial_sort(cFirst, cFirst + j, cFirst + i, std::greater()); + } + } + +// make sure we work with iterators as well as pointers + std::vector v(first, first + sz); + test_one_partial_sort(v.begin(), v.begin() + (sz / 2), v.end()); + test_one_partial_sort(v.begin(), v.begin() + (sz / 2), v.end(), std::greater()); + } + + + //// =================================== + //// ==== indirect_nth_element_sort ==== + //// =================================== + +template +void test_one_nth_element(Iter first, Iter nth, Iter last) { + const size_t nthIdx = std::distance(first, nth); + Permutation perm = boost::algorithm::indirect_nth_element(first, nth, last); + BOOST_CHECK (is_a_permutation(perm, std::distance(first, last))); + + for (size_t i = 0; i < nthIdx; ++i) + BOOST_CHECK(!(first[perm[nthIdx]] < first[perm[i]])); // all items before the nth element are <= the nth element + for (size_t i = nthIdx; i < std::distance(first, last); ++i) + BOOST_CHECK(!(first[perm[i]] < first[perm[nthIdx]])); // all items before the nth element are >= the nth element + +// Make a copy of the data, apply the permutation, and ensure that the result is correct. + typedef std::vector::value_type> Vector; + Vector v(first, last); + boost::algorithm::apply_permutation(v.begin(), v.end(), perm.begin(), perm.end()); + + for (size_t i = 0; i < nthIdx; ++i) + BOOST_CHECK(!(v[nthIdx] < v[i])); // all items before the nth element are <= the nth element + for (size_t i = nthIdx; i < v.size(); ++i) + BOOST_CHECK(!(v[i] < v[nthIdx])); // all items before the nth element are >= the nth element +} + +template +void test_one_nth_element(Iter first, Iter nth, Iter last, Comp comp) { + const size_t nthIdx = std::distance(first, nth); + + Permutation perm = boost::algorithm::indirect_nth_element(first, nth, last, comp); + BOOST_CHECK (is_a_permutation(perm, std::distance(first, last))); + for (size_t i = 0; i < nthIdx; ++i) + BOOST_CHECK(!comp(first[perm[nthIdx]], first[perm[i]])); // all items before the nth element are <= the nth element + for (size_t i = nthIdx; i < std::distance(first, last); ++i) + BOOST_CHECK(!comp(first[perm[i]], first[perm[nthIdx]])); // all items before the nth element are >= the nth element + + +// Make a copy of the data, apply the permutation, and ensure that the result is correct. + typedef std::vector::value_type> Vector; + Vector v(first, last); + boost::algorithm::apply_permutation(v.begin(), v.end(), perm.begin(), perm.end()); + + for (size_t i = 0; i < nthIdx; ++i) + BOOST_CHECK(!comp(v[nthIdx], v[i])); // all items before the nth element are <= the nth element + for (size_t i = nthIdx; i < v.size(); ++i) + BOOST_CHECK(!comp(v[i], v[nthIdx])); // all items before the nth element are >= the nth element +} + + +BOOST_AUTO_TEST_CASE(test_nth_element) { + int num[] = { 1, 3, 5, 7, 9, 2, 4, 6, 8, 10, 1, 2, 3, 4, 5 }; + const int sz = sizeof (num)/sizeof(num[0]); + int *first = &num[0]; + int const *cFirst = &num[0]; + +// Test subsets + for (size_t i = 0; i <= sz; ++i) { + for (size_t j = 0; j < i; ++j) { + test_one_nth_element(first, first + j, first + i); + test_one_nth_element(first, first + j, first + i, std::greater()); + + // test with constant inputs + test_one_nth_element(cFirst, cFirst + j, cFirst + i); + test_one_nth_element(cFirst, cFirst + j, cFirst + i, std::greater()); + } + } + +// make sure we work with iterators as well as pointers + std::vector v(first, first + sz); + test_one_nth_element(v.begin(), v.begin() + (sz / 2), v.end()); + test_one_nth_element(v.begin(), v.begin() + (sz / 2), v.end(), std::greater()); + }