Skip to content

Updated requirements.txt and .gitignore #99

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
.py[cod]
*.jsonl
__pycache__/
# Data Files
data/
*.jsonl
annotated/

# Environments
venv/
my_env/

# Other
.py[cod]
__pycache__/
.DS_Store
75 changes: 63 additions & 12 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,89 @@
from lib.dbengine import DBEngine
from lib.query import Query
from lib.common import count_lines
import copy
import pandas as pd


if __name__ == '__main__':
def evaluate_wikisql():
parser = ArgumentParser()
parser.add_argument('source_file', help='source file for the prediction')
parser.add_argument('db_file', help='source database for the prediction')
parser.add_argument('pred_file', help='predictions by the model')
parser.add_argument('csv_file_location')
parser.add_argument('--ordered', action='store_true', help='whether the exact match should consider the order of conditions')
args = parser.parse_args()

engine = DBEngine(args.db_file)
exact_match = []
exact_match_ddb = []
incorrect_answer = []
incorrect_pred = []
correct_answer = []
with open(args.source_file) as fs, open(args.pred_file) as fp:
grades = []
grades_ddb = []
count = 0
for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.source_file)):
eg = json.loads(ls)
ep = json.loads(lp)
qg = Query.from_dict(eg['sql'], ordered=args.ordered)
ddb = copy.deepcopy(ep) # Copy the prediction
qg = Query.from_dict(eg['sql'], ordered=args.ordered)
gold = engine.execute_query(eg['table_id'], qg, lower=True)
pred = ep.get('error', None)
qp = None
if not ep.get('error', None):
try:
qp_ddb = Query.from_dict(ddb['ddb_query'], ordered=args.ordered)
pred_ddb = engine.execute_query(eg['table_id'], qp_ddb, lower=True)
except Exception as e:
pred_ddb = repr(e)
try:
qp = Query.from_dict(ep['query'], ordered=args.ordered)
pred = engine.execute_query(eg['table_id'], qp, lower=True)
except Exception as e:
pred = repr(e)
correct = pred == gold
match = qp == qg
grades.append(correct)
exact_match.append(match)
print(json.dumps({
'ex_accuracy': sum(grades) / len(grades),
'lf_accuracy': sum(exact_match) / len(exact_match),
}, indent=2))

# This is the correct output
correct = pred == gold # This compares the query output
correct_ddb = pred_ddb == gold
# This is the correct query
match = qp == qg # qg is the query and qp is the prediction query
match_ddb = qp_ddb == qg
grades.append(correct) # Query output
grades_ddb.append(correct_ddb)
exact_match.append(match) # SQL query itself
exact_match_ddb.append(match_ddb)
# if count == 24:
# print('Question num: ', str(count))
# print('ex_accuracy: ', str(correct))
# print('Pred: ', str(pred))
# print('Gold: ', str(gold))
# print('lf_accuracy: ', str(match))
# print('Pred: ', str(qp))
# print('Gold: ', str(qg))
if match == 0:
incorrect_answer.append(f'dev_{count}')
incorrect_pred.append(qp)
correct_answer.append(qg)
if correct != correct_ddb:
print('Question num: ', str(count))

count += 1
result_list_dict = {
'Incorrect answer question #': incorrect_answer,
'Incorrect answer value prediction': incorrect_pred,
'Correct answer value': correct_answer
}
result_list_df = pd.DataFrame(result_list_dict)
result_list_df.to_csv(args.csv_file_location)
output = json.dumps({
'incorrect_ex_questions': [i for i, x in enumerate(grades) if x == 0],
'incorrect_lf_questions': [i for i, x in enumerate(exact_match) if x == 0],
'ex_accuracy': sum(grades) / len(grades), # Compare query output
'lf_accuracy': sum(exact_match) / len(exact_match), # Compare SQL query itself
'ddb_ex_accuracy': sum(grades_ddb) / len(grades_ddb),
'ddb_lf_accuracy': sum(exact_match_ddb) / len(exact_match_ddb),
}, indent=2)
print(output)

return output
evaluate_wikisql()
4 changes: 4 additions & 0 deletions lib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .common import *
from .dbengine import *
from .table import *
from .query import *
5 changes: 4 additions & 1 deletion lib/dbengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def execute(self, table_id, select_index, aggregation_index, conditions, lower=T
for tup in schema_str.split(', '):
c, t = tup.split()
schema[c] = t
select = 'col{}'.format(select_index)
if select_index == '*':
select = select_index
else:
select = 'col{}'.format(select_index)
agg = Query.agg_ops[aggregation_index]
if agg:
select = '{}({})'.format(agg, select)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ tqdm
records
babel
tabulate
SQLAlchemy==1.3
65 changes: 65 additions & 0 deletions temp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#!/usr/bin/env python
import json
from argparse import ArgumentParser
from tqdm import tqdm
from lib.dbengine import DBEngine
from lib.query import Query
from lib.common import count_lines

if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('source_file', help='source file for the prediction')
parser.add_argument('db_file', help='source database for the prediction')
parser.add_argument('pred_file', help='predictions by the model')
parser.add_argument('--ordered', action='store_true', help='whether the exact match should consider the order of conditions')
args = parser.parse_args()

engine = DBEngine(args.db_file)
exact_match = []
exact_match_ddb = []
with open(args.source_file) as fs, open(args.pred_file) as fp:
grades = []
grades_ddb = []
for ls, lp in tqdm(zip(fs, fp), total=count_lines(args.source_file)):
eg = json.loads(ls)
ep = json.loads(lp)
print('HEY!!!')
ddb = ep # Copy the prediction
qg = Query.from_dict(eg['sql'], ordered=args.ordered)
gold = engine.execute_query(eg['table_id'], qg, lower=True)
print("GOLD", gold)
pred = ep.get('error', None)
qp = None
if not ep.get('error', None):
try:
print('HEY@')
# If SELECT * is used with an agg function, then set to the correctly selected column
if ep['query']['sel'] == '*' and eg['sql']['agg'] > 0:
ddb['query']['sel'] = eg['sql']['sel']

qp = Query.from_dict(ep['query'], ordered=args.ordered)
pred = engine.execute_query(eg['table_id'], qp, lower=True)

qp_ddb = Query.from_dict(ddb['query'], ordered=args.ordered)
pred_ddb = engine.execute_query(eg['table_id'], qp_ddb, lower=True)
print('HEY')

except Exception as e:
pred = repr(e)
# This is the correct output
correct = pred == gold
correct_ddb = pred_ddb == gold
# This is the correct query
match = qp == qg
match_ddb = qp_ddb == qg
grades.append(correct)
grades_ddb.append(correct)
exact_match.append(match)
exact_match_ddb.append(match)
print('Here are the incorrect questions:', [i for i, x in enumerate(grades, start=1) if x == 0])
print(json.dumps({
'ex_accuracy': sum(grades) / len(grades),
'lf_accuracy': sum(exact_match) / len(exact_match),
'ddb_ex_accuracy': sum(grades_ddb) / len(grades_ddb),
'ddb_lf_accuracy': sum(exact_match_ddb) / len(exact_match_ddb),
}, indent=2))