-
Notifications
You must be signed in to change notification settings - Fork 198
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add API interface for model baseline result
- Loading branch information
Showing
4 changed files
with
237 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
""" | ||
dbgpt_hub.baseline | ||
============== | ||
""" | ||
|
||
from .show_result_api import show_all | ||
from .show_result_api import show_model | ||
|
||
__all__ = ["show_all", "show_model"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
{ | ||
"spider": { | ||
"llama2-7b-hf": { | ||
"base": { | ||
"alpaca":{ | ||
"instruction": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\\n\\\"\\n##Instruction:\\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\\nThe head_ID of management is the foreign key of head_ID of head.\\nThe department_ID of management is the foreign key of Department_ID of department.\\n\\n", | ||
"acc": { | ||
"ex":{ | ||
"easy": 0.1, | ||
"medium": 0.1, | ||
"hard": 0.1, | ||
"extra": 0.1, | ||
"all": 0.1 | ||
}, | ||
"em":{ | ||
"easy": 0.1, | ||
"medium": 0.1, | ||
"hard": 0.1, | ||
"extra": 0.1, | ||
"all": 0.1 | ||
} | ||
} | ||
}, | ||
"openai":{ | ||
"instruction": "openai-instruction", | ||
"acc": { | ||
"ex":{ | ||
"easy": 0.887, | ||
"medium": 0.711, | ||
"hard": 0.575, | ||
"extra": 0.380, | ||
"all": 0.677 | ||
}, | ||
"em":{ | ||
"easy": 0.887, | ||
"medium": 0.711, | ||
"hard": 0.575, | ||
"extra": 0.380, | ||
"all": 0.677 | ||
} | ||
} | ||
} | ||
}, | ||
"lora": { | ||
"alpaca":{ | ||
"instruction": "test", | ||
"acc": { | ||
"ex":{ | ||
"easy": 0.887, | ||
"medium": 0.711, | ||
"hard": 0.575, | ||
"extra": 0.380, | ||
"all": 0.677 | ||
}, | ||
"em":{ | ||
"easy": 0.887, | ||
"medium": 0.711, | ||
"hard": 0.575, | ||
"extra": 0.380, | ||
"all": 0.677 | ||
} | ||
} | ||
} | ||
}, | ||
"qlora": { | ||
|
||
} | ||
, | ||
"ppo":{} | ||
}, | ||
"llama2-7b-chat-hf": { | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import os | ||
import sys | ||
import json | ||
from typing import Optional, Dict, Any | ||
from prettytable import PrettyTable | ||
|
||
|
||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||
sys.path.append(ROOT_PATH) | ||
|
||
baseline_file = "./dbgpt_hub/baseline/baseline.json" | ||
# read json | ||
with open(baseline_file, 'r') as file: | ||
baseline_json = json.load(file) | ||
|
||
def print_models_info(dataset, model, method, prompt): | ||
print_table_models = PrettyTable() | ||
models_header = ['dataset', 'model', 'method', 'prompt'] | ||
models_info = [dataset, model, method, prompt] | ||
print_table_models.field_names = models_header | ||
print_table_models.add_rows([models_info]) | ||
return print_table_models | ||
|
||
|
||
def print_scores_info(acc_data): | ||
print_table_scores = PrettyTable() | ||
scores_header = ['etype', 'easy', 'medium', 'hard', 'extra', 'all'] | ||
print_table_scores.field_names = scores_header | ||
eytpe = "ex" | ||
ex_score = [acc_data[eytpe][key] for key in acc_data[eytpe].keys()] | ||
ex_score.insert(0, eytpe) | ||
eytpe = "em" | ||
em_score = [acc_data[eytpe][key] for key in acc_data[eytpe].keys()] | ||
em_score.insert(0, eytpe) | ||
print_table_scores.add_rows( | ||
[ | ||
ex_score, | ||
em_score | ||
] | ||
) | ||
return print_table_scores | ||
|
||
def show_model( | ||
dataset, | ||
model, | ||
method, | ||
prompt | ||
): | ||
|
||
# 1.get res | ||
acc_data = baseline_json[dataset][model][method][prompt]['acc'] | ||
|
||
# 2.print models info | ||
print_table_models = print_models_info(dataset, model, method, prompt) | ||
print(print_table_models) | ||
|
||
# 3.print scores info | ||
print_table_scores = print_scores_info(acc_data) | ||
print(print_table_scores) | ||
|
||
def show_model_api(args: Optional[Dict[str, Any]] = None): | ||
dataset = args["dataset"] | ||
model = args["model"] | ||
method = args["method"] | ||
prompt = args["prompt"] | ||
|
||
show_model( | ||
dataset, | ||
model, | ||
method, | ||
prompt | ||
) | ||
|
||
def show_all(): | ||
datasets = baseline_json.keys() | ||
for dataset in datasets: | ||
models = baseline_json[dataset].keys() | ||
for model in models: | ||
methods = baseline_json[dataset][model].keys() | ||
for method in methods: | ||
prompts = baseline_json[dataset][model][method].keys() | ||
for prompt in prompts: | ||
# 1.get scores info | ||
acc_data = baseline_json[dataset][model][method][prompt]['acc'] | ||
|
||
# 2.print models info | ||
print_table_models = print_models_info(dataset, model, method, prompt) | ||
print(print_table_models) | ||
|
||
# 3.print scores info | ||
print_table_scores = print_scores_info(acc_data) | ||
print(print_table_scores) | ||
|
||
|
||
|
||
|
||
def show_all_api(): | ||
show_all() | ||
|
||
|
||
|
||
# def update(): | ||
# # todo : 更新baseline.json | ||
# # | ||
|
||
|
||
if __name__ == "__main__": | ||
# args | ||
show_args = { | ||
"dataset" : "spider", | ||
"model" : "llama2-7b-hf", | ||
"method" : "lora", | ||
"prompt" : "alpaca", | ||
} | ||
show_model(show_args) | ||
|
||
show_all() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import Optional, Dict, Any | ||
|
||
from dbgpt_hub.baseline import show_result | ||
|
||
|
||
def show_all(): | ||
show_result.show_all_api() | ||
|
||
def show_model( | ||
args: Optional[Dict[str, Any]] = None | ||
): | ||
# Arguments for show result | ||
if args is None: | ||
args = { | ||
"dataset":"spider", | ||
"model":"llama2-7b-hf", | ||
"sft":"lora", | ||
"prompt":"alpaca", | ||
} | ||
else: | ||
args = args | ||
|
||
show_result.show_model_api(args) | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
show_all() | ||
|
||
show_args = { | ||
"dataset" : "spider", | ||
"model" : "llama2-7b-hf", | ||
"method" : "lora", | ||
"prompt" : "alpaca" | ||
} | ||
show_model(show_args) |