diff --git a/coremltools/converters/xgboost/_tree_ensemble.py b/coremltools/converters/xgboost/_tree_ensemble.py index c29b22e6d..801f5b594 100644 --- a/coremltools/converters/xgboost/_tree_ensemble.py +++ b/coremltools/converters/xgboost/_tree_ensemble.py @@ -39,7 +39,10 @@ def recurse_json( if "leaf" not in xgb_tree_json: branch_mode = "BranchOnValueLessThan" split_name = xgb_tree_json["split"] - feature_index = split_name if not feature_map else feature_map[split_name] + if split_name in feature_map: + feature_index = feature_map[split_name] + else: + feature_index = int(split_name) # xgboost internally uses float32, but the parsing from json pulls it out # as a 64bit double. To trigger the internal float32 detection in the @@ -157,7 +160,6 @@ def convert_tree_ensemble( import json import os - feature_map = None if isinstance( model, (_xgboost.core.Booster, _xgboost.XGBRegressor, _xgboost.XGBClassifier) ): @@ -202,15 +204,13 @@ def convert_tree_ensemble( raise ValueError( "The XGBoost model does not have feature names. They must be provided in convert method." ) - feature_names = model.feature_names + # Use user given feature names if they exist if feature_names is None: feature_names = model.feature_names - + feature_map = {f: i for i, f in enumerate(feature_names)} + xgb_model_str = model.get_dump(with_stats=True, dump_format="json") - if model.feature_names: - feature_map = {f: i for i, f in enumerate(model.feature_names)} - # Path on the file system where the XGboost model exists. elif isinstance(model, str): if not os.path.exists(model): diff --git a/coremltools/test/xgboost_tests/test_boosted_trees_regression.py b/coremltools/test/xgboost_tests/test_boosted_trees_regression.py index 7e6c7c620..6c4498a05 100644 --- a/coremltools/test/xgboost_tests/test_boosted_trees_regression.py +++ b/coremltools/test/xgboost_tests/test_boosted_trees_regression.py @@ -208,3 +208,18 @@ def test_conversion_bad_inputs(self): with self.assertRaises(TypeError): model = OneHotEncoder() spec = xgb_converter.convert(model, "data", "out") + + def test_conversion_model_without_feature_names(self): + # Train model without feature names + dtrain = xgboost.DMatrix( + self.scikit_data.data, + label=self.scikit_data.target + ) + model = xgboost.train({}, dtrain, 1) + + spec = xgb_converter.convert(model, feature_names=self.feature_names) + + self.assertEqual( + sorted(self.feature_names), + sorted(map(lambda x: x.name, spec.description.input)) + )