Skip to content

Commit

Permalink
feat: fix recursive_strict behavior for apply_custom_loss and `pa…
Browse files Browse the repository at this point in the history
…tch`, set default `var_type` to `"std"`, add more parameters to `VarianceEstimator` (#105)
  • Loading branch information
34j authored Nov 4, 2023
1 parent 61187f4 commit 24e90ac
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 107 deletions.
107 changes: 35 additions & 72 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ lightgbm = "^4.0.0"
catboost = "^1.2"
pyhumps = "^3.8.0"
attrs = "^23.1.0"
lightgbm-callbacks = "^0.1.1"
scikit-learn = "^1.3.2"

[tool.poetry.group.dev.dependencies]
pre-commit = ">=3"
Expand Down
23 changes: 18 additions & 5 deletions src/boost_loss/regression/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@ def __init__(
n_jobs: int | None = 1,
verbose: int = 0,
random_state: int | None = None,
m_type: Literal["mean", "median"] = "median",
var_type: Literal["var", "std", "range", "mae", "mse"] = "var",
m_type: Literal["mean", "median"] = "mean",
var_type: Literal["var", "std", "range", "mae", "mse"] = "std",
target_transformer: BaseEstimator | Any | None = None,
recursive: bool = True,
recursive_strict: bool = False,
) -> None:
"""Estimator that estimates the distribution by simply using multiple estimators
with different `t`.
Expand Down Expand Up @@ -99,6 +101,11 @@ def __init__(
target_transformer : BaseEstimator | Any | None, optional
The transformer to use for transforming the target, by default None
If `None`, no `TransformedTargetRegressor` is used.
recursive : bool, optional
Whether to recursively patch the estimator, by default True
recursive_strict : bool, optional
Whether to recursively patch the estimator's attributes,
lists, tuples, sets, and frozensets as well, by default False
Raises
------
Expand All @@ -118,7 +125,9 @@ def __init__(
self.m_type = m_type
self.var_type = var_type
self.target_transformer = target_transformer
self.random = np.random.RandomState(random_state)
self.recursive = recursive
self.recursive_strict = recursive_strict
self.random_state_ = np.random.RandomState(random_state)

def fit(self, X: Any, y: Any, **fit_params: Any) -> Self:
"""Fit each estimator with different `t`.
Expand Down Expand Up @@ -149,13 +158,17 @@ def fit(self, X: Any, y: Any, **fit_params: Any) -> Self:
self.estimator,
AsymmetricLoss(self.loss, t=t),
target_transformer=self.target_transformer,
recursive=self.recursive,
recursive_strict=self.recursive_strict,
)
for t in self.ts_
]
if self.random is not None:
if self.random_state_ is not None:
# set different random state for each estimator
# otherwise, estimators will be identical
for estimator in estimators_:
_recursively_set_random_state(
estimator, self.random.randint(0, np.iinfo(np.int32).max)
estimator, self.random_state_.randint(0, np.iinfo(np.int32).max)
)
parallel_result = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)(
[delayed(estimator.fit)(X, y, **fit_params) for estimator in estimators_]
Expand Down
Loading

0 comments on commit 24e90ac

Please sign in to comment.