Skip to content

Commit

Permalink
Added support for the 'RandomizedSearchCV' model type
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Dec 23, 2019
1 parent 9f90667 commit f8d60ca
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 9 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ Java library and command-line application for converting [Scikit-Learn](https://
* [`linear_model.TheilSenRegressor`](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.TheilSenRegressor.html)
* Model Selection:
* [`model_selection.GridSearchCV`](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html)
* [`model_selection.RandomizedSearchCV`](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html)
* Multiclass classification:
* [`multiclass.OneVsRestClassifier`](https://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsRestClassifier.html)
* Naive Bayes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
import org.jpmml.sklearn.PyClassDict;
import sklearn.Estimator;

public class GridSearchCV extends PyClassDict implements Castable {
public class EstimatorSearcher extends PyClassDict implements Castable {

public GridSearchCV(String module, String name){
public EstimatorSearcher(String module, String name){
super(module, name);
}

Expand Down
3 changes: 2 additions & 1 deletion src/main/resources/META-INF/sklearn2pmml.properties
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ sklearn.linear_model.(_sqd_fast|sgd_fast).SquaredHinge = sklearn.linear_model.st
sklearn.linear_model.(_stochastic_gradient|stochastic_gradient).SGDClassifier = sklearn.linear_model.stochastic_gradient.SGDClassifier
sklearn.linear_model.(_stochastic_gradient|stochastic_gradient).SGDRegressor = sklearn.linear_model.LinearRegressor
sklearn.linear_model.(_theil_sen|theil_sen).TheilSenRegressor = sklearn.linear_model.LinearRegressor
sklearn.model_selection._search.GridSearchCV = sklearn.model_selection.GridSearchCV
sklearn.model_selection._search.GridSearchCV = sklearn.model_selection.EstimatorSearcher
sklearn.model_selection._search.RandomizedSearchCV = sklearn.model_selection.EstimatorSearcher
sklearn.multiclass.OneVsRestClassifier =
sklearn.naive_bayes.GaussianNB =
sklearn.neighbors.(_classification|classification).KNeighborsClassifier = sklearn.neighbors.KNeighborsClassifier
Expand Down
8 changes: 4 additions & 4 deletions src/test/resources/csv/NuSVCVersicolor.csv
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ Species
1
1
0
1
0
0
0
0
Expand All @@ -120,9 +120,9 @@ Species
0
1
0
1
0
1
0
0
0
0
1
Expand All @@ -141,7 +141,7 @@ Species
0
0
0
1
0
0
0
0
Expand Down
4 changes: 2 additions & 2 deletions src/test/resources/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from sklearn.impute import MissingIndicator, SimpleImputer
from sklearn.isotonic import IsotonicRegression
from sklearn.linear_model import ARDRegression, BayesianRidge, ElasticNet, ElasticNetCV, HuberRegressor, LarsCV, Lasso, LassoCV, LassoLarsCV, LinearRegression, LogisticRegression, LogisticRegressionCV, OrthogonalMatchingPursuitCV, Ridge, RidgeCV, RidgeClassifier, RidgeClassifierCV, SGDClassifier, SGDRegressor, TheilSenRegressor
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
from sklearn.multiclass import OneVsRestClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
Expand Down Expand Up @@ -344,7 +344,7 @@ def build_versicolor(classifier, name, with_proba = True, **pmml_options):
build_versicolor(SGDClassifier(max_iter = 100, random_state = 13), "SGDVersicolor", with_proba = False)
build_versicolor(SGDClassifier(loss = "log", max_iter = 100, random_state = 13), "SGDLogVersicolor")
build_versicolor(GridSearchCV(SVC(gamma = "auto"), {"C" : [1, 3, 5]}), "SVCVersicolor", with_proba = False)
build_versicolor(NuSVC(gamma = "auto"), "NuSVCVersicolor", with_proba = False)
build_versicolor(RandomizedSearchCV(NuSVC(gamma = "auto"), {"nu" : [0.3, 0.4, 0.5, 0.6]}), "NuSVCVersicolor", with_proba = False)

versicolor_X, versicolor_y = load_versicolor("Versicolor")

Expand Down
Binary file modified src/test/resources/pkl/NuSVCVersicolor.pkl
Binary file not shown.

0 comments on commit f8d60ca

Please sign in to comment.