Skip to content

Commit a591369

Browse files
authored
Add Python type hint for tests and demos. (#11795)
1 parent f567d94 commit a591369

24 files changed

+211
-169
lines changed

demo/guide-python/cross_validation.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""
55

66
import os
7+
from typing import Any, Dict, Tuple
78

89
import numpy as np
910

@@ -54,7 +55,9 @@
5455
# used to return the preprocessed training, test data, and parameter
5556
# we can use this to do weight rescale, etc.
5657
# as a example, we try to set scale_pos_weight
57-
def fpreproc(dtrain, dtest, param):
58+
def fpreproc(
59+
dtrain: xgb.DMatrix, dtest: xgb.DMatrix, param: Any
60+
) -> Tuple[xgb.DMatrix, xgb.DMatrix, Dict[str, Any]]:
5861
label = dtrain.get_label()
5962
ratio = float(np.sum(label == 0)) / np.sum(label == 1)
6063
param["scale_pos_weight"] = ratio
@@ -74,15 +77,15 @@ def fpreproc(dtrain, dtest, param):
7477
print("running cross validation, with customized loss function")
7578

7679

77-
def logregobj(preds, dtrain):
80+
def logregobj(preds: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[np.ndarray, np.ndarray]:
7881
labels = dtrain.get_label()
7982
preds = 1.0 / (1.0 + np.exp(-preds))
8083
grad = preds - labels
8184
hess = preds * (1.0 - preds)
8285
return grad, hess
8386

8487

85-
def evalerror(preds, dtrain):
88+
def evalerror(preds: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:
8689
labels = dtrain.get_label()
8790
preds = 1.0 / (1.0 + np.exp(-preds))
8891
return "error", float(sum(labels != (preds > 0.0))) / len(labels)

demo/guide-python/custom_rmsle.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from time import time
1919
from typing import Dict, List, Tuple
2020

21-
import matplotlib
2221
import numpy as np
2322
from matplotlib import pyplot as plt
2423

@@ -136,7 +135,7 @@ def squared_log(predt: np.ndarray,
136135
def rmsle(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:
137136
''' Root mean squared log error metric.
138137
139-
:math:`\sqrt{\frac{1}{N}[log(pred + 1) - log(label + 1)]^2}`
138+
:math:`\\sqrt{\frac{1}{N}[log(pred + 1) - log(label + 1)]^2}`
140139
'''
141140
y = dtrain.get_label()
142141
predt[predt < -1] = -1 + 1e-6
@@ -156,11 +155,16 @@ def rmsle(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:
156155
return results
157156

158157

159-
def plot_history(rmse_evals, rmsle_evals, py_rmsle_evals):
158+
def plot_history(
159+
rmse_evals: Dict[str, Dict],
160+
rmsle_evals: Dict[str, Dict],
161+
py_rmsle_evals: Dict[str, Dict]
162+
) -> None:
160163
fig, axs = plt.subplots(3, 1)
161-
ax0: matplotlib.axes.Axes = axs[0]
162-
ax1: matplotlib.axes.Axes = axs[1]
163-
ax2: matplotlib.axes.Axes = axs[2]
164+
assert isinstance(axs, np.ndarray)
165+
ax0 = axs[0]
166+
ax1 = axs[1]
167+
ax2 = axs[2]
164168

165169
x = np.arange(0, kBoostRound, 1)
166170

@@ -177,7 +181,7 @@ def plot_history(rmse_evals, rmsle_evals, py_rmsle_evals):
177181
ax2.legend()
178182

179183

180-
def main(args):
184+
def main(args: argparse.Namespace) -> None:
181185
dtrain, dtest = generate_data()
182186
rmse_evals = native_rmse(dtrain, dtest)
183187
rmsle_evals = native_rmsle(dtrain, dtest)

demo/guide-python/evals_result.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
======================================================
44
"""
55
import os
6+
from typing import Any, Dict
67

78
import xgboost as xgb
89

@@ -24,7 +25,7 @@
2425
num_round = 2
2526
watchlist = [(dtest, "eval"), (dtrain, "train")]
2627

27-
evals_result = {}
28+
evals_result: Dict[str, Any] = {}
2829
bst = xgb.train(param, dtrain, num_round, watchlist, evals_result=evals_result)
2930

3031
print("Access logloss metric directly from evals_result:")

demo/guide-python/predict_first_ntree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
test = os.path.join(CURRENT_DIR, "../data/agaricus.txt.test")
1515

1616

17-
def native_interface():
17+
def native_interface() -> None:
1818
# load data in do training
1919
dtrain = xgb.DMatrix(train + "?format=libsvm")
2020
dtest = xgb.DMatrix(test + "?format=libsvm")
@@ -34,7 +34,7 @@ def native_interface():
3434
print("error of ypred2=%f" % (np.sum((ypred2 > 0.5) != label) / float(len(label))))
3535

3636

37-
def sklearn_interface():
37+
def sklearn_interface() -> None:
3838
X_train, y_train = load_svmlight_file(train)
3939
X_test, y_test = load_svmlight_file(test)
4040
clf = xgb.XGBClassifier(n_estimators=3, max_depth=2, eta=1)

demo/guide-python/spark_estimator_examples.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
@author: Weichen Xu
66
"""
77

8+
import numpy as np
89
import sklearn.datasets
910
from pyspark.ml.evaluation import MulticlassClassificationEvaluator, RegressionEvaluator
1011
from pyspark.ml.linalg import Vectors
11-
from pyspark.sql import SparkSession
12+
from pyspark.sql import DataFrame, SparkSession
1213
from pyspark.sql.functions import rand
1314
from sklearn.model_selection import train_test_split
1415

@@ -17,7 +18,7 @@
1718
spark = SparkSession.builder.master("local[*]").getOrCreate()
1819

1920

20-
def create_spark_df(X, y):
21+
def create_spark_df(X: np.ndarray, y: np.ndarray) -> DataFrame:
2122
return spark.createDataFrame(
2223
spark.sparkContext.parallelize(
2324
[(Vectors.dense(features), float(label)) for features, label in zip(X, y)]

ops/script/lint_python.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -111,39 +111,16 @@ class LintersPaths:
111111
"tests/python/test_model_io.py",
112112
"tests/python/test_ordinal.py",
113113
"tests/python/test_interaction_constraints.py",
114-
"tests/python-gpu/test_gpu_callbacks.py",
115-
"tests/python-gpu/test_gpu_data_iterator.py",
116-
"tests/python-gpu/test_gpu_ordinal.py",
117-
"tests/python-gpu/load_pickle.py",
118-
"tests/python-gpu/test_gpu_training_continuation.py",
119-
"tests/python-gpu/test_gpu_plotting.py",
120-
"tests/python-gpu/test_gpu_parse_tree.py",
114+
"tests/python-gpu/",
121115
"tests/test_distributed/test_federated/",
122116
"tests/test_distributed/test_gpu_federated/",
123-
"tests/test_distributed/test_with_dask/test_ranking.py",
124-
"tests/test_distributed/test_with_dask/test_external_memory.py",
117+
"tests/test_distributed/test_with_dask/",
125118
"tests/test_distributed/test_with_spark/test_data.py",
126119
"tests/test_distributed/test_gpu_with_spark/test_data.py",
127120
"tests/test_distributed/test_gpu_with_dask/",
128121
# demo
129122
"demo/dask/",
130-
"demo/guide-python/custom_softmax.py",
131-
"demo/guide-python/external_memory.py",
132-
"demo/guide-python/distributed_extmem_basic.py",
133-
"demo/guide-python/sklearn_examples.py",
134-
"demo/guide-python/continuation.py",
135-
"demo/guide-python/callbacks.py",
136-
"demo/guide-python/update_process.py",
137-
"demo/guide-python/cat_in_the_dat.py",
138-
"demo/guide-python/categorical.py",
139-
"demo/guide-python/cat_pipeline.py",
140-
"demo/guide-python/feature_weights.py",
141-
"demo/guide-python/model_parser.py",
142-
"demo/guide-python/individual_trees.py",
143-
"demo/guide-python/quantile_regression.py",
144-
"demo/guide-python/quantile_data_iterator.py",
145-
"demo/guide-python/multioutput_regression.py",
146-
"demo/guide-python/learning_to_rank.py",
123+
"demo/guide-python/",
147124
"demo/aft_survival/aft_survival_viz_demo.py",
148125
# CI
149126
"ops/",

python-package/xgboost/sklearn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __call__(
103103
self,
104104
y_true: ArrayLike,
105105
y_pred: ArrayLike,
106-
sample_weight: Optional[ArrayLike],
106+
sample_weight: Optional[ArrayLike] = None,
107107
) -> Tuple[ArrayLike, ArrayLike]: ...
108108

109109

tests/python-gpu/conftest.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1+
from typing import Any, List
2+
13
import pytest
24

35
from xgboost import testing as tm
46

57

6-
def has_rmm():
8+
def has_rmm() -> bool:
79
return tm.no_rmm()["condition"]
810

911

1012
@pytest.fixture(scope="session", autouse=True)
11-
def setup_rmm_pool(request, pytestconfig):
13+
def setup_rmm_pool(request: Any, pytestconfig: pytest.Config) -> None:
1214
tm.setup_rmm_pool(request, pytestconfig)
1315

1416

@@ -18,7 +20,9 @@ def pytest_addoption(parser: pytest.Parser) -> None:
1820
)
1921

2022

21-
def pytest_collection_modifyitems(config, items):
23+
def pytest_collection_modifyitems(
24+
config: pytest.Config, items: List[pytest.Item]
25+
) -> None:
2226
if config.getoption("--use-rmm-pool"):
2327
blocklist = [
2428
"python-gpu/test_gpu_demos.py::test_dask_training",

tests/python-gpu/test_device_quantile_dmatrix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def test_ref_dmatrix(self) -> None:
203203
strategies.fractions(0, 0.99),
204204
)
205205
@settings(print_blob=True, deadline=None)
206-
def test_to_csr(self, n_samples, n_features, sparsity) -> None:
206+
def test_to_csr(self, n_samples: int, n_features: int, sparsity: float) -> None:
207207
import cupy as cp
208208

209209
X, y = tm.make_sparse_regression(n_samples, n_features, sparsity, False)

0 commit comments

Comments
 (0)