@@ -54,9 +54,11 @@ def __init__(
54
54
n_jobs : int | None = 1 ,
55
55
verbose : int = 0 ,
56
56
random_state : int | None = None ,
57
- m_type : Literal ["mean" , "median" ] = "median " ,
58
- var_type : Literal ["var" , "std" , "range" , "mae" , "mse" ] = "var " ,
57
+ m_type : Literal ["mean" , "median" ] = "mean " ,
58
+ var_type : Literal ["var" , "std" , "range" , "mae" , "mse" ] = "std " ,
59
59
target_transformer : BaseEstimator | Any | None = None ,
60
+ recursive : bool = True ,
61
+ recursive_strict : bool = False ,
60
62
) -> None :
61
63
"""Estimator that estimates the distribution by simply using multiple estimators
62
64
with different `t`.
@@ -99,6 +101,11 @@ def __init__(
99
101
target_transformer : BaseEstimator | Any | None, optional
100
102
The transformer to use for transforming the target, by default None
101
103
If `None`, no `TransformedTargetRegressor` is used.
104
+ recursive : bool, optional
105
+ Whether to recursively patch the estimator, by default True
106
+ recursive_strict : bool, optional
107
+ Whether to recursively patch the estimator's attributes,
108
+ lists, tuples, sets, and frozensets as well, by default False
102
109
103
110
Raises
104
111
------
@@ -118,7 +125,9 @@ def __init__(
118
125
self .m_type = m_type
119
126
self .var_type = var_type
120
127
self .target_transformer = target_transformer
121
- self .random = np .random .RandomState (random_state )
128
+ self .recursive = recursive
129
+ self .recursive_strict = recursive_strict
130
+ self .random_state_ = np .random .RandomState (random_state )
122
131
123
132
def fit (self , X : Any , y : Any , ** fit_params : Any ) -> Self :
124
133
"""Fit each estimator with different `t`.
@@ -149,13 +158,17 @@ def fit(self, X: Any, y: Any, **fit_params: Any) -> Self:
149
158
self .estimator ,
150
159
AsymmetricLoss (self .loss , t = t ),
151
160
target_transformer = self .target_transformer ,
161
+ recursive = self .recursive ,
162
+ recursive_strict = self .recursive_strict ,
152
163
)
153
164
for t in self .ts_
154
165
]
155
- if self .random is not None :
166
+ if self .random_state_ is not None :
167
+ # set different random state for each estimator
168
+ # otherwise, estimators will be identical
156
169
for estimator in estimators_ :
157
170
_recursively_set_random_state (
158
- estimator , self .random .randint (0 , np .iinfo (np .int32 ).max )
171
+ estimator , self .random_state_ .randint (0 , np .iinfo (np .int32 ).max )
159
172
)
160
173
parallel_result = Parallel (n_jobs = self .n_jobs , verbose = self .verbose )(
161
174
[delayed (estimator .fit )(X , y , ** fit_params ) for estimator in estimators_ ]
0 commit comments