-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into ASReview2-rf
- Loading branch information
Showing
9 changed files
with
378 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
221 changes: 221 additions & 0 deletions
221
asreview2-optuna/completed_runs/optuna_output_analysis.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import optuna\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"from collections import defaultdict\n", | ||
"import pandas as pd\n", | ||
"import synergy_dataset as sd\n", | ||
"from IPython.display import display\n", | ||
"\n", | ||
"# Path to your SQLite3 database\n", | ||
"db_path = \"sqlite:///svm_db.sqlite3\" # Replace with your database path\n", | ||
"\n", | ||
"# Get all study summaries\n", | ||
"study_summaries = optuna.get_all_study_summaries(storage=db_path)\n", | ||
"\n", | ||
"for summary in study_summaries:\n", | ||
" print(f\"- {summary.study_name}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"study_name = \"ASReview2 2024-12-20 at 14.49.22\"\n", | ||
"study = optuna.load_study(study_name=study_name, storage=db_path)\n", | ||
"print(study.trials[0].params)\n", | ||
"\n", | ||
"dataset_names = []\n", | ||
"for i in sd.iter_datasets():\n", | ||
" if i.name != \"Chou_2004\":\n", | ||
" dataset_names.append(i.name)\n", | ||
"\n", | ||
"dataset_names.sort()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Prepare data for visualization\n", | ||
"data = []\n", | ||
"\n", | ||
"for trial in study.trials:\n", | ||
" if trial.intermediate_values:\n", | ||
" for dataset_id, value in enumerate(trial.intermediate_values.values()):\n", | ||
" params = trial.params # Extract trial parameters\n", | ||
" # Record dataset_id, loss (intermediate value), and parameters\n", | ||
" data.append({\n", | ||
" \"dataset_id\": dataset_id,\n", | ||
" \"loss\": value,\n", | ||
" \"ratio\": params.get(\"ratio\", None),\n", | ||
" \"c\": params.get(\"log__C\", None)\n", | ||
" })\n", | ||
"\n", | ||
"# Convert to pandas DataFrame\n", | ||
"df = pd.DataFrame(data)\n", | ||
"\n", | ||
"# Initialize variables to store the best trial per dataset\n", | ||
"num_datasets = len(study.trials[0].intermediate_values) # Assuming all trials have the same number of datasets\n", | ||
"best_trials_per_dataset = [None] * num_datasets # Store best trial numbers\n", | ||
"best_losses_per_dataset = [float(\"inf\")] * num_datasets # Store best loss values\n", | ||
"best_params_per_dataset = [None] * num_datasets # Store best trial parameters\n", | ||
"\n", | ||
"# Loop through all trials to find the best trial for each dataset\n", | ||
"for trial in study.trials:\n", | ||
" if trial.intermediate_values:\n", | ||
" # Iterate through each dataset (position in the intermediate_values list)\n", | ||
" for dataset_id, loss in enumerate(trial.intermediate_values.values()):\n", | ||
" if loss < best_losses_per_dataset[dataset_id]:\n", | ||
" # Update the best trial info for this dataset\n", | ||
" best_losses_per_dataset[dataset_id] = loss\n", | ||
" best_trials_per_dataset[dataset_id] = trial.number\n", | ||
" best_params_per_dataset[dataset_id] = trial.params" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Convert the dictionary to a pandas DataFrame\n", | ||
"df = pd.DataFrame(list(study.best_trial.intermediate_values.items()), columns=[\"Dataset\", \"Mean Loss\"])\n", | ||
"# Rename the rows to indicate the dataset number\n", | ||
"df.index = [dataset_names[i] for i in range(len(best_params_per_dataset))]\n", | ||
"df.drop(\"Dataset\", inplace=True, axis=1)\n", | ||
"\n", | ||
"display(df)\n", | ||
"\n", | ||
"# Plot the values (optional)\n", | ||
"df.plot(kind=\"bar\", figsize=(10, 6), legend=False)\n", | ||
"plt.title(\"Mean Losses per Dataset\")\n", | ||
"plt.xlabel(\"Dataset\")\n", | ||
"plt.ylabel(\"Mean Loss\")\n", | ||
"plt.tight_layout()\n", | ||
"plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"# Create a pandas DataFrame\n", | ||
"df = pd.DataFrame(best_params_per_dataset)\n", | ||
"\n", | ||
"# Rename the rows to indicate the dataset number\n", | ||
"df.index = [dataset_names[i] for i in range(len(best_params_per_dataset))]\n", | ||
"\n", | ||
"display(df)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Create a pandas DataFrame\n", | ||
"df = pd.DataFrame(best_params_per_dataset)\n", | ||
"\n", | ||
"# Plot each parameter separately\n", | ||
"num_params = len(df.columns)\n", | ||
"fig, axes = plt.subplots(num_params, 1, figsize=(8, num_params * 2.5), sharex=False)\n", | ||
"\n", | ||
"for idx, param in enumerate(df.columns):\n", | ||
" ax = axes[idx]\n", | ||
" ax.plot(dataset_names, df[param], marker='o', linestyle='-', color='b', alpha=0.8, label=param)\n", | ||
" ax.set_title(param, fontsize=10)\n", | ||
" ax.set_ylabel(\"Value\", fontsize=8)\n", | ||
" ax.grid(axis=\"y\", linestyle=\"--\", alpha=0.6)\n", | ||
" ax.tick_params(axis=\"y\", labelsize=8)\n", | ||
" ax.legend(fontsize=8, loc=\"upper left\")\n", | ||
" \n", | ||
" # Set dataset names as x-tick labels for each plot\n", | ||
" ax.set_xticks(dataset_names) # Setting positions explicitly\n", | ||
" ax.set_xticklabels(dataset_names, fontsize=8, rotation=90) # Setting labels\n", | ||
"\n", | ||
"# Add x-axis label only to the bottom subplot\n", | ||
"axes[-1].set_xlabel(\"Datasets\", fontsize=10)\n", | ||
"\n", | ||
"# Adjust layout for better spacing\n", | ||
"plt.tight_layout()\n", | ||
"\n", | ||
"# Save or show the plot\n", | ||
"plt.savefig(\"parameter_comparison_lineplots_all_xticks_fixed.pdf\", bbox_inches=\"tight\", dpi=300)\n", | ||
"plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Extract intermediate values grouped by dataset_id\n", | ||
"dataset_intermediate_values = defaultdict(list)\n", | ||
"\n", | ||
"for trial in study.trials:\n", | ||
" if trial.intermediate_values:\n", | ||
" # Distribute intermediate values by dataset_id (index in the list)\n", | ||
" for dataset_id, value in enumerate(trial.intermediate_values.values()):\n", | ||
" dataset_intermediate_values[dataset_id].append(value)\n", | ||
"\n", | ||
"# Prepare data for boxplots\n", | ||
"datasets = list(dataset_intermediate_values.keys())\n", | ||
"boxplot_data = [dataset_intermediate_values[dataset_id] for dataset_id in datasets]\n", | ||
"\n", | ||
"# Plot boxplots\n", | ||
"plt.figure(figsize=(12, 6))\n", | ||
"plt.boxplot(boxplot_data, labels=dataset_names, \n", | ||
" showmeans=True, patch_artist=True)\n", | ||
"plt.xlabel(\"Dataset\")\n", | ||
"plt.ylabel(\"Loss\")\n", | ||
"plt.title(f\"Boxplot of Losses for Each Dataset {study_name}\")\n", | ||
"plt.grid(axis=\"y\", linestyle=\"--\", alpha=0.7)\n", | ||
"plt.xticks(rotation=90) # Rotate dataset names for better readability\n", | ||
"plt.tight_layout()\n", | ||
"plt.ylim((0, 0.3))\n", | ||
"\n", | ||
"# Show the plot\n", | ||
"plt.tight_layout()\n", | ||
"plt.savefig(f\"boxplot_per_dataset_{study_name}.pdf\")\n", | ||
"plt.show()\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "asreview-2.0", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.8" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import pickle | ||
from pathlib import Path | ||
|
||
import pandas as pd | ||
import synergy_dataset as sd # Assuming this is your custom dataset handler | ||
import torch | ||
from sentence_transformers import SentenceTransformer | ||
from tqdm import tqdm | ||
|
||
FORCE = False | ||
|
||
# Folder to save embeddings | ||
folder_pickle_files = Path("synergy-dataset", "pickles_labse") | ||
folder_pickle_files.mkdir(parents=True, exist_ok=True) | ||
|
||
# Load LaBSE model | ||
model = SentenceTransformer("sentence-transformers/LaBSE") | ||
|
||
# Check if CUDA is available | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
print(f"Using device: {device}") | ||
|
||
# Loop through datasets | ||
for dataset in tqdm(sd.iter_datasets(), total=26): | ||
if dataset.name == "Moran_2021": | ||
df = pd.read_csv("../datasets/Moran_2021_corrected_shuffled_raw.csv") | ||
else: | ||
# Convert dataset to a DataFrame and reset index | ||
df = dataset.to_frame().reset_index() | ||
|
||
# Combine 'title' and 'abstract' text | ||
combined_texts = (df["title"].fillna("") + " " + df["abstract"].fillna("")).tolist() | ||
|
||
dataset_name = ( | ||
dataset.name if dataset.name != "Moran_2021" else "Moran_2021_corrected" | ||
) | ||
pickle_file_path = folder_pickle_files / f"{dataset_name}.pkl" | ||
|
||
# Check if the pickle file already exists | ||
if not FORCE and pickle_file_path.exists(): | ||
print(f"Skipping {dataset_name}, pickle file already exists.") | ||
continue | ||
|
||
# Generate embeddings | ||
X = model.encode( | ||
combined_texts, batch_size=64, show_progress_bar=False, device=device | ||
) | ||
|
||
# Save embeddings and labels as a pickle file | ||
with open(folder_pickle_files / f"{dataset_name}.pkl", "wb") as f: | ||
pickle.dump( | ||
( | ||
X, | ||
df["label_included"].tolist(), | ||
), | ||
f, | ||
) |
Oops, something went wrong.