44
44
CLIENT_CLOSE_TIMEOUT = 120
45
45
46
46
tasks = ['binary-classification' , 'multiclass-classification' , 'regression' , 'ranking' ]
47
+ distributed_training_algorithms = ['data' , 'voting' ]
47
48
data_output = ['array' , 'scipy_csr_matrix' , 'dataframe' , 'dataframe-with-categorical' ]
48
49
boosting_types = ['gbdt' , 'dart' , 'goss' , 'rf' ]
49
50
group_sizes = [5 , 5 , 5 , 10 , 10 , 10 , 20 , 20 , 20 , 50 , 50 ]
@@ -235,14 +236,16 @@ def _unpickle(filepath, serializer):
235
236
@pytest .mark .parametrize ('output' , data_output )
236
237
@pytest .mark .parametrize ('task' , ['binary-classification' , 'multiclass-classification' ])
237
238
@pytest .mark .parametrize ('boosting_type' , boosting_types )
238
- def test_classifier (output , task , boosting_type , client ):
239
+ @pytest .mark .parametrize ('tree_learner' , distributed_training_algorithms )
240
+ def test_classifier (output , task , boosting_type , tree_learner , client ):
239
241
X , y , w , _ , dX , dy , dw , _ = _create_data (
240
242
objective = task ,
241
243
output = output
242
244
)
243
245
244
246
params = {
245
247
"boosting_type" : boosting_type ,
248
+ "tree_learner" : tree_learner ,
246
249
"n_estimators" : 50 ,
247
250
"num_leaves" : 31
248
251
}
@@ -273,7 +276,7 @@ def test_classifier(output, task, boosting_type, client):
273
276
p2_proba = local_classifier .predict_proba (X )
274
277
s2 = local_classifier .score (X , y )
275
278
276
- if boosting_type == 'rf' and output == 'dataframe-with-categorical' :
279
+ if boosting_type == 'rf' :
277
280
# https://github.com/microsoft/LightGBM/issues/4118
278
281
assert_eq (s1 , s2 , atol = 0.01 )
279
282
assert_eq (p1_proba , p2_proba , atol = 0.8 )
@@ -448,7 +451,8 @@ def test_training_does_not_fail_on_port_conflicts(client):
448
451
449
452
@pytest .mark .parametrize ('output' , data_output )
450
453
@pytest .mark .parametrize ('boosting_type' , boosting_types )
451
- def test_regressor (output , boosting_type , client ):
454
+ @pytest .mark .parametrize ('tree_learner' , distributed_training_algorithms )
455
+ def test_regressor (output , boosting_type , tree_learner , client ):
452
456
X , y , w , _ , dX , dy , dw , _ = _create_data (
453
457
objective = 'regression' ,
454
458
output = output
@@ -469,7 +473,7 @@ def test_regressor(output, boosting_type, client):
469
473
dask_regressor = lgb .DaskLGBMRegressor (
470
474
client = client ,
471
475
time_out = 5 ,
472
- tree = 'data' ,
476
+ tree = tree_learner ,
473
477
** params
474
478
)
475
479
dask_regressor = dask_regressor .fit (dX , dy , sample_weight = dw )
@@ -623,7 +627,8 @@ def test_regressor_quantile(output, client, alpha):
623
627
@pytest .mark .parametrize ('output' , ['array' , 'dataframe' , 'dataframe-with-categorical' ])
624
628
@pytest .mark .parametrize ('group' , [None , group_sizes ])
625
629
@pytest .mark .parametrize ('boosting_type' , boosting_types )
626
- def test_ranker (output , group , boosting_type , client ):
630
+ @pytest .mark .parametrize ('tree_learner' , distributed_training_algorithms )
631
+ def test_ranker (output , group , boosting_type , tree_learner , client ):
627
632
if output == 'dataframe-with-categorical' :
628
633
X , y , w , g , dX , dy , dw , dg = _create_data (
629
634
objective = 'ranking' ,
@@ -666,7 +671,7 @@ def test_ranker(output, group, boosting_type, client):
666
671
dask_ranker = lgb .DaskLGBMRanker (
667
672
client = client ,
668
673
time_out = 5 ,
669
- tree_learner_type = 'data_parallel' ,
674
+ tree_learner_type = tree_learner ,
670
675
** params
671
676
)
672
677
dask_ranker = dask_ranker .fit (dX , dy , sample_weight = dw , group = dg )
@@ -961,22 +966,36 @@ def test_warns_and_continues_on_unrecognized_tree_learner(client):
961
966
client .close (timeout = CLIENT_CLOSE_TIMEOUT )
962
967
963
968
964
- def test_warns_but_makes_no_changes_for_feature_or_voting_tree_learner ( client ):
965
- X = da . random . random (( 1e3 , 10 ))
966
- y = da . random . random (( 1e3 , 1 ))
967
- for tree_learner in [ 'feature_parallel' , 'voting' ]:
968
- dask_regressor = lgb . DaskLGBMRegressor (
969
- client = client ,
970
- time_out = 5 ,
971
- tree_learner = tree_learner ,
972
- n_estimators = 1 ,
973
- num_leaves = 2
974
- )
975
- with pytest . warns ( UserWarning , match = 'Support for tree_learner %s in lightgbm' % tree_learner ):
976
- dask_regressor = dask_regressor .fit (X , y )
969
+ @ pytest . mark . parametrize ( 'tree_learner' , [ 'data_parallel' , 'voting_parallel' ])
970
+ def test_training_respects_tree_learner_aliases ( tree_learner , client ):
971
+ task = 'regression'
972
+ _ , _ , _ , _ , dX , dy , dw , dg = _create_data ( objective = task , output = 'array' )
973
+ dask_factory = task_to_dask_factory [ task ]
974
+ dask_model = dask_factory (
975
+ client = client ,
976
+ tree_learner = tree_learner ,
977
+ time_out = 5 ,
978
+ n_estimators = 10 ,
979
+ num_leaves = 15
980
+ )
981
+ dask_model .fit (dX , dy , sample_weight = dw , group = dg )
977
982
978
- assert dask_regressor .fitted_
979
- assert dask_regressor .get_params ()['tree_learner' ] == tree_learner
983
+ assert dask_model .fitted_
984
+ assert dask_model .get_params ()['tree_learner' ] == tree_learner
985
+
986
+
987
+ def test_error_on_feature_parallel_tree_learner (client ):
988
+ X = da .random .random ((100 , 10 ), chunks = (50 , 10 ))
989
+ y = da .random .random (100 , chunks = 50 )
990
+ dask_regressor = lgb .DaskLGBMRegressor (
991
+ client = client ,
992
+ time_out = 5 ,
993
+ tree_learner = 'feature_parallel' ,
994
+ n_estimators = 1 ,
995
+ num_leaves = 2
996
+ )
997
+ with pytest .raises (lgb .basic .LightGBMError , match = 'Do not support feature parallel in c api' ):
998
+ dask_regressor = dask_regressor .fit (X , y )
980
999
981
1000
client .close (timeout = CLIENT_CLOSE_TIMEOUT )
982
1001
0 commit comments