Skip to content

Commit 516bde9

Browse files
authored
[python-package] Allow to pass Arrow array as groups (#6166)
1 parent bc69422 commit 516bde9

File tree

6 files changed

+89
-40
lines changed

6 files changed

+89
-40
lines changed

include/LightGBM/c_api.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,9 +558,10 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetSetField(DatasetHandle handle,
558558
/*!
559559
* \brief Set vector to a content in info.
560560
* \note
561+
* - \a group converts input datatype into ``int32``;
561562
* - \a label and \a weight convert input datatype into ``float32``.
562563
* \param handle Handle of dataset
563-
* \param field_name Field name, can be \a label, \a weight
564+
* \param field_name Field name, can be \a label, \a weight, \a group
564565
* \param n_chunks The number of Arrow arrays passed to this function
565566
* \param chunks Pointer to the list of Arrow arrays
566567
* \param schema Pointer to the schema of all Arrow arrays

include/LightGBM/dataset.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ class Metadata {
116116
void SetWeights(const ArrowChunkedArray& array);
117117

118118
void SetQuery(const data_size_t* query, data_size_t len);
119+
void SetQuery(const ArrowChunkedArray& array);
119120

120121
void SetPosition(const data_size_t* position, data_size_t len);
121122

@@ -348,6 +349,9 @@ class Metadata {
348349
void InsertInitScores(const double* init_scores, data_size_t start_index, data_size_t len, data_size_t source_size);
349350
/*! \brief Insert queries at the given index */
350351
void InsertQueries(const data_size_t* queries, data_size_t start_index, data_size_t len);
352+
/*! \brief Set queries from pointers to the first element and the end of an iterator. */
353+
template <typename It>
354+
void SetQueriesFromIterator(It first, It last);
351355
/*! \brief Filename of current data */
352356
std::string data_filename_;
353357
/*! \brief Number of data */

python-package/lightgbm/basic.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@
7070
List[float],
7171
List[int],
7272
np.ndarray,
73-
pd_Series
73+
pd_Series,
74+
pa_Array,
75+
pa_ChunkedArray,
7476
]
7577
_LGBM_PositionType = Union[
7678
np.ndarray,
@@ -1652,7 +1654,7 @@ def __init__(
16521654
If this is Dataset for validation, training data should be used as reference.
16531655
weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
16541656
Weight for each instance. Weights should be non-negative.
1655-
group : list, numpy 1-D array, pandas Series or None, optional (default=None)
1657+
group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
16561658
Group/query data.
16571659
Only used in the learning-to-rank task.
16581660
sum(group) = n_samples.
@@ -2432,7 +2434,7 @@ def create_valid(
24322434
Label of the data.
24332435
weight : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
24342436
Weight for each instance. Weights should be non-negative.
2435-
group : list, numpy 1-D array, pandas Series or None, optional (default=None)
2437+
group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
24362438
Group/query data.
24372439
Only used in the learning-to-rank task.
24382440
sum(group) = n_samples.
@@ -2889,7 +2891,7 @@ def set_group(
28892891
28902892
Parameters
28912893
----------
2892-
group : list, numpy 1-D array, pandas Series or None
2894+
group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None
28932895
Group/query data.
28942896
Only used in the learning-to-rank task.
28952897
sum(group) = n_samples.
@@ -2903,7 +2905,8 @@ def set_group(
29032905
"""
29042906
self.group = group
29052907
if self._handle is not None and group is not None:
2906-
group = _list_to_1d_numpy(group, dtype=np.int32, name='group')
2908+
if not _is_pyarrow_array(group):
2909+
group = _list_to_1d_numpy(group, dtype=np.int32, name='group')
29072910
self.set_field('group', group)
29082911
# original values can be modified at cpp side
29092912
constructed_group = self.get_field('group')
@@ -4431,7 +4434,7 @@ def refit(
44314434
44324435
.. versionadded:: 4.0.0
44334436
4434-
group : list, numpy 1-D array, pandas Series or None, optional (default=None)
4437+
group : list, numpy 1-D array, pandas Series, pyarrow Array, pyarrow ChunkedArray or None, optional (default=None)
44354438
Group/query size for ``data``.
44364439
Only used in the learning-to-rank task.
44374440
sum(group) = n_samples.

src/io/dataset.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -904,6 +904,8 @@ bool Dataset::SetFieldFromArrow(const char* field_name, const ArrowChunkedArray
904904
metadata_.SetLabel(ca);
905905
} else if (name == std::string("weight") || name == std::string("weights")) {
906906
metadata_.SetWeights(ca);
907+
} else if (name == std::string("query") || name == std::string("group")) {
908+
metadata_.SetQuery(ca);
907909
} else {
908910
return false;
909911
}

src/io/metadata.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -507,30 +507,34 @@ void Metadata::InsertWeights(const label_t* weights, data_size_t start_index, da
507507
// CUDA is handled after all insertions are complete
508508
}
509509

510-
void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
510+
template <typename It>
511+
void Metadata::SetQueriesFromIterator(It first, It last) {
511512
std::lock_guard<std::mutex> lock(mutex_);
512-
// save to nullptr
513-
if (query == nullptr || len == 0) {
513+
// Clear query boundaries on empty input
514+
if (last - first == 0) {
514515
query_boundaries_.clear();
515516
num_queries_ = 0;
516517
return;
517518
}
519+
518520
data_size_t sum = 0;
519521
#pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) reduction(+:sum)
520-
for (data_size_t i = 0; i < len; ++i) {
521-
sum += query[i];
522+
for (data_size_t i = 0; i < last - first; ++i) {
523+
sum += first[i];
522524
}
523525
if (num_data_ != sum) {
524-
Log::Fatal("Sum of query counts is not same with #data");
526+
Log::Fatal("Sum of query counts (%i) differs from the length of #data (%i)", num_data_, sum);
525527
}
526-
num_queries_ = len;
528+
num_queries_ = last - first;
529+
527530
query_boundaries_.resize(num_queries_ + 1);
528531
query_boundaries_[0] = 0;
529532
for (data_size_t i = 0; i < num_queries_; ++i) {
530-
query_boundaries_[i + 1] = query_boundaries_[i] + query[i];
533+
query_boundaries_[i + 1] = query_boundaries_[i] + first[i];
531534
}
532535
CalculateQueryWeights();
533536
query_load_from_file_ = false;
537+
534538
#ifdef USE_CUDA
535539
if (cuda_metadata_ != nullptr) {
536540
if (query_weights_.size() > 0) {
@@ -543,6 +547,14 @@ void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
543547
#endif // USE_CUDA
544548
}
545549

550+
void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
551+
SetQueriesFromIterator(query, query + len);
552+
}
553+
554+
void Metadata::SetQuery(const ArrowChunkedArray& array) {
555+
SetQueriesFromIterator(array.begin<data_size_t>(), array.end<data_size_t>());
556+
}
557+
546558
void Metadata::SetPosition(const data_size_t* positions, data_size_t len) {
547559
std::lock_guard<std::mutex> lock(mutex_);
548560
// save to nullptr

tests/python_package_test/test_arrow.py

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# coding: utf-8
22
import filecmp
3-
from pathlib import Path
4-
from typing import Any, Callable, Dict
3+
from typing import Any, Dict
54

65
import numpy as np
76
import pyarrow as pa
@@ -15,6 +14,21 @@
1514
# UTILITIES #
1615
# ----------------------------------------------------------------------------------------------- #
1716

17+
_INTEGER_TYPES = [
18+
pa.int8(),
19+
pa.int16(),
20+
pa.int32(),
21+
pa.int64(),
22+
pa.uint8(),
23+
pa.uint16(),
24+
pa.uint32(),
25+
pa.uint64(),
26+
]
27+
_FLOAT_TYPES = [
28+
pa.float32(),
29+
pa.float64(),
30+
]
31+
1832

1933
def generate_simple_arrow_table() -> pa.Table:
2034
columns = [
@@ -85,9 +99,7 @@ def dummy_dataset_params() -> Dict[str, Any]:
8599
(lambda: generate_random_arrow_table(100, 10000, 43), {}),
86100
],
87101
)
88-
def test_dataset_construct_fuzzy(
89-
tmp_path: Path, arrow_table_fn: Callable[[], pa.Table], dataset_params: Dict[str, Any]
90-
):
102+
def test_dataset_construct_fuzzy(tmp_path, arrow_table_fn, dataset_params):
91103
arrow_table = arrow_table_fn()
92104

93105
arrow_dataset = lgb.Dataset(arrow_table, params=dataset_params)
@@ -108,17 +120,23 @@ def test_dataset_construct_fields_fuzzy():
108120
arrow_table = generate_random_arrow_table(3, 1000, 42)
109121
arrow_labels = generate_random_arrow_array(1000, 42)
110122
arrow_weights = generate_random_arrow_array(1000, 42)
123+
arrow_groups = pa.chunked_array([[300, 400, 50], [250]], type=pa.int32())
111124

112-
arrow_dataset = lgb.Dataset(arrow_table, label=arrow_labels, weight=arrow_weights)
125+
arrow_dataset = lgb.Dataset(
126+
arrow_table, label=arrow_labels, weight=arrow_weights, group=arrow_groups
127+
)
113128
arrow_dataset.construct()
114129

115130
pandas_dataset = lgb.Dataset(
116-
arrow_table.to_pandas(), label=arrow_labels.to_numpy(), weight=arrow_weights.to_numpy()
131+
arrow_table.to_pandas(),
132+
label=arrow_labels.to_numpy(),
133+
weight=arrow_weights.to_numpy(),
134+
group=arrow_groups.to_numpy(),
117135
)
118136
pandas_dataset.construct()
119137

120138
# Check for equality
121-
for field in ("label", "weight"):
139+
for field in ("label", "weight", "group"):
122140
np_assert_array_equal(
123141
arrow_dataset.get_field(field), pandas_dataset.get_field(field), strict=True
124142
)
@@ -133,22 +151,8 @@ def test_dataset_construct_fields_fuzzy():
133151
["array_type", "label_data"],
134152
[(pa.array, [0, 1, 0, 0, 1]), (pa.chunked_array, [[0], [1, 0, 0, 1]])],
135153
)
136-
@pytest.mark.parametrize(
137-
"arrow_type",
138-
[
139-
pa.int8(),
140-
pa.int16(),
141-
pa.int32(),
142-
pa.int64(),
143-
pa.uint8(),
144-
pa.uint16(),
145-
pa.uint32(),
146-
pa.uint64(),
147-
pa.float32(),
148-
pa.float64(),
149-
],
150-
)
151-
def test_dataset_construct_labels(array_type: Any, label_data: Any, arrow_type: Any):
154+
@pytest.mark.parametrize("arrow_type", _INTEGER_TYPES + _FLOAT_TYPES)
155+
def test_dataset_construct_labels(array_type, label_data, arrow_type):
152156
data = generate_dummy_arrow_table()
153157
labels = array_type(label_data, type=arrow_type)
154158
dataset = lgb.Dataset(data, label=labels, params=dummy_dataset_params())
@@ -175,11 +179,34 @@ def test_dataset_construct_weights_none():
175179
[(pa.array, [3, 0.7, 1.5, 0.5, 0.1]), (pa.chunked_array, [[3], [0.7, 1.5, 0.5, 0.1]])],
176180
)
177181
@pytest.mark.parametrize("arrow_type", [pa.float32(), pa.float64()])
178-
def test_dataset_construct_weights(array_type: Any, weight_data: Any, arrow_type: Any):
182+
def test_dataset_construct_weights(array_type, weight_data, arrow_type):
179183
data = generate_dummy_arrow_table()
180184
weights = array_type(weight_data, type=arrow_type)
181185
dataset = lgb.Dataset(data, weight=weights, params=dummy_dataset_params())
182186
dataset.construct()
183187

184188
expected = np.array([3, 0.7, 1.5, 0.5, 0.1], dtype=np.float32)
185189
np_assert_array_equal(expected, dataset.get_weight(), strict=True)
190+
191+
192+
# -------------------------------------------- GROUPS ------------------------------------------- #
193+
194+
195+
@pytest.mark.parametrize(
196+
["array_type", "group_data"],
197+
[
198+
(pa.array, [2, 3]),
199+
(pa.chunked_array, [[2], [3]]),
200+
(pa.chunked_array, [[], [2, 3]]),
201+
(pa.chunked_array, [[2], [], [3], []]),
202+
],
203+
)
204+
@pytest.mark.parametrize("arrow_type", _INTEGER_TYPES)
205+
def test_dataset_construct_groups(array_type, group_data, arrow_type):
206+
data = generate_dummy_arrow_table()
207+
groups = array_type(group_data, type=arrow_type)
208+
dataset = lgb.Dataset(data, group=groups, params=dummy_dataset_params())
209+
dataset.construct()
210+
211+
expected = np.array([0, 2, 5], dtype=np.int32)
212+
np_assert_array_equal(expected, dataset.get_field("group"), strict=True)

0 commit comments

Comments
 (0)