Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions python-package/xgboost/_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,3 +749,21 @@ def array_interface(self) -> bytes:
@abstractmethod
def shape(self) -> Tuple[int, int]:
"""Return the shape of the dataframe."""


class ArrowCatMixin:
"""Mixin for handling arrow-backed dictionary array in a transformed dataframe."""

def _push_arrow_cat(
self, col: "pa.DictionaryArray", aitfs: AifType, temporary_buffers: List[Tuple]
) -> None:
pa = import_pyarrow()
cats = col.dictionary
codes = col.indices
if not isinstance(cats, (pa.StringArray, pa.LargeStringArray)):
raise TypeError(
"Only string-based categorical index is supported for arrow."
)
jnames, jcodes, buf = arrow_cat_inf(cats, codes)
temporary_buffers.append(buf)
aitfs.append((jnames, jcodes))
46 changes: 26 additions & 20 deletions python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from ._data_utils import (
AifType,
ArrowCatMixin,
Categories,
DfCatAccessor,
TransformedDf,
Expand All @@ -34,7 +35,6 @@
array_hasobject,
array_interface,
array_interface_dict,
arrow_cat_inf,
check_cudf_meta,
cuda_array_interface,
cuda_array_interface_dict,
Expand Down Expand Up @@ -493,8 +493,7 @@ def is_pd_sparse_dtype(dtype: PandasDType) -> bool:
return is_sparse(dtype)


def pandas_pa_type(ser: Any) -> np.ndarray:
"""Handle pandas pyarrow extention."""
def pandas_pa_chunk(ser: Any) -> "pa.Array":
pd = import_pandas()

if TYPE_CHECKING:
Expand All @@ -511,6 +510,14 @@ def pandas_pa_type(ser: Any) -> np.ndarray:
aa: "pa.ChunkedArray" = d_array.__arrow_array__()
# combine_chunks takes the most significant amount of time
chunk: "pa.Array" = aa.combine_chunks()
return chunk


def pandas_pa_type(ser: Any) -> np.ndarray:
"""Handle numerical pandas pyarrow extention."""
pa = import_pyarrow()

chunk = pandas_pa_chunk(ser)
# When there's null value, we have to use copy
zero_copy = chunk.null_count == 0 and not pa.types.is_boolean(chunk.type)
# Alternately, we can use chunk.buffers(), which returns a list of buffers and
Expand All @@ -524,6 +531,11 @@ def pandas_pa_type(ser: Any) -> np.ndarray:
return arr


def pandas_pa_cat_type(ser: Any) -> np.ndarray:
"""Handle categorical pandas pyarrow extention."""
return pandas_pa_chunk(ser)


@functools.cache
def _lazy_has_npdtypes() -> bool:
return np.lib.NumpyVersion(np.__version__) > np.lib.NumpyVersion("1.25.0")
Expand Down Expand Up @@ -582,10 +594,9 @@ def oth_type(ser: "PdSeries") -> np.ndarray:

for col, dtype in zip(data.columns, data.dtypes):
if is_pa_ext_categorical_dtype(dtype):
raise ValueError(
"pyarrow dictionary type is not supported. Use pandas category instead."
)
if is_pd_cat_dtype(dtype):
arr_cat = pandas_pa_cat_type(data[col])
result.append(arr_cat)
elif is_pd_cat_dtype(dtype):
result.append(cat_codes(data[col]))
elif is_pa_ext_dtype(dtype):
result.append(pandas_pa_type(data[col]))
Expand All @@ -606,7 +617,7 @@ def oth_type(ser: "PdSeries") -> np.ndarray:
return result


class PandasTransformed(TransformedDf):
class PandasTransformed(TransformedDf, ArrowCatMixin):
"""A storage class for transformed pandas DataFrame."""

def __init__(
Expand All @@ -620,8 +631,11 @@ def __init__(

# Get the array interface representation for each column.
for col in self.columns:
if _is_df_cat(col):
# Categorical column
if is_arrow_dict(col):
# Arrow categorical column
self._push_arrow_cat(col, aitfs, self.temporary_buffers)
elif _is_df_cat(col):
# Pandas categorical column
jnames, jcodes, buf = pd_cat_inf(col.categories, col.codes)
self.temporary_buffers.append(buf)
aitfs.append((jnames, jcodes))
Expand Down Expand Up @@ -755,7 +769,7 @@ def _from_pandas_series(
)


class ArrowTransformed(TransformedDf):
class ArrowTransformed(TransformedDf, ArrowCatMixin):
"""A storage class for transformed arrow table."""

def __init__(
Expand All @@ -776,15 +790,7 @@ def __init__(

def push_series(col: Union["pa.NumericArray", "pa.DictionaryArray"]) -> None:
if isinstance(col, pa.DictionaryArray):
cats = col.dictionary
codes = col.indices
if not isinstance(cats, (pa.StringArray, pa.LargeStringArray)):
raise TypeError(
"Only string-based categorical index is supported for arrow."
)
jnames, jcodes, buf = arrow_cat_inf(cats, codes)
self.temporary_buffers.append(buf)
aitfs.append((jnames, jcodes))
self._push_arrow_cat(col, aitfs, self.temporary_buffers)
else:
jdata = _arrow_array_inf(col)
aitfs.append(jdata)
Expand Down
21 changes: 19 additions & 2 deletions python-package/xgboost/testing/ordinal.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,15 +335,32 @@ def run_basic_predict(DMatrixT: Type, device: Device, tdevice: Device) -> None:
assert_allclose(device, predt0, predt2)


def run_cat_predict(device: Device) -> None:
def run_cat_predict(device: Device, use_arrow: bool) -> None:
"""Basic tests for re-coding during prediction."""
Df, _ = get_df_impl(device)

for dm in (DMatrix, QuantileDMatrix):
run_basic_predict(dm, device, device)

def run_mixed(DMatrixT: Type) -> None:
df = Df({"b": [2, 1, 3], "c": ["cdef", "abc", "def"]}, dtype="category")
b_list = [2, 1, 3]
c_list = ["cdef", "abc", "def"]
if use_arrow:
import pandas as pd
import pyarrow as pa

c_typ = pa.DictionaryArray.from_arrays(
pa.array([0, 1, 2]),
pa.array(["cdef", "abc", "def"], type=pa.large_utf8()),
)
c_ser = pd.Series(c_typ, dtype=pd.ArrowDtype(c_typ.type))
Comment on lines +352 to +356

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When constructing c_ser from c_typ, I would suggest doing something like

c_array = pd.arrays.ArrowExtensionArray(c_typ)
c_ser = pd.Series(c_array)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing that out. I need to create a new test, this one doesn't work, as there's no cat attribute for arrow dictionary-backed columns (the existing test relies on this interface).

b_typ = pa.DictionaryArray.from_arrays(
pa.array([0, 1, 2]), pa.array(b_list)
)
b_ser = pd.Series(b_typ, dtype=pd.ArrowDtype(b_typ.type))
df = Df({"b": b_ser, "c": c_ser})
else:
df = Df({"b": b_list, "c": c_list}, dtype="category")
y = np.array([0, 1, 2])

# used with the next df
Expand Down
2 changes: 1 addition & 1 deletion tests/python-gpu/test_gpu_ordinal.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_cat_container_iter() -> None:


def test_cat_predict() -> None:
run_cat_predict("cuda")
run_cat_predict("cuda", False)


def test_cat_invalid() -> None:
Expand Down
7 changes: 6 additions & 1 deletion tests/python/test_ordinal.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ def test_cat_container_iter() -> None:


def test_cat_predict() -> None:
run_cat_predict("cpu")
run_cat_predict("cpu", False)


@pytest.mark.skipif(**tm.no_arrow())
def test_cat_predict_arrow() -> None:
run_cat_predict("cpu", True)


def test_cat_invalid() -> None:
Expand Down
Loading