Skip to content

Commit b851034

Browse files
committed
Fix spark test.
1 parent 1efab87 commit b851034

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

tests/test_distributed/test_with_spark/test_spark_local.py

+34-6
Original file line numberDiff line numberDiff line change
@@ -1359,14 +1359,42 @@ def ltr_data(spark: SparkSession) -> Generator[LTRData, None, None]:
13591359
],
13601360
["features", "label", "qid"],
13611361
)
1362+
X_train = np.array(
1363+
[
1364+
[1.0, 2.0, 3.0],
1365+
[4.0, 5.0, 6.0],
1366+
[9.0, 4.0, 8.0],
1367+
[np.NaN, 1.0, 5.5],
1368+
[np.NaN, 6.0, 7.5],
1369+
[np.NaN, 8.0, 9.5],
1370+
]
1371+
)
1372+
qid_train = np.array([0, 0, 0, 1, 1, 1])
1373+
y_train = np.array([0, 1, 2, 0, 1, 2])
1374+
1375+
X_test = np.array(
1376+
[
1377+
[1.5, 2.0, 3.0],
1378+
[4.5, 5.0, 6.0],
1379+
[9.0, 4.5, 8.0],
1380+
[np.NaN, 1.0, 6.0],
1381+
[np.NaN, 6.0, 7.0],
1382+
[np.NaN, 8.0, 10.5],
1383+
]
1384+
)
1385+
1386+
ltr = xgb.XGBRanker(tree_method="approx", objective="rank:pairwise")
1387+
ltr.fit(X_train, y_train, qid=qid_train)
1388+
predt = ltr.predict(X_test)
1389+
13621390
ranker_df_test = spark.createDataFrame(
13631391
[
1364-
(Vectors.dense(1.5, 2.0, 3.0), 0, -1.75218),
1365-
(Vectors.dense(4.5, 5.0, 6.0), 0, -0.34192949533462524),
1366-
(Vectors.dense(9.0, 4.5, 8.0), 0, 1.7251298427581787),
1367-
(Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, -1.7521828413009644),
1368-
(Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, -1.0988065004348755),
1369-
(Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, 1.632674217224121),
1392+
(Vectors.dense(1.5, 2.0, 3.0), 0, float(predt[0])),
1393+
(Vectors.dense(4.5, 5.0, 6.0), 0, float(predt[1])),
1394+
(Vectors.dense(9.0, 4.5, 8.0), 0, float(predt[2])),
1395+
(Vectors.sparse(3, {1: 1.0, 2: 6.0}), 1, float(predt[3])),
1396+
(Vectors.sparse(3, {1: 6.0, 2: 7.0}), 1, float(predt[4])),
1397+
(Vectors.sparse(3, {1: 8.0, 2: 10.5}), 1, float(predt[5])),
13701398
],
13711399
["features", "qid", "expected_prediction"],
13721400
)

0 commit comments

Comments
 (0)