Skip to content

Commit b245b90

Browse files
authored
[enc][dask] Avoid to_backend call. (#11665)
To make sure the data is partitioned as required.
1 parent 12df100 commit b245b90

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

python-package/xgboost/testing/dask.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -348,11 +348,6 @@ def run(DMatrixT: Type[dxgb.DaskDMatrix]) -> None:
348348
da.from_array(y, chunks=(y.shape[0] // 8,)).persist(workers=to),
349349
)
350350

351-
if device == "cuda":
352-
denc = denc.to_backend("cudf")
353-
dreenc = dreenc.to_backend("cudf")
354-
dy = dy.to_backend("cupy")
355-
356351
Xy = create_dmatrix(DMatrixT, client, denc, dy, enable_categorical=True)
357352
Xy_valid = create_dmatrix(
358353
DMatrixT, client, dreenc, dy, enable_categorical=True, ref=Xy

tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343

4444
try:
4545
import cudf
46+
import dask
4647
import dask.dataframe as dd
4748
from dask import __version__ as dask_version
4849
from dask import array as da
@@ -596,7 +597,13 @@ def test_categorical(local_cuda_client: Client) -> None:
596597

597598
@pytest.mark.skipif(**tm.no_dask_cudf())
598599
def test_recode(local_cuda_client: Client) -> None:
599-
run_recode(local_cuda_client, "cuda")
600+
with dask.config.set(
601+
{
602+
"array.backend": "cupy",
603+
"dataframe.backend": "cudf",
604+
}
605+
):
606+
run_recode(local_cuda_client, "cuda")
600607

601608

602609
@pytest.mark.skipif(**tm.no_cupy())

0 commit comments

Comments
 (0)