Skip to content

Commit 6b01796

Browse files
committed
[mt] Small improvements for the demo.
- Use named subplots.
1 parent c5ba21d commit 6b01796

File tree

1 file changed

+29
-19
lines changed

1 file changed

+29
-19
lines changed

demo/guide-python/multioutput_regression.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,22 @@
1616
"""
1717

1818
import argparse
19-
from typing import Dict, List, Tuple
19+
from typing import Dict, List, Optional, Tuple
2020

21+
import matplotlib
2122
import numpy as np
2223
from matplotlib import pyplot as plt
2324

2425
import xgboost as xgb
2526

2627

27-
def plot_predt(y: np.ndarray, y_predt: np.ndarray, name: str) -> None:
28+
def plot_predt(
29+
y: np.ndarray, y_predt: np.ndarray, name: str, ax: matplotlib.axes.Axes
30+
) -> None:
2831
s = 25
29-
plt.scatter(y[:, 0], y[:, 1], c="navy", s=s, edgecolor="black", label="data")
30-
plt.scatter(
31-
y_predt[:, 0], y_predt[:, 1], c="cornflowerblue", s=s, edgecolor="black"
32-
)
33-
plt.xlim([-1, 2])
34-
plt.ylim([-1, 2])
35-
plt.show()
32+
ax.scatter(y[:, 0], y[:, 1], c="navy", s=s, edgecolor="black", label=name)
33+
ax.scatter(y_predt[:, 0], y_predt[:, 1], c="cornflowerblue", s=s, edgecolor="black")
34+
ax.legend()
3635

3736

3837
def gen_circle() -> Tuple[np.ndarray, np.ndarray]:
@@ -46,7 +45,9 @@ def gen_circle() -> Tuple[np.ndarray, np.ndarray]:
4645
return X, y
4746

4847

49-
def rmse_model(plot_result: bool, strategy: str) -> None:
48+
def rmse_model(
49+
plot_result: bool, strategy: str, ax: Optional[matplotlib.axes.Axes]
50+
) -> None:
5051
"""Draw a circle with 2-dim coordinate as target variables."""
5152
X, y = gen_circle()
5253
# Train a regressor on it
@@ -61,11 +62,13 @@ def rmse_model(plot_result: bool, strategy: str) -> None:
6162
reg.fit(X, y, eval_set=[(X, y)])
6263

6364
y_predt = reg.predict(X)
64-
if plot_result:
65-
plot_predt(y, y_predt, "multi")
65+
if ax:
66+
plot_predt(y, y_predt, f"RMSE-{strategy}", ax)
6667

6768

68-
def custom_rmse_model(plot_result: bool, strategy: str) -> None:
69+
def custom_rmse_model(
70+
plot_result: bool, strategy: str, ax: Optional[matplotlib.axes.Axes]
71+
) -> None:
6972
"""Train using Python implementation of Squared Error."""
7073

7174
def gradient(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray:
@@ -111,8 +114,8 @@ def rmse(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:
111114
)
112115

113116
y_predt = booster.inplace_predict(X)
114-
if plot_result:
115-
plot_predt(y, y_predt, "multi")
117+
if ax:
118+
plot_predt(y, y_predt, f"PyRMSE-{strategy}", ax)
116119

117120
np.testing.assert_allclose(
118121
results["Train"]["rmse"], results["Train"]["PyRMSE"], rtol=1e-2
@@ -123,17 +126,24 @@ def rmse(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:
123126
parser = argparse.ArgumentParser()
124127
parser.add_argument("--plot", choices=[0, 1], type=int, default=1)
125128
args = parser.parse_args()
129+
if args.plot == 1:
130+
_, axs = plt.subplots(2, 2)
131+
else:
132+
axs = np.zeros(shape=(2, 2))
133+
assert isinstance(axs, np.ndarray)
126134

127135
# Train with builtin RMSE objective
128136
# - One model per output.
129-
rmse_model(args.plot == 1, "one_output_per_tree")
137+
rmse_model(args.plot == 1, "one_output_per_tree", axs[0, 0])
130138
# - One model for all outputs, this is still working in progress, many features are
131139
# missing.
132-
rmse_model(args.plot == 1, "multi_output_tree")
140+
rmse_model(args.plot == 1, "multi_output_tree", axs[0, 1])
133141

134142
# Train with custom objective.
135143
# - One model per output.
136-
custom_rmse_model(args.plot == 1, "one_output_per_tree")
144+
custom_rmse_model(args.plot == 1, "one_output_per_tree", axs[1, 0])
137145
# - One model for all outputs, this is still working in progress, many features are
138146
# missing.
139-
custom_rmse_model(args.plot == 1, "multi_output_tree")
147+
custom_rmse_model(args.plot == 1, "multi_output_tree", axs[1, 1])
148+
if args.plot == 1:
149+
plt.show()

0 commit comments

Comments
 (0)