Skip to content
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
11d4b63
optimise data initialisation
Apr 4, 2025
4589533
Merge branch 'dmlc:master' into dev/cpu/init_data_optimisation
razdoburdin Apr 4, 2025
3464f8c
linting
Apr 7, 2025
e211ab9
changing the capture for inner lambdas
Apr 8, 2025
9221573
fix
Apr 8, 2025
0a793e3
set default
Apr 8, 2025
e249a3b
linting
Apr 8, 2025
1be6f5d
fix test
Apr 8, 2025
396f4b3
Merge branch 'dmlc:master' into dev/cpu/init_data_optimisation
razdoburdin Apr 8, 2025
55a89d7
submodule fix
Apr 8, 2025
8a15c70
fix for i386
Apr 8, 2025
085627f
proteckt thread-unsafe code
Apr 9, 2025
b1e714f
fix for multi-batch
Apr 10, 2025
70fd6bc
fix compilation error
Apr 10, 2025
edef9e7
remove critical section; avoid using of bit filds
Apr 10, 2025
606c537
tildy
Apr 10, 2025
6f885b0
return deleted code
Apr 11, 2025
c0dbd7e
remove unactual code
Apr 11, 2025
1cb3693
address comments
May 5, 2025
560a67a
fix calling ColumnMatrix constructor
May 5, 2025
0ac338e
switch back to bitfield
May 19, 2025
61b3878
linting
May 19, 2025
98ef541
Update src/common/column_matrix.h
razdoburdin Jun 2, 2025
2b090e6
Update src/common/column_matrix.h
razdoburdin Jun 2, 2025
9f5ba75
Merge branch 'master' into dev/cpu/init_data_optimisation
trivialfis Jun 15, 2025
8cdd7db
Cleanup, typos.
trivialfis Jul 1, 2025
15aa65b
rename.
trivialfis Jul 1, 2025
58920a0
typos.
trivialfis Jul 1, 2025
0b7037a
update comment
Jul 21, 2025
820b79a
Update src/common/column_matrix.h
razdoburdin Nov 4, 2025
6d553fb
Update src/common/column_matrix.h
razdoburdin Nov 4, 2025
ea5c4fe
Update src/common/column_matrix.h
razdoburdin Nov 4, 2025
7dc15f2
Update src/common/column_matrix.h
razdoburdin Nov 4, 2025
390efd1
Update src/common/column_matrix.h
razdoburdin Nov 4, 2025
3250cc0
Update src/common/column_matrix.h
razdoburdin Nov 4, 2025
a8df33b
linting
razdoburdin Nov 5, 2025
0fcd8a7
remove whitespace
razdoburdin Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/common/column_matrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
#include "xgboost/span.h" // for Span

namespace xgboost::common {
void ColumnMatrix::InitStorage(GHistIndexMatrix const& gmat, double sparse_threshold) {
void ColumnMatrix::InitStorage(GHistIndexMatrix const& gmat, double sparse_threshold,
int n_threads) {
auto const nfeature = gmat.Features();
const size_t nrow = gmat.Size();
// identify type of each column
Expand Down Expand Up @@ -61,10 +62,11 @@ void ColumnMatrix::InitStorage(GHistIndexMatrix const& gmat, double sparse_thres
auto storage_size =
feature_offsets_.back() * static_cast<std::underlying_type_t<BinTypeSize>>(bins_type_size_);

index_ = common::MakeFixedVecWithMalloc(storage_size, std::uint8_t{0});
index_ = common::MakeFixedVecWithMalloc(storage_size, std::uint8_t{0}, n_threads);

if (!all_dense_column) {
row_ind_ = common::MakeFixedVecWithMalloc(feature_offsets_[nfeature], std::size_t{0});
row_ind_ = common::MakeFixedVecWithMalloc(feature_offsets_[nfeature],
std::size_t{0}, n_threads);
}

// store least bin id for each feature
Expand Down
131 changes: 93 additions & 38 deletions src/common/column_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <cstdint> // for uint8_t
#include <limits>
#include <memory>
#include <vector>
#include <type_traits> // for enable_if_t, is_same_v, is_signed_v

#include "../data/adapter.h"
Expand Down Expand Up @@ -113,18 +114,18 @@ class DenseColumnIter : public Column<BinIdxT> {
private:
using Base = Column<BinIdxT>;
/* flags for missing values in dense columns */
LBitField32 missing_flags_;
Span<uint8_t> missing_flags_;
size_t feature_offset_;

public:
explicit DenseColumnIter(common::Span<const BinIdxT> index, bst_bin_t index_base,
LBitField32 missing_flags, size_t feature_offset)
Span<uint8_t> missing_flags, size_t feature_offset)
: Base{index, index_base}, missing_flags_{missing_flags}, feature_offset_{feature_offset} {}
DenseColumnIter(DenseColumnIter const&) = delete;
DenseColumnIter(DenseColumnIter&&) = default;

[[nodiscard]] bool IsMissing(size_t ridx) const {
return missing_flags_.Check(feature_offset_ + ridx);
return missing_flags_[feature_offset_ + ridx];
}

bst_bin_t operator[](size_t ridx) const {
Expand All @@ -148,15 +149,11 @@ class ColumnMatrix {
* @brief A bit set for indicating whether an element in a dense column is missing.
*/
struct MissingIndicator {
using BitFieldT = LBitField32;
using T = typename BitFieldT::value_type;

BitFieldT missing;
RefResourceView<T> storage;
static_assert(std::is_same_v<T, std::uint32_t>);
Span<uint8_t> missing;
RefResourceView<uint8_t> storage;

template <typename U>
[[nodiscard]] std::enable_if_t<!std::is_signed_v<U>, U> static InitValue(bool init) {
[[nodiscard]] std::enable_if_t<!std::is_signed_v<U>, U> static InitValue(uint8_t init) {
return init ? ~U{0} : U{0};
}

Expand All @@ -166,45 +163,40 @@ class ColumnMatrix {
* @param init Initialize the indicator to true or false.
*/
MissingIndicator(std::size_t n_elements, bool init) {
auto m_size = missing.ComputeStorageSize(n_elements);
storage = common::MakeFixedVecWithMalloc(m_size, InitValue<T>(init));
storage = common::MakeFixedVecWithMalloc(n_elements, static_cast<uint8_t>(init));
this->InitView();
}
/** @brief Set the i^th element to be a valid element (instead of missing). */
void SetValid(typename LBitField32::index_type i) { missing.Clear(i); }
void SetValid(size_t i) { missing[i] = 0; }
/** @brief assign the storage to the view. */
void InitView() {
missing = LBitField32{Span{storage.data(), static_cast<size_t>(storage.size())}};
missing = Span{storage.data(), static_cast<size_t>(storage.size())};
}

void GrowTo(std::size_t n_elements, bool init) {
CHECK(storage.Resource()->Type() == ResourceHandler::kMalloc)
<< "[Internal Error]: Cannot grow the vector when external memory is used.";
auto m_size = missing.ComputeStorageSize(n_elements);
CHECK_GE(m_size, storage.size());
if (m_size == storage.size()) {
CHECK_GE(n_elements, storage.size());
if (n_elements == storage.size()) {
return;
}
// grow the storage
auto resource = std::dynamic_pointer_cast<common::MallocResource>(storage.Resource());
CHECK(resource);
resource->Resize(m_size * sizeof(T), InitValue<std::byte>(init));
storage = RefResourceView<T>{resource->DataAs<T>(), m_size, resource};
resource->Resize(n_elements * sizeof(uint8_t), InitValue<std::byte>(init));
storage = RefResourceView<uint8_t>{resource->DataAs<uint8_t>(), n_elements, resource};

this->InitView();
}
};

void InitStorage(GHistIndexMatrix const& gmat, double sparse_threshold);
void InitStorage(GHistIndexMatrix const& gmat, double sparse_threshold, int n_threads);

template <typename ColumnBinT, typename BinT, typename RIdx>
void SetBinSparse(BinT bin_id, RIdx rid, bst_feature_t fid, ColumnBinT* local_index) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function still used now that we have a new SetBinSparse?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original SetBinSparse is also used

if (type_[fid] == kDenseColumn) {
ColumnBinT* begin = &local_index[feature_offsets_[fid]];
begin[rid] = bin_id - index_base_[fid];
// not thread-safe with bit field.
// FIXME(jiamingy): We can directly assign kMissingId to the index to avoid missing
// flags.
missing_.SetValid(feature_offsets_[fid] + rid);
} else {
ColumnBinT* begin = &local_index[feature_offsets_[fid]];
Expand All @@ -214,15 +206,28 @@ class ColumnMatrix {
}
}

template <typename ColumnBinT, typename BinT, typename RIdx>
void SetBinSparse(BinT bin_id, RIdx rid, bst_feature_t fid, ColumnBinT* local_index, size_t nnz) {
if (type_[fid] == kDenseColumn) {
ColumnBinT* begin = &local_index[feature_offsets_[fid]];
begin[rid] = bin_id - index_base_[fid];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two lines look exactly the same as the following two lines.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved the first line outside branches. The second one differs.

missing_.SetValid(feature_offsets_[fid] + rid);
} else {
ColumnBinT* begin = &local_index[feature_offsets_[fid]];
begin[nnz] = bin_id - index_base_[fid];
row_ind_[feature_offsets_[fid] + nnz] = rid;
}
}

public:
// get number of features
[[nodiscard]] bst_feature_t GetNumFeature() const {
return static_cast<bst_feature_t>(type_.size());
}

ColumnMatrix() = default;
ColumnMatrix(GHistIndexMatrix const& gmat, double sparse_threshold) {
this->InitStorage(gmat, sparse_threshold);
ColumnMatrix(GHistIndexMatrix const& gmat, double sparse_threshold, int n_threads = 1) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in which case n_threads is 1, and what are the other cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

this->InitStorage(gmat, sparse_threshold, n_threads);
}

/**
Expand All @@ -232,7 +237,7 @@ class ColumnMatrix {
void InitFromSparse(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold,
int32_t n_threads) {
auto batch = data::SparsePageAdapterBatch{page.GetView()};
this->InitStorage(gmat, sparse_threshold);
this->InitStorage(gmat, sparse_threshold, n_threads);
// ignore base row id here as we always has one column matrix for each sparse page.
this->PushBatch(n_threads, batch, std::numeric_limits<float>::quiet_NaN(), gmat, 0);
}
Expand Down Expand Up @@ -283,7 +288,7 @@ class ColumnMatrix {
SetIndexNoMissing(base_rowid, gmat.index.data<RowBinIdxT>(), size, n_features, n_threads);
});
} else {
SetIndexMixedColumns(base_rowid, batch, gmat, missing);
SetIndexMixedColumns(base_rowid, batch, gmat, missing, n_threads);
}
}

Expand Down Expand Up @@ -349,7 +354,7 @@ class ColumnMatrix {
*/
template <typename Batch>
void SetIndexMixedColumns(size_t base_rowid, Batch const& batch, const GHistIndexMatrix& gmat,
float missing) {
float missing, int n_threads) {
auto n_features = gmat.Features();

missing_.GrowTo(feature_offsets_[n_features], true);
Expand All @@ -366,19 +371,69 @@ class ColumnMatrix {
using ColumnBinT = decltype(t);
ColumnBinT* local_index = reinterpret_cast<ColumnBinT*>(index_.data());
size_t const batch_size = batch.Size();
size_t k{0};
for (size_t rid = 0; rid < batch_size; ++rid) {
auto line = batch.GetLine(rid);
for (size_t i = 0; i < line.Size(); ++i) {
auto coo = line.GetElement(i);
if (is_valid(coo)) {
auto fid = coo.column_idx;
const uint32_t bin_id = row_index[k];
SetBinSparse(bin_id, rid + base_rowid, fid, local_index);
++k;

dmlc::OMPException exc;
std::vector<size_t> n_elements((n_threads + 1) * n_features, 0);
std::vector<size_t> k_offsets(n_threads + 1, 0);
size_t block_size = DivRoundUp(batch_size, n_threads);
#pragma omp parallel num_threads(n_threads)
{
exc.Run([&, is_valid]() {
int tid = omp_get_thread_num();
size_t begin = block_size * tid;
size_t end = std::min(begin + block_size, batch_size);
for (size_t rid = begin; rid < end; ++rid) {
const auto& line = batch.GetLine(rid);
for (size_t i = 0; i < line.Size(); ++i) {
auto coo = line.GetElement(i);
if (is_valid(coo)) {
auto fid = coo.column_idx;
if ((type_[fid] != kDenseColumn)) {
n_elements[(tid + 1) * n_features + fid] += 1;
}
k_offsets[tid + 1] += 1;
}
}
}
});
}
exc.Rethrow();

ParallelFor(n_features, n_threads, [&](auto fid) {
n_elements[fid] += num_nonzeros_[fid];
for (int tid = 0; tid < n_threads; ++tid) {
n_elements[(tid + 1) * n_features + fid] +=
n_elements[tid * n_features + fid];
}
num_nonzeros_[fid] = n_elements[n_threads * n_features + fid];
});
std::partial_sum(k_offsets.cbegin(), k_offsets.cend(), k_offsets.begin());

#pragma omp parallel num_threads(n_threads)
{
std::vector<size_t> nnz_offsets(n_features, 0);
exc.Run([&, is_valid, base_rowid, row_index]() {
int tid = omp_get_thread_num();
size_t begin = block_size * tid;
size_t end = std::min(begin + block_size, batch_size);
size_t k = 0;
for (size_t rid = begin; rid < end; ++rid) {
const auto& line = batch.GetLine(rid);
for (size_t i = 0; i < line.Size(); ++i) {
auto coo = line.GetElement(i);
if (is_valid(coo)) {
auto fid = coo.column_idx;
const uint32_t bin_id = row_index[k_offsets[tid] + k];
size_t nnz = n_elements[tid * n_features + fid] + nnz_offsets[fid];
SetBinSparse(bin_id, rid + base_rowid, fid, local_index, nnz);
++k;
nnz_offsets[fid] += (type_[fid] != kDenseColumn);
}
}
}
});
}
exc.Rethrow();
});
}

Expand Down
22 changes: 22 additions & 0 deletions src/common/ref_resource_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,28 @@ template <typename T>
return ref;
}

/**
* @brief Make a fixed size `RefResourceView` with malloc resource.
* Use n_threads to initilise the storage
*/
template <typename T>
[[nodiscard]] RefResourceView<T> MakeFixedVecWithMalloc(std::size_t n_elements, T const& init,
int n_threads) {
auto resource = std::make_shared<common::MallocResource>(n_elements * sizeof(T));
auto ref = RefResourceView{resource->DataAs<T>(), n_elements, resource};

size_t block_size = n_elements / n_threads + (n_elements % n_threads > 0);
#pragma omp parallel num_threads(n_threads)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this faster than std::fill_n for primitive data? Seems unlikely..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, if number of elements is high. Significant speed-up for number of elements ~1e8-1e9.

{
int tid = omp_get_thread_num();
auto begin = tid * block_size;
auto end = std::min((tid + 1) * block_size, n_elements);
auto size = end > begin ? end - begin : 0;
std::fill_n(ref.data() + begin, size, init);
}
return ref;
}

template <typename T>
class ReallocVector : public RefResourceView<T> {
static_assert(!std::is_reference_v<T>);
Expand Down
14 changes: 7 additions & 7 deletions src/data/gradient_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ GHistIndexMatrix::GHistIndexMatrix(Context const *ctx, DMatrix *p_fmat, bst_bin_
cut = common::SketchOnDMatrix(ctx, p_fmat, max_bins_per_feat, sorted_sketch, hess);

const uint32_t nbins = cut.Ptrs().back();
hit_count = common::MakeFixedVecWithMalloc(nbins, std::size_t{0});
hit_count = common::MakeFixedVecWithMalloc(nbins, std::size_t{0}, ctx->Threads());
hit_count_tloc_.resize(ctx->Threads() * nbins, 0);

size_t new_size = 1;
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
new_size += batch.Size();
}

row_ptr = common::MakeFixedVecWithMalloc(new_size, std::size_t{0});
row_ptr = common::MakeFixedVecWithMalloc(new_size, std::size_t{0}, ctx->Threads());

const bool isDense = p_fmat->IsDense();
this->isDense_ = isDense;
Expand Down Expand Up @@ -132,13 +132,13 @@ INSTANTIATION_PUSH(data::SparsePageAdapterBatch)
INSTANTIATION_PUSH(data::ColumnarAdapterBatch)
#undef INSTANTIATION_PUSH

void GHistIndexMatrix::ResizeColumns(double sparse_thresh) {
void GHistIndexMatrix::ResizeColumns(double sparse_thresh, int n_threads) {
CHECK(!std::isnan(sparse_thresh));
this->columns_ = std::make_unique<common::ColumnMatrix>(*this, sparse_thresh);
this->columns_ = std::make_unique<common::ColumnMatrix>(*this, sparse_thresh, n_threads);
}

void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) {
auto make_index = [this, n_index](auto t, common::BinTypeSize t_size) {
void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense, int n_threads) {
auto make_index = [this, n_index, n_threads](auto t, common::BinTypeSize t_size) {
// Must resize instead of allocating a new one. This function is called everytime a
// new batch is pushed, and we grow the size accordingly without loosing the data in
// the previous batches.
Expand All @@ -150,7 +150,7 @@ void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) {
decltype(this->data) new_vec;
if (!resource) {
CHECK(this->data.empty());
new_vec = common::MakeFixedVecWithMalloc(n_bytes, std::uint8_t{0});
new_vec = common::MakeFixedVecWithMalloc(n_bytes, std::uint8_t{0}, n_threads);
} else {
CHECK(resource->Type() == common::ResourceHandler::kMalloc);
auto malloc_resource = std::dynamic_pointer_cast<common::MallocResource>(resource);
Expand Down
8 changes: 4 additions & 4 deletions src/data/gradient_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class GHistIndexMatrix {

auto n_bins_total = cut.TotalBins();
const size_t n_index = row_ptr[rbegin + batch.Size()]; // number of entries in this page
ResizeIndex(n_index, isDense_);
ResizeIndex(n_index, isDense_, n_threads);
if (isDense_) {
index.SetBinOffset(cut.Ptrs());
}
Expand All @@ -142,7 +142,7 @@ class GHistIndexMatrix {
}

// The function is only created to avoid using the column matrix in the header.
void ResizeColumns(double sparse_thresh);
void ResizeColumns(double sparse_thresh, int n_threads);

public:
/** @brief row pointer to rows by element position */
Expand Down Expand Up @@ -224,7 +224,7 @@ class GHistIndexMatrix {

if (rbegin + batch.Size() == n_samples_total) {
// finished
this->ResizeColumns(sparse_thresh);
this->ResizeColumns(sparse_thresh, ctx->Threads());
}
}

Expand All @@ -233,7 +233,7 @@ class GHistIndexMatrix {
void PushAdapterBatchColumns(Context const* ctx, Batch const& batch, float missing,
size_t rbegin);

void ResizeIndex(const size_t n_index, const bool isDense);
void ResizeIndex(const size_t n_index, const bool isDense, int n_threads = 1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please share in which case nthread=1, and what are the other cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the code, no default value now.


void GetFeatureCounts(size_t* counts) const {
auto nfeature = cut.Ptrs().size() - 1;
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/common/test_column_matrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ TEST(ColumnMatrix, GrowMissing) {
auto const& column_matrix = page.Transpose();
auto const& missing = column_matrix.Missing();
auto n = NumpyArrayIterForTest::Rows() * NumpyArrayIterForTest::Cols();
auto expected = std::remove_reference_t<decltype(missing)>::BitFieldT::ComputeStorageSize(n);
auto expected = n;
auto got = missing.storage.size();
ASSERT_EQ(expected, got);
DispatchBinType(column_matrix.GetTypeSize(), [&](auto dtype) {
Expand Down
Loading