@@ -1327,9 +1327,7 @@ def test_unsupported_params(self):
1327
1327
SparkXGBClassifier (evals_result = {})
1328
1328
1329
1329
1330
- LTRData = namedtuple (
1331
- "LTRData" , ("df_train" , "df_test" , "df_train_1" )
1332
- )
1330
+ LTRData = namedtuple ("LTRData" , ("df_train" , "df_test" , "df_train_1" ))
1333
1331
1334
1332
1335
1333
@pytest .fixture
@@ -1358,22 +1356,22 @@ def ltr_data(spark: SparkSession) -> Generator[LTRData, None, None]:
1358
1356
["features" , "qid" , "expected_prediction" ],
1359
1357
)
1360
1358
ranker_df_train_1 = spark .createDataFrame (
1361
- [
1362
- (Vectors .sparse (3 , {1 : 1.0 , 2 : 5.5 }), 0 , 9 ),
1363
- (Vectors .sparse (3 , {1 : 6.0 , 2 : 7.5 }), 1 , 9 ),
1364
- (Vectors .sparse (3 , {1 : 8.0 , 2 : 9.5 }), 2 , 9 ),
1365
- (Vectors .dense (1.0 , 2.0 , 3.0 ), 0 , 8 ),
1366
- (Vectors .dense (4.0 , 5.0 , 6.0 ), 1 , 8 ),
1367
- (Vectors .dense (9.0 , 4.0 , 8.0 ), 2 , 8 ),
1368
- (Vectors .sparse (3 , {1 : 1.0 , 2 : 5.5 }), 0 , 7 ),
1369
- (Vectors .sparse (3 , {1 : 6.0 , 2 : 7.5 }), 1 , 7 ),
1370
- (Vectors .sparse (3 , {1 : 8.0 , 2 : 9.5 }), 2 , 7 ),
1371
- (Vectors .dense (1.0 , 2.0 , 3.0 ), 0 , 6 ),
1372
- (Vectors .dense (4.0 , 5.0 , 6.0 ), 1 , 6 ),
1373
- (Vectors .dense (9.0 , 4.0 , 8.0 ), 2 , 6 ),
1374
- ]
1375
- * 4 ,
1376
- ["features" , "label" , "qid" ],
1359
+ [
1360
+ (Vectors .sparse (3 , {1 : 1.0 , 2 : 5.5 }), 0 , 9 ),
1361
+ (Vectors .sparse (3 , {1 : 6.0 , 2 : 7.5 }), 1 , 9 ),
1362
+ (Vectors .sparse (3 , {1 : 8.0 , 2 : 9.5 }), 2 , 9 ),
1363
+ (Vectors .dense (1.0 , 2.0 , 3.0 ), 0 , 8 ),
1364
+ (Vectors .dense (4.0 , 5.0 , 6.0 ), 1 , 8 ),
1365
+ (Vectors .dense (9.0 , 4.0 , 8.0 ), 2 , 8 ),
1366
+ (Vectors .sparse (3 , {1 : 1.0 , 2 : 5.5 }), 0 , 7 ),
1367
+ (Vectors .sparse (3 , {1 : 6.0 , 2 : 7.5 }), 1 , 7 ),
1368
+ (Vectors .sparse (3 , {1 : 8.0 , 2 : 9.5 }), 2 , 7 ),
1369
+ (Vectors .dense (1.0 , 2.0 , 3.0 ), 0 , 6 ),
1370
+ (Vectors .dense (4.0 , 5.0 , 6.0 ), 1 , 6 ),
1371
+ (Vectors .dense (9.0 , 4.0 , 8.0 ), 2 , 6 ),
1372
+ ]
1373
+ * 4 ,
1374
+ ["features" , "label" , "qid" ],
1377
1375
)
1378
1376
yield LTRData (ranker_df_train , ranker_df_test , ranker_df_train_1 )
1379
1377
0 commit comments