diff --git a/include/LightGBM/arrow.h b/include/LightGBM/arrow.h index 3d1c74713bd3..75511e17e72a 100644 --- a/include/LightGBM/arrow.h +++ b/include/LightGBM/arrow.h @@ -117,6 +117,7 @@ class ArrowChunkedArray { const struct ArrowSchema* schema) { chunks_.reserve(n_chunks); for (auto k = 0; k < n_chunks; ++k) { + if (chunks[k].length == 0) continue; chunks_.push_back(&chunks[k]); } schema_ = schema; @@ -220,6 +221,7 @@ class ArrowTable { std::vector children_chunks; children_chunks.reserve(n_chunks); for (int64_t k = 0; k < n_chunks; ++k) { + if (chunks[k].length == 0) continue; children_chunks.push_back(chunks[k].children[j]); } columns_.emplace_back(children_chunks, schema->children[j]); diff --git a/tests/python_package_test/test_arrow.py b/tests/python_package_test/test_arrow.py index 5e09465e34b3..7542368dcd63 100644 --- a/tests/python_package_test/test_arrow.py +++ b/tests/python_package_test/test_arrow.py @@ -30,18 +30,19 @@ ] -def generate_simple_arrow_table() -> pa.Table: +def generate_simple_arrow_table(empty_chunks: bool = False) -> pa.Table: + c: list[list[int]] = [[]] if empty_chunks else [] columns = [ - pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.uint8()), - pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.int8()), - pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.uint16()), - pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.int16()), - pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.uint32()), - pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.int32()), - pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.uint64()), - pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.int64()), - pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.float32()), - pa.chunked_array([[1, 2, 3, 4, 5]], type=pa.float64()), + pa.chunked_array(c + [[1, 2, 3]] + c + [[4, 5]] + c, type=pa.uint8()), + pa.chunked_array(c + [[1, 2, 3]] + c + [[4, 5]] + c, type=pa.int8()), + pa.chunked_array(c + [[1, 2, 3]] + c + [[4, 5]] + c, type=pa.uint16()), + pa.chunked_array(c + [[1, 2, 3]] + c + [[4, 5]] + c, type=pa.int16()), + pa.chunked_array(c + [[1, 2, 3]] + c + [[4, 5]] + c, type=pa.uint32()), + pa.chunked_array(c + [[1, 2, 3]] + c + [[4, 5]] + c, type=pa.int32()), + pa.chunked_array(c + [[1, 2, 3]] + c + [[4, 5]] + c, type=pa.uint64()), + pa.chunked_array(c + [[1, 2, 3]] + c + [[4, 5]] + c, type=pa.int64()), + pa.chunked_array(c + [[1, 2, 3]] + c + [[4, 5]] + c, type=pa.float32()), + pa.chunked_array(c + [[1, 2, 3]] + c + [[4, 5]] + c, type=pa.float64()), ] return pa.Table.from_arrays(columns, names=[f"col_{i}" for i in range(len(columns))]) @@ -104,6 +105,7 @@ def dummy_dataset_params() -> Dict[str, Any]: ("arrow_table_fn", "dataset_params"), [ # Use lambda functions here to minimize memory consumption (lambda: generate_simple_arrow_table(), dummy_dataset_params()), + (lambda: generate_simple_arrow_table(empty_chunks=True), dummy_dataset_params()), (lambda: generate_dummy_arrow_table(), dummy_dataset_params()), (lambda: generate_nullable_arrow_table(), dummy_dataset_params()), (lambda: generate_random_arrow_table(3, 1000, 42), {}), @@ -160,7 +162,12 @@ def test_dataset_construct_fields_fuzzy(): @pytest.mark.parametrize( ["array_type", "label_data"], - [(pa.array, [0, 1, 0, 0, 1]), (pa.chunked_array, [[0], [1, 0, 0, 1]])], + [ + (pa.array, [0, 1, 0, 0, 1]), + (pa.chunked_array, [[0], [1, 0, 0, 1]]), + (pa.chunked_array, [[], [0], [1, 0, 0, 1]]), + (pa.chunked_array, [[0], [], [1, 0], [], [], [0, 1], []]), + ], ) @pytest.mark.parametrize("arrow_type", _INTEGER_TYPES + _FLOAT_TYPES) def test_dataset_construct_labels(array_type, label_data, arrow_type): @@ -187,7 +194,12 @@ def test_dataset_construct_weights_none(): @pytest.mark.parametrize( ["array_type", "weight_data"], - [(pa.array, [3, 0.7, 1.5, 0.5, 0.1]), (pa.chunked_array, [[3], [0.7, 1.5, 0.5, 0.1]])], + [ + (pa.array, [3, 0.7, 1.5, 0.5, 0.1]), + (pa.chunked_array, [[3], [0.7, 1.5, 0.5, 0.1]]), + (pa.chunked_array, [[], [3], [0.7, 1.5, 0.5, 0.1]]), + (pa.chunked_array, [[3], [0.7], [], [], [1.5, 0.5, 0.1], []]), + ], ) @pytest.mark.parametrize("arrow_type", _FLOAT_TYPES) def test_dataset_construct_weights(array_type, weight_data, arrow_type):