diff --git a/poetry.lock b/poetry.lock index 34a214e..3fb9461 100644 --- a/poetry.lock +++ b/poetry.lock @@ -582,13 +582,13 @@ i18n = ["Babel (>=2.7)"] [[package]] name = "joblib" -version = "1.2.0" +version = "1.3.2" description = "Lightweight pipelining with Python functions" optional = false python-versions = ">=3.7" files = [ - {file = "joblib-1.2.0-py3-none-any.whl", hash = "sha256:091138ed78f800342968c523bdde947e7a305b8594b910a0fea2ab83c3c6d385"}, - {file = "joblib-1.2.0.tar.gz", hash = "sha256:e1cee4a79e4af22881164f218d4311f60074197fb707e082e803b61f6d137018"}, + {file = "joblib-1.3.2-py3-none-any.whl", hash = "sha256:ef4331c65f239985f3f2220ecc87db222f08fd22097a3dd5698f693875f8cbb9"}, + {file = "joblib-1.3.2.tar.gz", hash = "sha256:92f865e621e17784e7955080b6d042489e3b8e294949cc44c6eac304f59772b1"}, ] [[package]] @@ -691,23 +691,6 @@ dask = ["dask[array,dataframe,distributed] (>=2.0.0)", "pandas (>=0.24.0)"] pandas = ["pandas (>=0.24.0)"] scikit-learn = ["scikit-learn (!=0.22.0)"] -[[package]] -name = "lightgbm-callbacks" -version = "0.1.5" -description = "A collection of LightGBM callbacks." -optional = false -python-versions = ">=3.8,<4.0" -files = [ - {file = "lightgbm_callbacks-0.1.5-py3-none-any.whl", hash = "sha256:e177e6b7d75f688adf6a349b493c5b373b59da708c468f6fe687d1b71b8cf0ce"}, - {file = "lightgbm_callbacks-0.1.5.tar.gz", hash = "sha256:762f53b7eed2c426fbaffb6eaf802f98c6e7125759b638714ec3b8a2c5c6a86f"}, -] - -[package.dependencies] -lightgbm = ">=4.0.0,<5.0.0" -scikit-learn = ">=1.3.1,<2.0.0" -tqdm = ">=4.65.0,<5.0.0" -typing-extensions = ">=4.5.0,<5.0.0" - [[package]] name = "markdown-it-py" version = "3.0.0" @@ -1405,37 +1388,37 @@ use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] [[package]] name = "scikit-learn" -version = "1.3.1" +version = "1.3.2" description = "A set of python modules for machine learning and data mining" optional = false python-versions = ">=3.8" files = [ - {file = "scikit-learn-1.3.1.tar.gz", hash = "sha256:1a231cced3ee3fa04756b4a7ab532dc9417acd581a330adff5f2c01ac2831fcf"}, - {file = "scikit_learn-1.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3153612ff8d36fa4e35ef8b897167119213698ea78f3fd130b4068e6f8d2da5a"}, - {file = "scikit_learn-1.3.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:6bb9490fdb8e7e00f1354621689187bef3cab289c9b869688f805bf724434755"}, - {file = "scikit_learn-1.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a7135a03af71138669f19bc96e7d0cc8081aed4b3565cc3b131135d65fc642ba"}, - {file = "scikit_learn-1.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d8dee8c1f40eeba49a85fe378bdf70a07bb64aba1a08fda1e0f48d27edfc3e6"}, - {file = "scikit_learn-1.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:4d379f2b34096105a96bd857b88601dffe7389bd55750f6f29aaa37bc6272eb5"}, - {file = "scikit_learn-1.3.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:14e8775eba072ab10866a7e0596bc9906873e22c4c370a651223372eb62de180"}, - {file = "scikit_learn-1.3.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:58b0c2490eff8355dc26e884487bf8edaccf2ba48d09b194fb2f3a026dd64f9d"}, - {file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f66eddfda9d45dd6cadcd706b65669ce1df84b8549875691b1f403730bdef217"}, - {file = "scikit_learn-1.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6448c37741145b241eeac617028ba6ec2119e1339b1385c9720dae31367f2be"}, - {file = "scikit_learn-1.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:c413c2c850241998168bbb3bd1bb59ff03b1195a53864f0b80ab092071af6028"}, - {file = "scikit_learn-1.3.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ef540e09873e31569bc8b02c8a9f745ee04d8e1263255a15c9969f6f5caa627f"}, - {file = "scikit_learn-1.3.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:9147a3a4df4d401e618713880be023e36109c85d8569b3bf5377e6cd3fecdeac"}, - {file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2cd3634695ad192bf71645702b3df498bd1e246fc2d529effdb45a06ab028b4"}, - {file = "scikit_learn-1.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c275a06c5190c5ce00af0acbb61c06374087949f643ef32d355ece12c4db043"}, - {file = "scikit_learn-1.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:0e1aa8f206d0de814b81b41d60c1ce31f7f2c7354597af38fae46d9c47c45122"}, - {file = "scikit_learn-1.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:52b77cc08bd555969ec5150788ed50276f5ef83abb72e6f469c5b91a0009bbca"}, - {file = "scikit_learn-1.3.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a683394bc3f80b7c312c27f9b14ebea7766b1f0a34faf1a2e9158d80e860ec26"}, - {file = "scikit_learn-1.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a15d964d9eb181c79c190d3dbc2fff7338786bf017e9039571418a1d53dab236"}, - {file = "scikit_learn-1.3.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ce9233cdf0cdcf0858a5849d306490bf6de71fa7603a3835124e386e62f2311"}, - {file = "scikit_learn-1.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:1ec668ce003a5b3d12d020d2cde0abd64b262ac5f098b5c84cf9657deb9996a8"}, - {file = "scikit_learn-1.3.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ccbbedae99325628c1d1cbe3916b7ef58a1ce949672d8d39c8b190e10219fd32"}, - {file = "scikit_learn-1.3.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:845f81c7ceb4ea6bac64ab1c9f2ce8bef0a84d0f21f3bece2126adcc213dfecd"}, - {file = "scikit_learn-1.3.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8454d57a22d856f1fbf3091bd86f9ebd4bff89088819886dc0c72f47a6c30652"}, - {file = "scikit_learn-1.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d993fb70a1d78c9798b8f2f28705bfbfcd546b661f9e2e67aa85f81052b9c53"}, - {file = "scikit_learn-1.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:66f7bb1fec37d65f4ef85953e1df5d3c98a0f0141d394dcdaead5a6de9170347"}, + {file = "scikit-learn-1.3.2.tar.gz", hash = "sha256:a2f54c76accc15a34bfb9066e6c7a56c1e7235dda5762b990792330b52ccfb05"}, + {file = "scikit_learn-1.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e326c0eb5cf4d6ba40f93776a20e9a7a69524c4db0757e7ce24ba222471ee8a1"}, + {file = "scikit_learn-1.3.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:535805c2a01ccb40ca4ab7d081d771aea67e535153e35a1fd99418fcedd1648a"}, + {file = "scikit_learn-1.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1215e5e58e9880b554b01187b8c9390bf4dc4692eedeaf542d3273f4785e342c"}, + {file = "scikit_learn-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ee107923a623b9f517754ea2f69ea3b62fc898a3641766cb7deb2f2ce450161"}, + {file = "scikit_learn-1.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:35a22e8015048c628ad099da9df5ab3004cdbf81edc75b396fd0cff8699ac58c"}, + {file = "scikit_learn-1.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6fb6bc98f234fda43163ddbe36df8bcde1d13ee176c6dc9b92bb7d3fc842eb66"}, + {file = "scikit_learn-1.3.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:18424efee518a1cde7b0b53a422cde2f6625197de6af36da0b57ec502f126157"}, + {file = "scikit_learn-1.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3271552a5eb16f208a6f7f617b8cc6d1f137b52c8a1ef8edf547db0259b2c9fb"}, + {file = "scikit_learn-1.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc4144a5004a676d5022b798d9e573b05139e77f271253a4703eed295bde0433"}, + {file = "scikit_learn-1.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:67f37d708f042a9b8d59551cf94d30431e01374e00dc2645fa186059c6c5d78b"}, + {file = "scikit_learn-1.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8db94cd8a2e038b37a80a04df8783e09caac77cbe052146432e67800e430c028"}, + {file = "scikit_learn-1.3.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:61a6efd384258789aa89415a410dcdb39a50e19d3d8410bd29be365bcdd512d5"}, + {file = "scikit_learn-1.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb06f8dce3f5ddc5dee1715a9b9f19f20d295bed8e3cd4fa51e1d050347de525"}, + {file = "scikit_learn-1.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5b2de18d86f630d68fe1f87af690d451388bb186480afc719e5f770590c2ef6c"}, + {file = "scikit_learn-1.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:0402638c9a7c219ee52c94cbebc8fcb5eb9fe9c773717965c1f4185588ad3107"}, + {file = "scikit_learn-1.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a19f90f95ba93c1a7f7924906d0576a84da7f3b2282ac3bfb7a08a32801add93"}, + {file = "scikit_learn-1.3.2-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b8692e395a03a60cd927125eef3a8e3424d86dde9b2370d544f0ea35f78a8073"}, + {file = "scikit_learn-1.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:15e1e94cc23d04d39da797ee34236ce2375ddea158b10bee3c343647d615581d"}, + {file = "scikit_learn-1.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:785a2213086b7b1abf037aeadbbd6d67159feb3e30263434139c98425e3dcfcf"}, + {file = "scikit_learn-1.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:64381066f8aa63c2710e6b56edc9f0894cc7bf59bd71b8ce5613a4559b6145e0"}, + {file = "scikit_learn-1.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6c43290337f7a4b969d207e620658372ba3c1ffb611f8bc2b6f031dc5c6d1d03"}, + {file = "scikit_learn-1.3.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:dc9002fc200bed597d5d34e90c752b74df516d592db162f756cc52836b38fe0e"}, + {file = "scikit_learn-1.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d08ada33e955c54355d909b9c06a4789a729977f165b8bae6f225ff0a60ec4a"}, + {file = "scikit_learn-1.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:763f0ae4b79b0ff9cca0bf3716bcc9915bdacff3cebea15ec79652d1cc4fa5c9"}, + {file = "scikit_learn-1.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:ed932ea780517b00dae7431e031faae6b49b20eb6950918eb83bd043237950e0"}, ] [package.dependencies] @@ -1713,13 +1696,13 @@ doc = ["reno", "sphinx", "tornado (>=4.5)"] [[package]] name = "threadpoolctl" -version = "3.1.0" +version = "3.2.0" description = "threadpoolctl" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "threadpoolctl-3.1.0-py3-none-any.whl", hash = "sha256:8b99adda265feb6773280df41eece7b2e6561b772d21ffd52e372f999024907b"}, - {file = "threadpoolctl-3.1.0.tar.gz", hash = "sha256:a335baacfaa4400ae1f0d8e3a58d6674d2f8828e3716bb2802c44955ad391380"}, + {file = "threadpoolctl-3.2.0-py3-none-any.whl", hash = "sha256:2b7818516e423bdaebb97c723f86a7c6b0a83d3f3b0970328d66f4d9104dc032"}, + {file = "threadpoolctl-3.2.0.tar.gz", hash = "sha256:c96a0ba3bdddeaca37dc4cc7344aafad41cdb8c313f74fdfe387a867bba93355"}, ] [[package]] @@ -1773,26 +1756,6 @@ typing-extensions = "*" [package.extras] opt-einsum = ["opt-einsum (>=3.3)"] -[[package]] -name = "tqdm" -version = "4.65.0" -description = "Fast, Extensible Progress Meter" -optional = false -python-versions = ">=3.7" -files = [ - {file = "tqdm-4.65.0-py3-none-any.whl", hash = "sha256:c4f53a17fe37e132815abceec022631be8ffe1b9381c2e6e30aa70edc99e9671"}, - {file = "tqdm-4.65.0.tar.gz", hash = "sha256:1871fb68a86b8fb3b59ca4cdd3dcccbc7e6d613eeed31f4c332531977b89beb5"}, -] - -[package.dependencies] -colorama = {version = "*", markers = "platform_system == \"Windows\""} - -[package.extras] -dev = ["py-make (>=0.1.0)", "twine", "wheel"] -notebook = ["ipywidgets (>=6)"] -slack = ["slack-sdk"] -telegram = ["requests"] - [[package]] name = "typing-extensions" version = "4.6.2" @@ -1897,4 +1860,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "c6bfcc2b39d1c83d8f98052080ac13c0846c68579b3afde7c1045a2f1f87f4e7" +content-hash = "160a9f8a69ceafa066ea9f0fa2175a764094247e147fef4940b071b785e6358c" diff --git a/pyproject.toml b/pyproject.toml index 0f34283..60130a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ lightgbm = "^4.0.0" catboost = "^1.2" pyhumps = "^3.8.0" attrs = "^23.1.0" -lightgbm-callbacks = "^0.1.1" +scikit-learn = "^1.3.2" [tool.poetry.group.dev.dependencies] pre-commit = ">=3" diff --git a/src/boost_loss/regression/sklearn.py b/src/boost_loss/regression/sklearn.py index 3fa95eb..8918967 100644 --- a/src/boost_loss/regression/sklearn.py +++ b/src/boost_loss/regression/sklearn.py @@ -54,9 +54,11 @@ def __init__( n_jobs: int | None = 1, verbose: int = 0, random_state: int | None = None, - m_type: Literal["mean", "median"] = "median", - var_type: Literal["var", "std", "range", "mae", "mse"] = "var", + m_type: Literal["mean", "median"] = "mean", + var_type: Literal["var", "std", "range", "mae", "mse"] = "std", target_transformer: BaseEstimator | Any | None = None, + recursive: bool = True, + recursive_strict: bool = False, ) -> None: """Estimator that estimates the distribution by simply using multiple estimators with different `t`. @@ -99,6 +101,11 @@ def __init__( target_transformer : BaseEstimator | Any | None, optional The transformer to use for transforming the target, by default None If `None`, no `TransformedTargetRegressor` is used. + recursive : bool, optional + Whether to recursively patch the estimator, by default True + recursive_strict : bool, optional + Whether to recursively patch the estimator's attributes, + lists, tuples, sets, and frozensets as well, by default False Raises ------ @@ -118,7 +125,9 @@ def __init__( self.m_type = m_type self.var_type = var_type self.target_transformer = target_transformer - self.random = np.random.RandomState(random_state) + self.recursive = recursive + self.recursive_strict = recursive_strict + self.random_state_ = np.random.RandomState(random_state) def fit(self, X: Any, y: Any, **fit_params: Any) -> Self: """Fit each estimator with different `t`. @@ -149,13 +158,17 @@ def fit(self, X: Any, y: Any, **fit_params: Any) -> Self: self.estimator, AsymmetricLoss(self.loss, t=t), target_transformer=self.target_transformer, + recursive=self.recursive, + recursive_strict=self.recursive_strict, ) for t in self.ts_ ] - if self.random is not None: + if self.random_state_ is not None: + # set different random state for each estimator + # otherwise, estimators will be identical for estimator in estimators_: _recursively_set_random_state( - estimator, self.random.randint(0, np.iinfo(np.int32).max) + estimator, self.random_state_.randint(0, np.iinfo(np.int32).max) ) parallel_result = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)( [delayed(estimator.fit)(X, y, **fit_params) for estimator in estimators_] diff --git a/src/boost_loss/sklearn.py b/src/boost_loss/sklearn.py index d2bdcf0..38f3db1 100644 --- a/src/boost_loss/sklearn.py +++ b/src/boost_loss/sklearn.py @@ -25,6 +25,7 @@ def apply_custom_loss( copy: bool = ..., target_transformer: None = ..., recursive: bool = ..., + recursive_strict: bool = ..., ) -> TEstimator: ... @@ -37,6 +38,7 @@ def apply_custom_loss( copy: bool = ..., target_transformer: BaseEstimator = ..., recursive: bool = ..., + recursive_strict: bool = ..., ) -> TransformedTargetRegressor: ... @@ -67,6 +69,10 @@ def apply_custom_loss( recursive : bool, optional Whether to recursively search for estimators inside the estimator and apply the custom loss to all of them, by default True + recursive_strict : bool, optional + Whether to recursively search for estimators inside the estimator's + attributes, lists, tuples, sets, and frozensets as well, + by default False Returns ------- @@ -96,10 +102,37 @@ def fit(X: Any, y: Any, **fit_params: Any) -> Any: estimator.set_params( **{ key: apply_custom_loss( - value, loss, copy=False, target_transformer=None + value, + loss, + copy=False, + target_transformer=None, + recursive=False, + recursive_strict=False, ) } ) + if recursive_strict: + if hasattr(estimator, "__dict__"): + for _, value in estimator.__dict__.items(): + apply_custom_loss( + value, + loss, + copy=False, + target_transformer=None, + recursive=True, + recursive_strict=True, + ) + elif isinstance(estimator, (list, tuple, set, frozenset)): + # https://github.com/scikit-learn/scikit-learn/blob/364c77e047ca08a95862becf40a04fe9d4cd2c98/sklearn/base.py#L66 + for value in estimator: + apply_custom_loss( + value, + loss, + copy=False, + target_transformer=None, + recursive=True, + recursive_strict=True, + ) if target_transformer is None: return estimator @@ -145,6 +178,25 @@ def predict_std(X: Any, **predict_params: Any) -> NDArray[Any]: return dist.scale setattr(estimator, "predict_std", predict_std) + + original_predict = estimator.predict + + def predict( + X: Any, + *, + return_std: bool = False, + **predict_params: Any, + ) -> NDArray[Any] | tuple[NDArray[Any], NDArray[Any]]: + if return_std: + dist = self.pred_dist(X, **predict_params) + if not isinstance(dist, Normal): + raise NotImplementedError + return dist.mean, dist.scale + else: + return original_predict(X, **predict_params) + + setattr(estimator, "predict", predict) + return estimator @@ -166,33 +218,6 @@ def patch_catboost(estimator: cb.CatBoost) -> cb.CatBoost: """ original_predict = estimator.predict - @functools.wraps(original_predict) - def predict( - data: Any, - prediction_type: Literal[ - "Probability", "Class", "RawFormulaVal", "Exponent", "LogProbability" - ] = "RawFormulaVal", - ntree_start: int = 0, - ntree_end: int = 0, - thread_count: int = -1, - verbose: bool | None = None, - task_type: str = "CPU", - ) -> NDArray[Any]: - prediction = original_predict( - data, - prediction_type, - ntree_start, - ntree_end, - thread_count, - verbose, - task_type, - ) - if prediction.ndim == 2: - return prediction[:, 0] - return prediction - - setattr(estimator, "predict", predict) - self = estimator def predict_var( @@ -229,6 +254,45 @@ def predict_var( ) setattr(estimator, "predict_var", predict_var) + + @functools.wraps(original_predict) + def predict( + data: Any, + prediction_type: Literal[ + "Probability", "Class", "RawFormulaVal", "Exponent", "LogProbability" + ] = "RawFormulaVal", + ntree_start: int = 0, + ntree_end: int = 0, + thread_count: int = -1, + verbose: bool | None = None, + task_type: str = "CPU", + return_std: bool = False, + ) -> NDArray[Any]: + prediction = original_predict( + data, + prediction_type, + ntree_start, + ntree_end, + thread_count, + verbose, + task_type, + ) + if prediction.ndim == 2: + return prediction[:, 0] + if return_std: + # see virtual_ensembles_predict() for details + return prediction, np.sqrt( + predict_var( + data, + ntree_end=ntree_end, # 0 + thread_count=thread_count, # -1 + verbose=verbose, # None + ) + ) + return prediction + + setattr(estimator, "predict", predict) + return estimator @@ -275,7 +339,7 @@ def patch( if recursive and hasattr(estimator, "get_params"): for _, value in estimator.get_params(deep=True).items(): - patch(value, copy=False, recursive=False, recursive_strict=recursive_strict) + patch(value, copy=False, recursive=False, recursive_strict=False) if recursive_strict: if hasattr(estimator, "__dict__"): for _, value in estimator.__dict__.items(): diff --git a/tests/regression/test_sklearn.py b/tests/regression/test_sklearn.py index e69de29..eb28867 100644 --- a/tests/regression/test_sklearn.py +++ b/tests/regression/test_sklearn.py @@ -0,0 +1,51 @@ +import importlib.util + +import pytest + +if importlib.util.find_spec("seaborn") is None: + pytest.skip("Skipping tests that require seaborn", allow_module_level=True) +else: + from pathlib import Path + from typing import Any + + import matplotlib.pyplot as plt + import numpy as np + import pytest + import seaborn as sns + from catboost import CatBoostRegressor + from lightgbm import LGBMRegressor + from pandas import DataFrame + from scipy.stats import norm + from sklearn.datasets import make_regression + from xgboost import XGBRegressor + + from boost_loss.regression import L1Loss + from boost_loss.regression.sklearn import VarianceEstimator + + @pytest.mark.parametrize( + "estimator", + [ + CatBoostRegressor(n_estimators=100), + LGBMRegressor(), + XGBRegressor(base_score=0.5), + ], + ) + def test_normal(estimator: Any) -> None: + X, y = make_regression(n_samples=100, n_features=1, random_state=0) + y = np.random.standard_normal(size=y.shape) + ve = VarianceEstimator(estimator=estimator, loss=L1Loss(), ts=5) + ve.fit(X, y) + Y = ve.predict_raw(X) + Y = DataFrame(Y.T, columns=[f"{t:g}/{norm.ppf(t):g}" for t in ve.ts_]) + sns.violinplot(data=Y) + plt.title( + f"Violin plot of predictions for estimator {estimator.__class__.__name__}" + ) + path = ( + Path(__file__).parent + / ".cache" + / f"test_normal_{estimator.__class__.__name__}.png" + ) + path.parent.mkdir(exist_ok=True) + plt.savefig(path) + plt.close()