Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 67 additions & 42 deletions python-package/xgboost/dask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import collections
import logging
import platform
import socket
import warnings
from collections import defaultdict
from contextlib import contextmanager
Expand Down Expand Up @@ -155,6 +154,25 @@
LOGGER = logging.getLogger("[xgboost.dask]")


def _DASK_VERSION():
from packaging.version import parse as parse_version

return parse_version(dask.__version__)


def _DASK_2024_12_1() -> bool:
from packaging.version import parse as parse_version

return _DASK_VERSION() >= parse_version("2024.12.1")


def _DASK_2025_3_0() -> bool:

from packaging.version import parse as parse_version

return _DASK_VERSION() >= parse_version("2025.3.0")


def _try_start_tracker(
n_workers: int,
addrs: List[Union[Optional[str], Optional[Tuple[str, int]]]],
Expand Down Expand Up @@ -307,6 +325,8 @@ def __init__(
feature_weights: Optional[_DaskCollection] = None,
enable_categorical: bool = False,
) -> None:
from distributed import Future

_assert_dask_support()
client = _xgb_get_client(client)

Expand All @@ -332,7 +352,7 @@ def __init__(

self._n_cols = data.shape[1]
assert isinstance(self._n_cols, int)
self.worker_map: Dict[str, List[distributed.Future]] = defaultdict(list)
self.worker_map: Dict[str, List[Future]] = defaultdict(list)
self.is_quantile: bool = False

self._init = client.sync(
Expand Down Expand Up @@ -365,6 +385,7 @@ async def _map_local_data(
) -> "DaskDMatrix":
"""Obtain references to local data."""
from dask.delayed import Delayed
from distributed import Future

def inconsistent(
left: List[Any], left_name: str, right: List[Any], right_name: str
Expand All @@ -376,48 +397,39 @@ def inconsistent(
)
return msg

def check_columns(parts: numpy.ndarray) -> None:
# x is required to be 2 dim in __init__
assert parts.ndim == 1 or parts.shape[1], (
"Data should be"
" partitioned by row. To avoid this specify the number"
" of columns for your dask Array explicitly. e.g."
" chunks=(partition_size, X.shape[1])"
)

def to_delayed(d: _DaskCollection) -> List[Delayed]:
"""Breaking data into partitions, a trick borrowed from dask_xgboost. `to_delayed`
downgrades high-level objects into numpy or pandas equivalents .

"""
def to_futures(d: _DaskCollection) -> List[Future]:
"""Breaking data into partitions."""
d = client.persist(d)
delayed_obj = d.to_delayed()
if isinstance(delayed_obj, numpy.ndarray):
# da.Array returns an array to delayed objects
check_columns(delayed_obj)
delayed_list: List[Delayed] = delayed_obj.flatten().tolist()
else:
# dd.DataFrame
delayed_list = delayed_obj
return delayed_list
if (
hasattr(d.partitions, "shape")
and len(d.partitions.shape) > 1
and d.partitions.shape[1] > 1
):
raise ValueError(
"Data should be"
" partitioned by row. To avoid this specify the number"
" of columns for your dask Array explicitly. e.g."
" chunks=(partition_size, -1])"
)
return client.futures_of(d)

def flatten_meta(meta: Optional[_DaskCollection]) -> Optional[List[Delayed]]:
def flatten_meta(meta: Optional[_DaskCollection]) -> Optional[List[Future]]:
if meta is not None:
meta_parts: List[Delayed] = to_delayed(meta)
meta_parts: List[Future] = to_futures(meta)
return meta_parts
return None

X_parts = to_delayed(data)
X_parts = to_futures(data)
y_parts = flatten_meta(label)
w_parts = flatten_meta(weights)
margin_parts = flatten_meta(base_margin)
qid_parts = flatten_meta(qid)
ll_parts = flatten_meta(label_lower_bound)
lu_parts = flatten_meta(label_upper_bound)

parts: Dict[str, List[Delayed]] = {"data": X_parts}
parts: Dict[str, List[Future]] = {"data": X_parts}

def append_meta(m_parts: Optional[List[Delayed]], name: str) -> None:
def append_meta(m_parts: Optional[List[Future]], name: str) -> None:
if m_parts is not None:
assert len(X_parts) == len(m_parts), inconsistent(
X_parts, "X", m_parts, name
Expand All @@ -431,12 +443,12 @@ def append_meta(m_parts: Optional[List[Delayed]], name: str) -> None:
append_meta(ll_parts, "label_lower_bound")
append_meta(lu_parts, "label_upper_bound")
# At this point, `parts` looks like:
# [(x0, x1, ..), (y0, y1, ..), ..] in delayed form
# [(x0, x1, ..), (y0, y1, ..), ..] in future form

# turn into list of dictionaries.
packed_parts: List[Dict[str, Delayed]] = []
packed_parts: List[Dict[str, Future]] = []
for i in range(len(X_parts)):
part_dict: Dict[str, Delayed] = {}
part_dict: Dict[str, Future] = {}
for key, value in parts.items():
part_dict[key] = value[i]
packed_parts.append(part_dict)
Expand All @@ -445,16 +457,17 @@ def append_meta(m_parts: Optional[List[Delayed]], name: str) -> None:
# pylint: disable=no-member
delayed_parts: List[Delayed] = list(map(dask.delayed, packed_parts))
# At this point, the mental model should look like:
# [(x0, y0, ..), (x1, y1, ..), ..] in delayed form
# [{"data": x0, "label": y0, ..}, {"data": x1, "label": y1, ..}, ..]

# convert delayed objects into futures and make sure they are realized
fut_parts: List[distributed.Future] = client.compute(delayed_parts)
# Convert delayed objects into futures and make sure they are realized
#
# This also makes partitions to align (co-locate) on workers (X_0, y_0 should be
# on the same worker).
fut_parts: List[Future] = client.compute(delayed_parts)
await distributed.wait(fut_parts) # async wait for parts to be computed

# maybe we can call dask.align_partitions here to ease the partition alignment?

for part in fut_parts:
# Each part is [x0, y0, w0, ...] in future form.
# Each part is [{"data": x0, "label": y0, ..}, ...] in future form.
assert part.status == "finished", part.status

# Preserving the partition order for prediction.
Expand All @@ -467,7 +480,7 @@ def append_meta(m_parts: Optional[List[Delayed]], name: str) -> None:
keys=[part.key for part in fut_parts]
)

worker_map: Dict[str, List[distributed.Future]] = defaultdict(list)
worker_map: Dict[str, List[Future]] = defaultdict(list)

for key, workers in who_has.items():
worker_map[next(iter(workers))].append(key_to_partition[key])
Expand Down Expand Up @@ -1645,6 +1658,18 @@ async def _predict_async(
)
if isinstance(predts, dd.DataFrame):
predts = predts.to_dask_array()
# Make sure the booster is part of the task graph implicitly
# only needed for certain versions of dask.
if _DASK_2024_12_1() and not _DASK_2025_3_0():
# Fixes this issue for dask>=2024.1.1,<2025.3.0
# Dask==2025.3.0 fails with:
# RuntimeError: Attempting to use an asynchronous
# Client in a synchronous context of `dask.compute`
#
# Dask==2025.4.0 fails with:
# TypeError: Value type is not supported for data
# iterator:<class 'distributed.client.Future'>
predts = predts.persist()
else:
test_dmatrix = await DaskDMatrix(
self.client,
Expand Down Expand Up @@ -1769,7 +1794,7 @@ def _client_sync(self, func: Callable, **kwargs: Any) -> Any:
@xgboost_model_doc(
"""Implementation of the Scikit-Learn API for XGBoost.""", ["estimators", "model"]
)
class DaskXGBRegressor(DaskScikitLearnBase, XGBRegressorBase):
class DaskXGBRegressor(XGBRegressorBase, DaskScikitLearnBase):
"""dummy doc string to workaround pylint, replaced by the decorator."""

async def _fit_async(
Expand Down Expand Up @@ -1859,7 +1884,7 @@ def fit(
"Implementation of the scikit-learn API for XGBoost classification.",
["estimators", "model"],
)
class DaskXGBClassifier(DaskScikitLearnBase, XGBClassifierBase):
class DaskXGBClassifier(XGBClassifierBase, DaskScikitLearnBase):
# pylint: disable=missing-class-docstring
async def _fit_async(
self,
Expand Down
Loading