|
| 1 | +# pylint: disable=invalid-name |
1 | 2 | """Tests for dask shared by different test modules.""" |
2 | 3 |
|
3 | | -from typing import Any, List, Literal, Tuple, cast |
| 4 | +from typing import Any, List, Literal, Tuple, Type, cast |
4 | 5 |
|
5 | 6 | import numpy as np |
6 | 7 | import pandas as pd |
7 | 8 | from dask import array as da |
8 | 9 | from dask import dataframe as dd |
9 | | -from distributed import Client, get_worker, wait |
| 10 | +from distributed import Client, get_worker |
10 | 11 | from packaging.version import parse as parse_version |
11 | 12 | from sklearn.datasets import make_classification |
12 | 13 |
|
|
17 | 18 |
|
18 | 19 | from .. import dask as dxgb |
19 | 20 | from .._typing import EvalsLog |
20 | | -from ..dask import _DASK_VERSION, _get_rabit_args |
| 21 | +from ..dask import _get_rabit_args |
| 22 | +from ..dask.utils import _DASK_VERSION |
21 | 23 | from .data import make_batches |
22 | 24 | from .data import make_categorical as make_cat_local |
23 | 25 | from .ordinal import make_recoded |
@@ -325,61 +327,77 @@ def pack(**kwargs: Any) -> dd.DataFrame: |
325 | 327 | # pylint: disable=too-many-locals |
326 | 328 | def run_recode(client: Client, device: Device) -> None: |
327 | 329 | """Run re-coding test with the Dask interface.""" |
328 | | - enc, reenc, y, _, _ = make_recoded(device, n_features=96) |
329 | | - workers = get_client_workers(client) |
330 | | - denc, dreenc, dy = ( |
331 | | - dd.from_pandas(enc, npartitions=8).persist(workers=workers), |
332 | | - dd.from_pandas(reenc, npartitions=8).persist(workers=workers), |
333 | | - da.from_array(y, chunks=(y.shape[0] // 8,)).persist(workers=workers), |
334 | | - ) |
335 | 330 |
|
336 | | - wait([denc, dreenc, dy]) |
| 331 | + def create_dmatrix( |
| 332 | + DMatrixT: Type[dxgb.DaskDMatrix], *args: Any, **kwargs: Any |
| 333 | + ) -> dxgb.DaskDMatrix: |
| 334 | + if DMatrixT is dxgb.DaskQuantileDMatrix: |
| 335 | + ref = kwargs.pop("ref", None) |
| 336 | + return DMatrixT(*args, ref=ref, **kwargs) |
337 | 337 |
|
338 | | - if device == "cuda": |
339 | | - denc = denc.to_backend("cudf") |
340 | | - dreenc = dreenc.to_backend("cudf") |
341 | | - dy = dy.to_backend("cupy") |
| 338 | + kwargs.pop("ref", None) |
| 339 | + return DMatrixT(*args, **kwargs) |
342 | 340 |
|
343 | | - Xy = dxgb.DaskQuantileDMatrix(client, denc, dy, enable_categorical=True) |
344 | | - Xy_valid = dxgb.DaskQuantileDMatrix( |
345 | | - client, dreenc, dy, enable_categorical=True, ref=Xy |
346 | | - ) |
347 | | - # Base model |
348 | | - results = dxgb.train(client, {"device": device}, Xy, evals=[(Xy_valid, "Valid")]) |
| 341 | + def run(DMatrixT: Type[dxgb.DaskDMatrix]) -> None: |
| 342 | + enc, reenc, y, _, _ = make_recoded(device, n_features=96) |
| 343 | + to = get_client_workers(client) |
349 | 344 |
|
350 | | - # Training continuation |
351 | | - Xy = dxgb.DaskQuantileDMatrix(client, denc, dy, enable_categorical=True) |
352 | | - Xy_valid = dxgb.DaskQuantileDMatrix( |
353 | | - client, dreenc, dy, enable_categorical=True, ref=Xy |
354 | | - ) |
355 | | - results_1 = dxgb.train( |
356 | | - client, |
357 | | - {"device": device}, |
358 | | - Xy, |
359 | | - evals=[(Xy_valid, "Valid")], |
360 | | - xgb_model=results["booster"], |
361 | | - ) |
| 345 | + denc, dreenc, dy = ( |
| 346 | + dd.from_pandas(enc, npartitions=8).persist(workers=to), |
| 347 | + dd.from_pandas(reenc, npartitions=8).persist(workers=to), |
| 348 | + da.from_array(y, chunks=(y.shape[0] // 8,)).persist(workers=to), |
| 349 | + ) |
362 | 350 |
|
363 | | - # Reversed training continuation |
364 | | - Xy = dxgb.DaskQuantileDMatrix(client, dreenc, dy, enable_categorical=True) |
365 | | - Xy_valid = dxgb.DaskQuantileDMatrix( |
366 | | - client, denc, dy, enable_categorical=True, ref=Xy |
367 | | - ) |
368 | | - results_2 = dxgb.train( |
369 | | - client, |
370 | | - {"device": device}, |
371 | | - Xy, |
372 | | - evals=[(Xy_valid, "Valid")], |
373 | | - xgb_model=results["booster"], |
374 | | - ) |
375 | | - np.testing.assert_allclose( |
376 | | - results_1["history"]["Valid"]["rmse"], results_2["history"]["Valid"]["rmse"] |
377 | | - ) |
| 351 | + if device == "cuda": |
| 352 | + denc = denc.to_backend("cudf") |
| 353 | + dreenc = dreenc.to_backend("cudf") |
| 354 | + dy = dy.to_backend("cupy") |
| 355 | + |
| 356 | + Xy = create_dmatrix(DMatrixT, client, denc, dy, enable_categorical=True) |
| 357 | + Xy_valid = create_dmatrix( |
| 358 | + DMatrixT, client, dreenc, dy, enable_categorical=True, ref=Xy |
| 359 | + ) |
| 360 | + # Base model |
| 361 | + results = dxgb.train( |
| 362 | + client, {"device": device}, Xy, evals=[(Xy_valid, "Valid")] |
| 363 | + ) |
| 364 | + |
| 365 | + # Training continuation |
| 366 | + Xy = create_dmatrix(DMatrixT, client, denc, dy, enable_categorical=True) |
| 367 | + Xy_valid = create_dmatrix( |
| 368 | + DMatrixT, client, dreenc, dy, enable_categorical=True, ref=Xy |
| 369 | + ) |
| 370 | + results_1 = dxgb.train( |
| 371 | + client, |
| 372 | + {"device": device}, |
| 373 | + Xy, |
| 374 | + evals=[(Xy_valid, "Valid")], |
| 375 | + xgb_model=results["booster"], |
| 376 | + ) |
| 377 | + |
| 378 | + # Reversed training continuation |
| 379 | + Xy = create_dmatrix(DMatrixT, client, dreenc, dy, enable_categorical=True) |
| 380 | + Xy_valid = create_dmatrix( |
| 381 | + DMatrixT, client, denc, dy, enable_categorical=True, ref=Xy |
| 382 | + ) |
| 383 | + results_2 = dxgb.train( |
| 384 | + client, |
| 385 | + {"device": device}, |
| 386 | + Xy, |
| 387 | + evals=[(Xy_valid, "Valid")], |
| 388 | + xgb_model=results["booster"], |
| 389 | + ) |
| 390 | + np.testing.assert_allclose( |
| 391 | + results_1["history"]["Valid"]["rmse"], results_2["history"]["Valid"]["rmse"] |
| 392 | + ) |
| 393 | + |
| 394 | + predt_0 = dxgb.inplace_predict(client, results, denc).compute() |
| 395 | + predt_1 = dxgb.inplace_predict(client, results, dreenc).compute() |
| 396 | + assert_allclose(device, predt_0, predt_1) |
378 | 397 |
|
379 | | - predt_0 = dxgb.inplace_predict(client, results, denc).compute() |
380 | | - predt_1 = dxgb.inplace_predict(client, results, dreenc).compute() |
381 | | - assert_allclose(device, predt_0, predt_1) |
| 398 | + predt_0 = dxgb.predict(client, results, Xy).compute() |
| 399 | + predt_1 = dxgb.predict(client, results, Xy_valid).compute() |
| 400 | + assert_allclose(device, predt_0, predt_1) |
382 | 401 |
|
383 | | - predt_0 = dxgb.predict(client, results, Xy).compute() |
384 | | - predt_1 = dxgb.predict(client, results, Xy_valid).compute() |
385 | | - assert_allclose(device, predt_0, predt_1) |
| 402 | + for DMatrixT in [dxgb.DaskDMatrix, dxgb.DaskQuantileDMatrix]: |
| 403 | + run(DMatrixT) |
0 commit comments