1616"""
1717
1818import argparse
19- from typing import Dict , List , Tuple
19+ from typing import Dict , List , Optional , Tuple
2020
21+ import matplotlib
2122import numpy as np
2223from matplotlib import pyplot as plt
2324
2425import xgboost as xgb
2526
2627
27- def plot_predt (y : np .ndarray , y_predt : np .ndarray , name : str ) -> None :
28+ def plot_predt (
29+ y : np .ndarray , y_predt : np .ndarray , name : str , ax : matplotlib .axes .Axes
30+ ) -> None :
2831 s = 25
29- plt .scatter (y [:, 0 ], y [:, 1 ], c = "navy" , s = s , edgecolor = "black" , label = "data" )
30- plt .scatter (
31- y_predt [:, 0 ], y_predt [:, 1 ], c = "cornflowerblue" , s = s , edgecolor = "black"
32- )
33- plt .xlim ([- 1 , 2 ])
34- plt .ylim ([- 1 , 2 ])
35- plt .show ()
32+ ax .scatter (y [:, 0 ], y [:, 1 ], c = "navy" , s = s , edgecolor = "black" , label = name )
33+ ax .scatter (y_predt [:, 0 ], y_predt [:, 1 ], c = "cornflowerblue" , s = s , edgecolor = "black" )
34+ ax .legend ()
3635
3736
3837def gen_circle () -> Tuple [np .ndarray , np .ndarray ]:
@@ -46,7 +45,9 @@ def gen_circle() -> Tuple[np.ndarray, np.ndarray]:
4645 return X , y
4746
4847
49- def rmse_model (plot_result : bool , strategy : str ) -> None :
48+ def rmse_model (
49+ plot_result : bool , strategy : str , ax : Optional [matplotlib .axes .Axes ]
50+ ) -> None :
5051 """Draw a circle with 2-dim coordinate as target variables."""
5152 X , y = gen_circle ()
5253 # Train a regressor on it
@@ -61,11 +62,13 @@ def rmse_model(plot_result: bool, strategy: str) -> None:
6162 reg .fit (X , y , eval_set = [(X , y )])
6263
6364 y_predt = reg .predict (X )
64- if plot_result :
65- plot_predt (y , y_predt , "multi" )
65+ if ax :
66+ plot_predt (y , y_predt , f"RMSE- { strategy } " , ax )
6667
6768
68- def custom_rmse_model (plot_result : bool , strategy : str ) -> None :
69+ def custom_rmse_model (
70+ plot_result : bool , strategy : str , ax : Optional [matplotlib .axes .Axes ]
71+ ) -> None :
6972 """Train using Python implementation of Squared Error."""
7073
7174 def gradient (predt : np .ndarray , dtrain : xgb .DMatrix ) -> np .ndarray :
@@ -111,8 +114,8 @@ def rmse(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:
111114 )
112115
113116 y_predt = booster .inplace_predict (X )
114- if plot_result :
115- plot_predt (y , y_predt , "multi" )
117+ if ax :
118+ plot_predt (y , y_predt , f"PyRMSE- { strategy } " , ax )
116119
117120 np .testing .assert_allclose (
118121 results ["Train" ]["rmse" ], results ["Train" ]["PyRMSE" ], rtol = 1e-2
@@ -123,17 +126,24 @@ def rmse(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:
123126 parser = argparse .ArgumentParser ()
124127 parser .add_argument ("--plot" , choices = [0 , 1 ], type = int , default = 1 )
125128 args = parser .parse_args ()
129+ if args .plot == 1 :
130+ _ , axs = plt .subplots (2 , 2 )
131+ else :
132+ axs = np .zeros (shape = (2 , 2 ))
133+ assert isinstance (axs , np .ndarray )
126134
127135 # Train with builtin RMSE objective
128136 # - One model per output.
129- rmse_model (args .plot == 1 , "one_output_per_tree" )
137+ rmse_model (args .plot == 1 , "one_output_per_tree" , axs [ 0 , 0 ] )
130138 # - One model for all outputs, this is still working in progress, many features are
131139 # missing.
132- rmse_model (args .plot == 1 , "multi_output_tree" )
140+ rmse_model (args .plot == 1 , "multi_output_tree" , axs [ 0 , 1 ] )
133141
134142 # Train with custom objective.
135143 # - One model per output.
136- custom_rmse_model (args .plot == 1 , "one_output_per_tree" )
144+ custom_rmse_model (args .plot == 1 , "one_output_per_tree" , axs [ 1 , 0 ] )
137145 # - One model for all outputs, this is still working in progress, many features are
138146 # missing.
139- custom_rmse_model (args .plot == 1 , "multi_output_tree" )
147+ custom_rmse_model (args .plot == 1 , "multi_output_tree" , axs [1 , 1 ])
148+ if args .plot == 1 :
149+ plt .show ()
0 commit comments