Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API interfaces for train, predict and evaluate processes #160

Merged
merged 6 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -184,4 +184,4 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.idea/
8 changes: 8 additions & 0 deletions dbgpt_hub/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
dbgpt_hub.eval
==============
"""

from .evaluation_api import start_evaluate

__all__ = ["start_evaluate"]
37 changes: 37 additions & 0 deletions dbgpt_hub/eval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import subprocess
import json

from typing import Optional, Dict, Any
from process_sql import get_schema, Schema, get_sql
from exec_eval import eval_exec_match
from func_timeout import func_timeout, FunctionTimedOut
Expand Down Expand Up @@ -1152,6 +1153,42 @@ def build_foreign_key_map_from_json(table):
return tables


def evaluate_api(args: Optional[Dict[str, Any]] = None):
# Prepare output file path by appending "2sql" before ".txt" if --natsql is true
if args["natsql"]:
pred_file_path = (
args["input"].rsplit(".", 1)[0] + "2sql." + args["input"].rsplit(".", 1)[1]
)
gold_file_path = args["gold_natsql"]
table_info_path = args["table_natsql"]
else:
pred_file_path = args["input"]
gold_file_path = args["gold"]
table_info_path = args["table"]

# only evaluating exact match needs this argument
kmaps = None
if args["etype"] in ["all", "match"]:
assert (
args.table is not None
), "table argument must be non-None if exact set match is evaluated"
kmaps = build_foreign_key_map_from_json(args["table"])

# Print args
print(f"params as fllows \n {args}")

evaluate(
gold_file_path,
pred_file_path,
args["db"],
args["etype"],
kmaps,
args["plug_value"],
args["keep_distinct"],
args["progress_bar_for_each_datapoint"],
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down
32 changes: 32 additions & 0 deletions dbgpt_hub/eval/evaluation_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Optional, Dict, Any

from dbgpt_hub.eval import evaluation


def start_evaluate(
args: Optional[Dict[str, Any]] = None,
):
# Arguments for evaluation
if args is None:
args = {
"input": "./dbgpt_hub/output/pred/pred_sql_dev_skeleton.sql",
"gold": "./dbgpt_hub/data/eval_data/gold.txt",
"gold_natsql": "./dbgpt_hub/data/eval_data/gold_natsql2sql.txt",
"db": "./dbgpt_hub/data/spider/database",
"table": "./dbgpt_hub/data/eval_data/tables.json",
"table_natsql": "./dbgpt_hub/data/eval_data/tables_for_natsql2sql.json",
"etype": "exec",
"plug_value": True,
"keep_distict": False,
"progress_bar_for_each_datapoint": False,
"natsql": False,
}
else:
args = args

# Execute evaluation
evaluation.evaluate_api(args)


if __name__ == "__main__":
start_evaluate()
8 changes: 8 additions & 0 deletions dbgpt_hub/predict/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
dbgpt_hub.predict
==============
"""

from .predict_api import start_predict

__all__ = ["start_predict"]
55 changes: 33 additions & 22 deletions dbgpt_hub/predict/predict.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,21 @@
import os
import json
import sys
from tqdm import tqdm

ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
from typing import List, Dict

from tqdm import tqdm
from typing import List, Dict, Optional, Any

from dbgpt_hub.data_process.data_utils import extract_sql_prompt_dataset
from dbgpt_hub.llm_base.chat_model import ChatModel
from dbgpt_hub.configs.config import (
PREDICTED_DATA_PATH,
OUT_DIR,
PREDICTED_OUT_FILENAME,
)


def prepare_dataset() -> List[Dict]:
with open(PREDICTED_DATA_PATH, "r") as fp:
def prepare_dataset(
predict_file_path: Optional[str] = None,
) -> List[Dict]:
with open(predict_file_path, "r") as fp:
data = json.load(fp)
predict_data = [extract_sql_prompt_dataset(item) for item in data]
return predict_data
Expand All @@ -33,21 +31,34 @@ def inference(model: ChatModel, predict_data: List[Dict], **input_kwargs):
return res


def main():
predict_data = prepare_dataset()
def predict(args: Optional[Dict[str, Any]] = None):
predict_file_path = ""
if args is None:
predict_file_path = os.path.join(
ROOT_PATH, "dbgpt_hub/data/eval_data/dev_sql.json"
)
predict_out_dir = os.path.join(
os.path.join(ROOT_PATH, "dbgpt_hub/output/"), "pred"
)
if not os.path.exists(predict_out_dir):
os.mkdir(predict_out_dir)
predict_output_filename = os.path.join(predict_out_dir, "pred_sql.sql")
print(f"predict_output_filename \t{predict_output_filename}")
else:
predict_file_path = os.path.join(ROOT_PATH, args["predict_file_path"])
predict_out_dir = os.path.join(
os.path.join(ROOT_PATH, args["predict_out_dir"]), "pred"
)
if not os.path.exists(predict_out_dir):
os.mkdir(predict_out_dir)
predict_output_filename = os.path.join(predict_out_dir, args["pred_sql.sql"])
print(f"predict_output_filename \t{predict_output_filename}")

predict_data = prepare_dataset(predict_file_path=predict_file_path)
model = ChatModel()
result = inference(model, predict_data)

predict_out_dir = os.path.join(OUT_DIR, "pred")
if not os.path.exists(predict_out_dir):
os.mkdir(predict_out_dir)

predict_output_dir_name = os.path.join(
predict_out_dir, model.data_args.predicted_out_filename
)
print(f"predict_output_dir_name \t{predict_output_dir_name}")

with open(predict_output_dir_name, "w") as f:
with open(predict_output_filename, "w") as f:
for p in result:
try:
f.write(p.replace("\n", " ") + "\n")
Expand All @@ -56,4 +67,4 @@ def main():


if __name__ == "__main__":
main()
predict()
31 changes: 31 additions & 0 deletions dbgpt_hub/predict/predict_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os
from dbgpt_hub.predict import predict
from typing import Optional, Dict, Any


def start_predict(
args: Optional[Dict[str, Any]] = None, cuda_visible_devices: Optional[str] = "0"
):
# Setting CUDA Device
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices

# Default Arguments
if args is None:
args = {
"model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf",
"template": "llama2",
"finetuning_type": "lora",
"checkpoint_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora",
"predict_file_path": "dbgpt_hub/data/eval_data/dev_sql.json",
"predict_out_dir": "dbgpt_hub/output/",
"predicted_out_filename": "pred_sql.sql",
}
else:
args = args

# Execute prediction
predict.predict(args)


if __name__ == "__main__":
start_predict()
8 changes: 8 additions & 0 deletions dbgpt_hub/train/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
dbgpt_hub.train
==============
"""

from .sft_train_api import start_sft

__all__ = ["start_sft"]
47 changes: 47 additions & 0 deletions dbgpt_hub/train/sft_train_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os

from typing import Optional, Dict, Any
from dbgpt_hub.train import sft_train


def start_sft(
args: Optional[Dict[str, Any]] = None, cuda_visible_devices: Optional[str] = "0"
):
# Setting CUDA Device
os.environ["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices

# Default Arguments
if args is None:
args = {
"model_name_or_path": "codellama/CodeLlama-13b-Instruct-hf",
"do_train": True,
"dataset": "example_text2sql_train",
"max_source_length": 2048,
"max_target_length": 512,
"finetuning_type": "lora",
"lora_target": "q_proj,v_proj",
"template": "llama2",
"lora_rank": 64,
"lora_alpha": 32,
"output_dir": "dbgpt_hub/output/adapter/CodeLlama-13b-sql-lora",
"overwrite_cache": True,
"overwrite_output_dir": True,
"per_device_train_batch_size": 1,
"gradient_accumulation_steps": 16,
"lr_scheduler_type": "cosine_with_restarts",
"logging_steps": 50,
"save_steps": 2000,
"learning_rate": 2e-4,
"num_train_epochs": 8,
"plot_loss": True,
"bf16": True,
}
else:
args = args

# Run SFT
sft_train.train(args)


if __name__ == "__main__":
start_sft()
Loading