Skip to content

Commit 0e9bc36

Browse files
authored
Infer the enable_categorical during model load. (#11816)
1 parent 3a62712 commit 0e9bc36

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

python-package/xgboost/sklearn.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
_py_version,
6565
)
6666
from .data import (
67+
CAT_T,
6768
_is_cudf_df,
6869
_is_cudf_ser,
6970
_is_cupy_alike,
@@ -1133,7 +1134,6 @@ def load_model(self, fname: ModelIn) -> None:
11331134
f"{self._get_type()}, got: {t}"
11341135
)
11351136

1136-
self.feature_types = self.get_booster().feature_types
11371137
self.get_booster().set_attr(scikit_learn=None)
11381138
config = json.loads(self.get_booster().save_config())
11391139
self._load_model_attributes(config)
@@ -1152,6 +1152,9 @@ def _load_model_attributes(self, config: dict) -> None:
11521152
config["learner"]["learner_model_param"]["base_score"]
11531153
)
11541154
self.feature_types = booster.feature_types
1155+
self.enable_categorical = self.feature_types is not None and any(
1156+
ft == CAT_T for ft in self.feature_types
1157+
)
11551158

11561159
if is_classifier(self):
11571160
self.n_classes_ = int(config["learner"]["learner_model_param"]["num_class"])

tests/python/test_with_sklearn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,6 +1321,7 @@ def test_categorical():
13211321
reg = xgb.XGBRegressor()
13221322
reg.load_model(path)
13231323
assert reg.feature_types == ft
1324+
assert reg.enable_categorical is True
13241325

13251326
onehot, y = tm.make_categorical(
13261327
n_samples=32, n_features=2, n_categories=3, onehot=True

0 commit comments

Comments
 (0)