Skip to content

Commit

Permalink
Added lower_bound function that works in device code
Browse files Browse the repository at this point in the history
Co-authored-by: Andrea Bocci <[email protected]>
  • Loading branch information
ariostas and fwyzard committed Jan 17, 2025
1 parent 1416747 commit baa91b3
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
22 changes: 19 additions & 3 deletions HeterogeneousCore/AlpakaInterface/interface/alpakastdAlgorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,28 @@

#include <alpaka/alpaka.hpp>

// reimplementation of std algorithms able to compile with Alpaka,
// mostly by declaring them constexpr (until C++20, which will make it
// constexpr by default. TODO: drop when moving to C++20)
// reimplementation of std algorithms able to work on device code

namespace alpaka_std {

template <typename RandomIt, typename T, typename Compare = std::less<T>>
ALPAKA_FN_HOST_ACC constexpr RandomIt lower_bound(RandomIt first, RandomIt last, const T &value, Compare comp = {}) {
auto count = last - first;

while (count > 0) {
auto it = first;
auto step = count / 2;
it += step;
if (comp(*it, value)) {
first = ++it;
count -= step + 1;
} else {
count = step;
}
}
return first;
}

template <typename RandomIt, typename T, typename Compare = std::less<T>>
ALPAKA_FN_HOST_ACC constexpr RandomIt upper_bound(RandomIt first, RandomIt last, const T &value, Compare comp = {}) {
auto count = last - first;
Expand Down
5 changes: 3 additions & 2 deletions RecoTracker/LSTCore/src/alpaka/Hit.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define RecoTracker_LSTCore_src_alpaka_Hit_h

#include "HeterogeneousCore/AlpakaInterface/interface/workdivision.h"
#include "HeterogeneousCore/AlpakaInterface/interface/alpakastdAlgorithm.h"

#include "RecoTracker/LSTCore/interface/alpaka/Common.h"
#include "RecoTracker/LSTCore/interface/ModulesSoA.h"
Expand Down Expand Up @@ -103,7 +104,7 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst {
((ihit_z > 0) - (ihit_z < 0)) *
alpaka::math::acosh(
acc, alpaka::math::sqrt(acc, ihit_x * ihit_x + ihit_y * ihit_y + ihit_z * ihit_z) / hits.rts()[ihit]);
auto found_pointer = std::lower_bound(modules.mapdetId(), modules.mapdetId() + nModules, iDetId);
auto found_pointer = alpaka_std::lower_bound(modules.mapdetId(), modules.mapdetId() + nModules, iDetId);
int found_index = std::distance(modules.mapdetId(), found_pointer);
if (found_pointer == modules.mapdetId() + nModules)
found_index = -1;
Expand All @@ -112,7 +113,7 @@ namespace ALPAKA_ACCELERATOR_NAMESPACE::lst {
hits.moduleIndices()[ihit] = lastModuleIndex;

if (modules.subdets()[lastModuleIndex] == Endcap && modules.moduleType()[lastModuleIndex] == TwoS) {
found_pointer = std::lower_bound(geoMapDetId, geoMapDetId + nEndCapMap, iDetId);
found_pointer = alpaka_std::lower_bound(geoMapDetId, geoMapDetId + nEndCapMap, iDetId);
found_index = std::distance(geoMapDetId, found_pointer);
if (found_pointer == geoMapDetId + nEndCapMap)
found_index = -1;
Expand Down

0 comments on commit baa91b3

Please sign in to comment.