diff --git a/asreview2-optuna/classifiers.py b/asreview2-optuna/classifiers.py index 65b6ab2..6a3f16c 100644 --- a/asreview2-optuna/classifiers.py +++ b/asreview2-optuna/classifiers.py @@ -3,10 +3,11 @@ from asreview.models.classifiers import ( NaiveBayesClassifier, LogisticClassifier, - SVMClassifier, RandomForestClassifier, ) +from sklearn.svm import LinearSVC + def naive_bayes_params(trial: optuna.trial.FrozenTrial): # Use logarithmic normal distribution for alpha (alpha effect is non-linear) @@ -22,17 +23,11 @@ def logistic_params(trial: optuna.trial.FrozenTrial): def svm_params(trial: optuna.trial.FrozenTrial): # Use logarithmic normal distribution for C (C effect is non-linear) - C = trial.suggest_float("svm__C", 0.05, 10, log=True) - - # Use categorical for kernel - kernel = "linear"#trial.suggest_categorical("svm__kernel", ["linear", "rbf"]) + C = trial.suggest_float("svm__C", 0.01, 10, log=True) - # Only set gamma to a value if we use rbf kernel - gamma = "scale" - if kernel == "rbf": - # Use logarithmic normal distribution for gamma (gamma effect is non-linear) - gamma = trial.suggest_float("svm__gamma", 1e-4, 10, log=True) - return {"C": C, "kernel": kernel, "gamma": gamma} + loss = trial.suggest_categorical("svm__loss", ["hinge", "squared_hinge"]) + + return {"C": C, "loss": loss} def random_forest_params(trial: optuna.trial.FrozenTrial): @@ -52,9 +47,25 @@ def random_forest_params(trial: optuna.trial.FrozenTrial): } +class LinearSVMClassifier(LinearSVC): + """Support vector machine classifier. + + Based on the sklearn implementation of the support vector machine + sklearn.svm.LinearSVC. + """ + + name = "svm" + label = "Support vector machine" + + def __init__(self, C=15.4, **kwargs): + super().__init__( + C=C, + **kwargs, + ) + classifiers = { "nb": NaiveBayesClassifier, "log": LogisticClassifier, - "svm": SVMClassifier, + "svm": LinearSVMClassifier, "rf": RandomForestClassifier, } diff --git a/asreview2-optuna/main.py b/asreview2-optuna/main.py index f95b9eb..5339eef 100644 --- a/asreview2-optuna/main.py +++ b/asreview2-optuna/main.py @@ -19,7 +19,7 @@ from feature_extractors import feature_extractor_params, feature_extractors # Study variables -VERSION = 1 +VERSION = 2 STUDY_SET = "demo" PICKLE_FOLDER_PATH = Path("synergy-dataset", "pickles") CLASSIFIER_TYPE = "svm" # Options: "nb", "log", "svm", "rf"