Skip to content

Commit 24e90ac

Browse files
authored
feat: fix recursive_strict behavior for apply_custom_loss and patch, set default var_type to "std", add more parameters to VarianceEstimator (#105)
1 parent 61187f4 commit 24e90ac

File tree

5 files changed

+198
-107
lines changed

5 files changed

+198
-107
lines changed

poetry.lock

Lines changed: 35 additions & 72 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ lightgbm = "^4.0.0"
2929
catboost = "^1.2"
3030
pyhumps = "^3.8.0"
3131
attrs = "^23.1.0"
32-
lightgbm-callbacks = "^0.1.1"
32+
scikit-learn = "^1.3.2"
3333

3434
[tool.poetry.group.dev.dependencies]
3535
pre-commit = ">=3"

src/boost_loss/regression/sklearn.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,11 @@ def __init__(
5454
n_jobs: int | None = 1,
5555
verbose: int = 0,
5656
random_state: int | None = None,
57-
m_type: Literal["mean", "median"] = "median",
58-
var_type: Literal["var", "std", "range", "mae", "mse"] = "var",
57+
m_type: Literal["mean", "median"] = "mean",
58+
var_type: Literal["var", "std", "range", "mae", "mse"] = "std",
5959
target_transformer: BaseEstimator | Any | None = None,
60+
recursive: bool = True,
61+
recursive_strict: bool = False,
6062
) -> None:
6163
"""Estimator that estimates the distribution by simply using multiple estimators
6264
with different `t`.
@@ -99,6 +101,11 @@ def __init__(
99101
target_transformer : BaseEstimator | Any | None, optional
100102
The transformer to use for transforming the target, by default None
101103
If `None`, no `TransformedTargetRegressor` is used.
104+
recursive : bool, optional
105+
Whether to recursively patch the estimator, by default True
106+
recursive_strict : bool, optional
107+
Whether to recursively patch the estimator's attributes,
108+
lists, tuples, sets, and frozensets as well, by default False
102109
103110
Raises
104111
------
@@ -118,7 +125,9 @@ def __init__(
118125
self.m_type = m_type
119126
self.var_type = var_type
120127
self.target_transformer = target_transformer
121-
self.random = np.random.RandomState(random_state)
128+
self.recursive = recursive
129+
self.recursive_strict = recursive_strict
130+
self.random_state_ = np.random.RandomState(random_state)
122131

123132
def fit(self, X: Any, y: Any, **fit_params: Any) -> Self:
124133
"""Fit each estimator with different `t`.
@@ -149,13 +158,17 @@ def fit(self, X: Any, y: Any, **fit_params: Any) -> Self:
149158
self.estimator,
150159
AsymmetricLoss(self.loss, t=t),
151160
target_transformer=self.target_transformer,
161+
recursive=self.recursive,
162+
recursive_strict=self.recursive_strict,
152163
)
153164
for t in self.ts_
154165
]
155-
if self.random is not None:
166+
if self.random_state_ is not None:
167+
# set different random state for each estimator
168+
# otherwise, estimators will be identical
156169
for estimator in estimators_:
157170
_recursively_set_random_state(
158-
estimator, self.random.randint(0, np.iinfo(np.int32).max)
171+
estimator, self.random_state_.randint(0, np.iinfo(np.int32).max)
159172
)
160173
parallel_result = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)(
161174
[delayed(estimator.fit)(X, y, **fit_params) for estimator in estimators_]

0 commit comments

Comments
 (0)