@@ -54,9 +54,11 @@ def __init__(
5454 n_jobs : int | None = 1 ,
5555 verbose : int = 0 ,
5656 random_state : int | None = None ,
57- m_type : Literal ["mean" , "median" ] = "median " ,
58- var_type : Literal ["var" , "std" , "range" , "mae" , "mse" ] = "var " ,
57+ m_type : Literal ["mean" , "median" ] = "mean " ,
58+ var_type : Literal ["var" , "std" , "range" , "mae" , "mse" ] = "std " ,
5959 target_transformer : BaseEstimator | Any | None = None ,
60+ recursive : bool = True ,
61+ recursive_strict : bool = False ,
6062 ) -> None :
6163 """Estimator that estimates the distribution by simply using multiple estimators
6264 with different `t`.
@@ -99,6 +101,11 @@ def __init__(
99101 target_transformer : BaseEstimator | Any | None, optional
100102 The transformer to use for transforming the target, by default None
101103 If `None`, no `TransformedTargetRegressor` is used.
104+ recursive : bool, optional
105+ Whether to recursively patch the estimator, by default True
106+ recursive_strict : bool, optional
107+ Whether to recursively patch the estimator's attributes,
108+ lists, tuples, sets, and frozensets as well, by default False
102109
103110 Raises
104111 ------
@@ -118,7 +125,9 @@ def __init__(
118125 self .m_type = m_type
119126 self .var_type = var_type
120127 self .target_transformer = target_transformer
121- self .random = np .random .RandomState (random_state )
128+ self .recursive = recursive
129+ self .recursive_strict = recursive_strict
130+ self .random_state_ = np .random .RandomState (random_state )
122131
123132 def fit (self , X : Any , y : Any , ** fit_params : Any ) -> Self :
124133 """Fit each estimator with different `t`.
@@ -149,13 +158,17 @@ def fit(self, X: Any, y: Any, **fit_params: Any) -> Self:
149158 self .estimator ,
150159 AsymmetricLoss (self .loss , t = t ),
151160 target_transformer = self .target_transformer ,
161+ recursive = self .recursive ,
162+ recursive_strict = self .recursive_strict ,
152163 )
153164 for t in self .ts_
154165 ]
155- if self .random is not None :
166+ if self .random_state_ is not None :
167+ # set different random state for each estimator
168+ # otherwise, estimators will be identical
156169 for estimator in estimators_ :
157170 _recursively_set_random_state (
158- estimator , self .random .randint (0 , np .iinfo (np .int32 ).max )
171+ estimator , self .random_state_ .randint (0 , np .iinfo (np .int32 ).max )
159172 )
160173 parallel_result = Parallel (n_jobs = self .n_jobs , verbose = self .verbose )(
161174 [delayed (estimator .fit )(X , y , ** fit_params ) for estimator in estimators_ ]
0 commit comments