Skip to content

Commit

Permalink
Prep xgboost study
Browse files Browse the repository at this point in the history
  • Loading branch information
timovdk committed Feb 5, 2025
1 parent cbe49a6 commit a65d435
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 4 deletions.
1 change: 1 addition & 0 deletions asreview2-optuna/ansible/ansible_optuna_playbook.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
name:
- python3-venv
- python3-pip
- libgomp1
state: present

- name: Create Python virtual environment
Expand Down
23 changes: 23 additions & 0 deletions asreview2-optuna/classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
SVM,
)

from xgboost import XGBClassifier


def naive_bayes_params(trial: optuna.trial.FrozenTrial):
# Use logarithmic normal distribution for alpha (alpha effect is non-linear)
Expand Down Expand Up @@ -34,17 +36,38 @@ def random_forest_params(trial: optuna.trial.FrozenTrial):
return {"n_estimators": n_estimators, "max_features": max_features}


def xgboost_params(trial: optuna.trial.FrozenTrial):
# Use normal distribution for n_estimators (n_estimators effect is linear)
n_estimators = trial.suggest_int("xgboost__n_estimators", 50, 500)

# Use normal distribution for max_depth (max_depth effect is linear)
max_depth = trial.suggest_int("xgboost__max_depth", 2, 20)
return {"n_estimators": n_estimators, "max_depth": max_depth}


classifier_params = {
"nb": naive_bayes_params,
"log": logistic_params,
"svm": svm_params,
"rf": random_forest_params,
"xgboost": xgboost_params,
}

class XGBoost(XGBClassifier):
"""XGBoost classifier.
"""

name = "xgboost"
label = "XGBoost"

def __init__(self, **kwargs):
super().__init__(**kwargs)

classifiers = {
"nb": NaiveBayes,
"log": Logistic,
"svm": SVM,
"rf": RandomForest,
"xgboost": XGBoost,
}
6 changes: 3 additions & 3 deletions asreview2-optuna/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
VERSION = 1
METRIC = "ndcg" # Options: "loss", "ndcg"
STUDY_SET = "full"
CLASSIFIER_TYPE = "svm" # Options: "nb", "log", "svm", "rf"
FEATURE_EXTRACTOR_TYPE = "tfidf" # Options: "tfidf", "onehot", "labse", "bge-m3", "stella", "mxbai"
CLASSIFIER_TYPE = "xgboost" # Options: "nb", "log", "svm", "rf", "xgboost"
FEATURE_EXTRACTOR_TYPE = "mxbai" # Options: "tfidf", "onehot", "labse", "bge-m3", "stella", "mxbai"
PICKLE_FOLDER_PATH = Path("synergy-dataset", f"pickles_{FEATURE_EXTRACTOR_TYPE}")
PRE_PROCESSED_FMS = True # False = on the fly
PARALLELIZE_OBJECTIVE = True
Expand Down Expand Up @@ -248,7 +248,7 @@ def download_pickles(report_order):

if __name__ == "__main__":
# list of studies
studies = pd.read_json(f"synergy_studies_{STUDY_SET}.jsonl", lines=True)
studies = pd.read_json(f"synergy_studies_{STUDY_SET}.jsonl", lines=True).head(1)
report_order = sorted(set(studies["dataset_id"]))

if PRE_PROCESSED_FMS:
Expand Down
3 changes: 2 additions & 1 deletion asreview2-optuna/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ synergy_dataset
numpy
pandas
psycopg2-binary
sentence-transformers
sentence-transformers
xgboost

0 comments on commit a65d435

Please sign in to comment.