Skip to content

Commit 1fe40ca

Browse files
committed
Add test of batched clustering
1 parent 15a0a96 commit 1fe40ca

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

tests/test_batched_clustering.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
2+
#include "CLUEstering/CLUEstering.hpp"
3+
#include "CLUEstering/utils/detail/get_cluster_properties.hpp"
4+
5+
#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN
6+
#include "doctest.h"
7+
8+
TEST_CASE("Test batched clustering with fixed batch size") {
9+
const auto device = clue::get_device(0u);
10+
clue::Queue queue(device);
11+
12+
clue::PointsHost<2> h_points = clue::read_csv<2>(queue, "../../../data/batched_data_1024.csv");
13+
const auto n_points = h_points.size();
14+
clue::PointsDevice<2> d_points(queue, n_points);
15+
16+
const float dc{1.3f}, rhoc{10.f}, outlier{1.3f};
17+
clue::Clusterer<2> algo(queue, dc, rhoc, outlier);
18+
const std::size_t batch_size = 1024;
19+
20+
std::vector<std::size_t> event_sizes(10, batch_size);
21+
algo.make_clusters(queue, h_points, d_points, event_sizes, clue::FlatKernel{.5f});
22+
23+
auto truth = clue::read_output<2>(queue, "../../../data/truth_files/data_1024_truth.csv");
24+
auto truth_n_clusters = clue::detail::compute_nclusters(truth.clusterIndexes());
25+
auto n_clusters = clue::detail::compute_nclusters(h_points.clusterIndexes());
26+
CHECK(n_clusters == truth_n_clusters * 10);
27+
}

0 commit comments

Comments
 (0)