Skip to content

Commit 5a0ee90

Browse files
committed
Simple tests.
1 parent 7b62f8f commit 5a0ee90

File tree

4 files changed

+47
-4
lines changed

4 files changed

+47
-4
lines changed

src/common/linalg_op.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
*
66
* Client code can use utilities like @ref ElementWiseKernel by including this file in the
77
* right translation unit. For CUDA-compatible kernels, include this header in a .cu TU.
8+
*
9+
* Be aware of potential violation of the one definition rule (ODR). The dispatching
10+
* functions should never be used in an inline function without a system tag.
811
*/
912
#ifndef XGBOOST_COMMON_LINALG_OP_H_
1013
#define XGBOOST_COMMON_LINALG_OP_H_

tests/cpp/common/test_linalg.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <vector> // for vector
1212

1313
#include "../../../src/common/linalg_op.h"
14+
#include "test_linalg.h" // for TestLinalgDispatch
1415

1516
namespace xgboost::linalg {
1617
namespace {
@@ -410,4 +411,9 @@ TEST(Linalg, IO) {
410411
check(loaded);
411412
}
412413
}
414+
415+
TEST(Linalg, CpuDispatch) {
416+
Context ctx;
417+
TestLinalgDispatch(&ctx, [](auto v) { return v + 1; });
418+
}
413419
} // namespace xgboost::linalg

tests/cpp/common/test_linalg.cu

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "../../../src/common/linalg_op.h"
1111
#include "../../../src/common/optional_weight.h" // for MakeOptionalWeights
1212
#include "../helpers.h"
13+
#include "test_linalg.h" // for TestLinalgDispatch
1314
#include "thrust/random.h" // for default_random_engine
1415
#include "thrust/shuffle.h" // for shuffle
1516
#include "xgboost/context.h"
@@ -28,8 +29,7 @@ void TestElementWiseKernel() {
2829
// GPU view
2930
auto t = l.View(device).Slice(linalg::All(), 1, linalg::All());
3031
ASSERT_FALSE(t.CContiguous());
31-
cuda_impl::TransformIdxKernel(&ctx, t,
32-
[] XGBOOST_DEVICE(std::size_t i, float) { return i; });
32+
cuda_impl::TransformIdxKernel(&ctx, t, [] XGBOOST_DEVICE(std::size_t i, float) { return i; });
3333
// CPU view
3434
t = l.View(DeviceOrd::CPU()).Slice(linalg::All(), 1, linalg::All());
3535
std::size_t k = 0;
@@ -56,8 +56,7 @@ void TestElementWiseKernel() {
5656
* Contiguous
5757
*/
5858
auto t = l.View(device);
59-
cuda_impl::TransformIdxKernel(&ctx, t,
60-
[] XGBOOST_DEVICE(size_t i, float) { return i; });
59+
cuda_impl::TransformIdxKernel(&ctx, t, [] XGBOOST_DEVICE(size_t i, float) { return i; });
6160
ASSERT_TRUE(t.CContiguous());
6261
// CPU view
6362
t = l.View(DeviceOrd::CPU());
@@ -151,4 +150,11 @@ TEST(Linalg, SmallHistogram) {
151150
ASSERT_EQ(h_bins[i], cnt);
152151
}
153152
}
153+
namespace {
154+
void TestGpuDispatch() {
155+
auto ctx = MakeCUDACtx(0);
156+
TestLinalgDispatch(&ctx, [] XGBOOST_DEVICE(double v) { return v + 1; });
157+
}
158+
} // namespace
159+
TEST(Linalg, GpuDispatch) { TestGpuDispatch(); }
154160
} // namespace xgboost::linalg

tests/cpp/common/test_linalg.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/**
2+
* Copyright 2025, XGBoost Contributors
3+
*/
4+
#pragma once
5+
6+
#include <gtest/gtest.h>
7+
#include <xgboost/context.h>
8+
#include <xgboost/linalg.h> // for Vector
9+
10+
#include <numeric> // for iota
11+
#include <vector> // for vector
12+
13+
#include "../../../src/common/linalg_op.h"
14+
15+
namespace xgboost::linalg {
16+
template <typename Fn>
17+
void TestLinalgDispatch(Context const* ctx, Fn&& fn) {
18+
std::vector<double> data(128, 0);
19+
std::iota(data.begin(), data.end(), 0.0);
20+
Vector<double> vec(data.begin(), data.end(), {data.size()}, DeviceOrd::CPU());
21+
22+
TransformKernel(ctx, vec.View(ctx->Device()), [=] XGBOOST_DEVICE(double v) { return fn(v); });
23+
auto h_v = vec.HostView();
24+
for (std::size_t i = 0; i < h_v.Size(); ++i) {
25+
ASSERT_EQ(h_v(i), fn(i));
26+
}
27+
}
28+
} // namespace xgboost::linalg

0 commit comments

Comments
 (0)