Skip to content

Commit d517ba1

Browse files
authored
[tests][dask] Add voting_parallel algorithm in tests (fixes #3834) (#4088)
* include voting_parallel tree_learner in test_regressor, test_classifier and test_ranker * remove test for warnings and test for error when using feature_parallel * use real names for tree_learner intest and include test for aliases. use the error message in the test for error in feature parallel * split all tests with rf in test_classifier * remove task parametrization for tree_learner aliases test. smaller input data from feature_parallel error * define task for tree_learner aliases
1 parent 46a20ab commit d517ba1

File tree

2 files changed

+40
-27
lines changed

2 files changed

+40
-27
lines changed

Diff for: python-package/lightgbm/dask.py

-6
Original file line numberDiff line numberDiff line change
@@ -309,12 +309,6 @@ def _train(
309309
_log_warning('Parameter tree_learner set to %s, which is not allowed. Using "data" as default' % params['tree_learner'])
310310
params['tree_learner'] = 'data'
311311

312-
if params['tree_learner'] not in {'data', 'data_parallel'}:
313-
_log_warning(
314-
'Support for tree_learner %s in lightgbm.dask is experimental and may break in a future release. \n'
315-
'Use "data" for a stable, well-tested interface.' % params['tree_learner']
316-
)
317-
318312
# Some passed-in parameters can be removed:
319313
# * 'num_machines': set automatically from Dask worker list
320314
# * 'num_threads': overridden to match nthreads on each Dask process

Diff for: tests/python_package_test/test_dask.py

+40-21
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
CLIENT_CLOSE_TIMEOUT = 120
4545

4646
tasks = ['binary-classification', 'multiclass-classification', 'regression', 'ranking']
47+
distributed_training_algorithms = ['data', 'voting']
4748
data_output = ['array', 'scipy_csr_matrix', 'dataframe', 'dataframe-with-categorical']
4849
boosting_types = ['gbdt', 'dart', 'goss', 'rf']
4950
group_sizes = [5, 5, 5, 10, 10, 10, 20, 20, 20, 50, 50]
@@ -235,14 +236,16 @@ def _unpickle(filepath, serializer):
235236
@pytest.mark.parametrize('output', data_output)
236237
@pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification'])
237238
@pytest.mark.parametrize('boosting_type', boosting_types)
238-
def test_classifier(output, task, boosting_type, client):
239+
@pytest.mark.parametrize('tree_learner', distributed_training_algorithms)
240+
def test_classifier(output, task, boosting_type, tree_learner, client):
239241
X, y, w, _, dX, dy, dw, _ = _create_data(
240242
objective=task,
241243
output=output
242244
)
243245

244246
params = {
245247
"boosting_type": boosting_type,
248+
"tree_learner": tree_learner,
246249
"n_estimators": 50,
247250
"num_leaves": 31
248251
}
@@ -273,7 +276,7 @@ def test_classifier(output, task, boosting_type, client):
273276
p2_proba = local_classifier.predict_proba(X)
274277
s2 = local_classifier.score(X, y)
275278

276-
if boosting_type == 'rf' and output == 'dataframe-with-categorical':
279+
if boosting_type == 'rf':
277280
# https://github.com/microsoft/LightGBM/issues/4118
278281
assert_eq(s1, s2, atol=0.01)
279282
assert_eq(p1_proba, p2_proba, atol=0.8)
@@ -448,7 +451,8 @@ def test_training_does_not_fail_on_port_conflicts(client):
448451

449452
@pytest.mark.parametrize('output', data_output)
450453
@pytest.mark.parametrize('boosting_type', boosting_types)
451-
def test_regressor(output, boosting_type, client):
454+
@pytest.mark.parametrize('tree_learner', distributed_training_algorithms)
455+
def test_regressor(output, boosting_type, tree_learner, client):
452456
X, y, w, _, dX, dy, dw, _ = _create_data(
453457
objective='regression',
454458
output=output
@@ -469,7 +473,7 @@ def test_regressor(output, boosting_type, client):
469473
dask_regressor = lgb.DaskLGBMRegressor(
470474
client=client,
471475
time_out=5,
472-
tree='data',
476+
tree=tree_learner,
473477
**params
474478
)
475479
dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw)
@@ -623,7 +627,8 @@ def test_regressor_quantile(output, client, alpha):
623627
@pytest.mark.parametrize('output', ['array', 'dataframe', 'dataframe-with-categorical'])
624628
@pytest.mark.parametrize('group', [None, group_sizes])
625629
@pytest.mark.parametrize('boosting_type', boosting_types)
626-
def test_ranker(output, group, boosting_type, client):
630+
@pytest.mark.parametrize('tree_learner', distributed_training_algorithms)
631+
def test_ranker(output, group, boosting_type, tree_learner, client):
627632
if output == 'dataframe-with-categorical':
628633
X, y, w, g, dX, dy, dw, dg = _create_data(
629634
objective='ranking',
@@ -666,7 +671,7 @@ def test_ranker(output, group, boosting_type, client):
666671
dask_ranker = lgb.DaskLGBMRanker(
667672
client=client,
668673
time_out=5,
669-
tree_learner_type='data_parallel',
674+
tree_learner_type=tree_learner,
670675
**params
671676
)
672677
dask_ranker = dask_ranker.fit(dX, dy, sample_weight=dw, group=dg)
@@ -961,22 +966,36 @@ def test_warns_and_continues_on_unrecognized_tree_learner(client):
961966
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
962967

963968

964-
def test_warns_but_makes_no_changes_for_feature_or_voting_tree_learner(client):
965-
X = da.random.random((1e3, 10))
966-
y = da.random.random((1e3, 1))
967-
for tree_learner in ['feature_parallel', 'voting']:
968-
dask_regressor = lgb.DaskLGBMRegressor(
969-
client=client,
970-
time_out=5,
971-
tree_learner=tree_learner,
972-
n_estimators=1,
973-
num_leaves=2
974-
)
975-
with pytest.warns(UserWarning, match='Support for tree_learner %s in lightgbm' % tree_learner):
976-
dask_regressor = dask_regressor.fit(X, y)
969+
@pytest.mark.parametrize('tree_learner', ['data_parallel', 'voting_parallel'])
970+
def test_training_respects_tree_learner_aliases(tree_learner, client):
971+
task = 'regression'
972+
_, _, _, _, dX, dy, dw, dg = _create_data(objective=task, output='array')
973+
dask_factory = task_to_dask_factory[task]
974+
dask_model = dask_factory(
975+
client=client,
976+
tree_learner=tree_learner,
977+
time_out=5,
978+
n_estimators=10,
979+
num_leaves=15
980+
)
981+
dask_model.fit(dX, dy, sample_weight=dw, group=dg)
977982

978-
assert dask_regressor.fitted_
979-
assert dask_regressor.get_params()['tree_learner'] == tree_learner
983+
assert dask_model.fitted_
984+
assert dask_model.get_params()['tree_learner'] == tree_learner
985+
986+
987+
def test_error_on_feature_parallel_tree_learner(client):
988+
X = da.random.random((100, 10), chunks=(50, 10))
989+
y = da.random.random(100, chunks=50)
990+
dask_regressor = lgb.DaskLGBMRegressor(
991+
client=client,
992+
time_out=5,
993+
tree_learner='feature_parallel',
994+
n_estimators=1,
995+
num_leaves=2
996+
)
997+
with pytest.raises(lgb.basic.LightGBMError, match='Do not support feature parallel in c api'):
998+
dask_regressor = dask_regressor.fit(X, y)
980999

9811000
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
9821001

0 commit comments

Comments
 (0)