Skip to content

Commit

Permalink
connected the pipeline (#65) (#93)
Browse files Browse the repository at this point in the history
* connected the pipeline

* minor changes to api code
  • Loading branch information
rosequ authored Dec 8, 2017
1 parent 68e0ef4 commit 85f35bb
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 38 deletions.
75 changes: 50 additions & 25 deletions anserini_dependency/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from flask import Flask, jsonify, request
# FIXME: separate this out to a classifier class where we can switch out the models

from RetrieveSentences import RetrieveSentences
from anserini_dependency.RetrieveSentences import RetrieveSentences
from sm_cnn.bridge import SMModelBridge

app = Flask(__name__)
rs = None
rs = None
smmodel = None
idf_json = None

@app.route("/", methods=['GET'])
def hello():
Expand Down Expand Up @@ -41,19 +43,36 @@ def wit_ai_config():
def get_answers(question, num_hits, k):

parser = argparse.ArgumentParser(description='Retrieve Sentences')
parser.add_argument("-index", help="Lucene index", required=True)
parser.add_argument("-embeddings", help="Path of the word2vec index", default="")
parser.add_argument("-topics", help="topics file", default="")
parser.add_argument("-query", help="a single query", default="")
parser.add_argument("-hits", help="max number of hits to return", default=100)
parser.add_argument("-scorer", help="passage scores", default="Idf")
parser.add_argument("-k", help="top-k passages to be retrieved", default=1)
args_raw = parser.parse_args(["-query", question, "-hits", str(num_hits), "-scorer",
"Idf", "-k", str(k), "-index", app.config['Flask']['index']])
parser.add_argument("--index", help="Lucene index", required=True)
parser.add_argument("--embeddings", help="Path of the word2vec index", default="")
parser.add_argument("--topics", help="topics file", default="")
parser.add_argument("--query", help="a single query", default="")
parser.add_argument("--hits", help="max number of hits to return", default=100)
parser.add_argument("--scorer", help="passage scores", default="Idf")
parser.add_argument("--k", help="top-k passages to be retrieved", default=1)
parser.add_argument('--model', help="the path to the saved model file")
parser.add_argument('--dataset', help="the QA dataset folder {TrecQA|WikiQA}", default='../../data/TrecQA/')
parser.add_argument('--no_cuda', action='store_false', help='do not use cuda', dest='cuda')
parser.add_argument('--gpu', type=int, default=0) # Use -1 for CPU
parser.add_argument('--seed', type=int, default=3435)

arg_list = ["--query", question, "--hits", str(num_hits), "--scorer", "Idf", "--k", str(k),
"--index", app.config['Flask']['index'], "--model", app.config['Flask']['model'],
"--gpu", str(app.config['Flask']['gpu']), "--seed", str(app.config['Flask']['seed']),
"--dataset", app.config['Flask']['dataset']]

if not app.config['Flask']['cuda']:
arg_list.append("--no_cuda")

args_raw = parser.parse_args(arg_list)

global rs
global smmodel
global idf_json

if rs == None:
rs = RetrieveSentences(args_raw)
idf_json = rs.getTermIdfJSON()
candidate_passages_scores = rs.getRankedPassages(question, app.config['Flask']['index'], num_hits, k)

candidate_sent_scores = []
Expand All @@ -64,18 +83,10 @@ def get_answers(question, num_hits, k):
candidate_passages_sm.append(ps_split[0])
candidate_sent_scores.append((float(ps_split[1]), ps_split[0]))

if app.config['Flask']['model'] == "sm":
path_to_castorini = os.getcwd() + "/.."
model = SMModelBridge(path_to_castorini + '/models/sm_model/sm_model.fixed_ext_feats_paper.puncts_stay',
path_to_castorini + '/data/word2vec/aquaint+wiki.txt.gz.ndim=50.cache',
app.config['Flask']['index'])

idf_json = rs.getTermIdfJSON()
flags = {
"punctuation": "", # ignoring for now you can {keep|remove} punctuation
"dash_words": "" # ignoring for now. you can {keep|split} words-with-hyphens
}
answers_list = model.rerank_candidate_answers(question, candidate_passages_sm, idf_json, flags)
if app.config['Flask']['reranker'] == "sm":
if smmodel == None:
smmodel = SMModelBridge(args_raw)
answers_list = smmodel.rerank_candidate_answers(question, candidate_passages_scores, idf_json)
sorted_answers = sorted(answers_list, key=lambda x: x[0], reverse=True)
else:
# the re-ranking model chosen is idf
Expand All @@ -86,14 +97,23 @@ def get_answers(question, num_hits, k):
for score, sent in sorted_answers:
answers.append({'passage': sent, 'score': score})

print(answers)
return answers

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Start the Flask API at the specified host, port')
parser.add_argument('--config', help='config to use', required=False, type=str, default='config.cfg')
parser.add_argument("--debug", help="print debug info", action="store_true")
parser.add_argument("--model", help="[idf|sm]", default="idf")
parser.add_argument("--reranker", help="[idf|sm]", default="idf")
parser.add_argument('--model', help="the path to the saved model file")
parser.add_argument('--no_cuda', action='store_false', help='do not use cuda', dest='cuda')
parser.add_argument('--gpu', type=int, default=0) # Use -1 for CPU
parser.add_argument('--seed', type=int, default=3435)
parser.add_argument('--dataset', help="the QA dataset folder {TrecQA|WikiQA}", default='../../data/TrecQA/')

args = parser.parse_args()
if not args.cuda:
args.gpu = -1

if not os.path.isfile(args.config):
print("The configuration file ({}) does not exist!".format(args.config))
Expand All @@ -110,13 +130,18 @@ def get_answers(question, num_hits, k):
for key, value in config.items(name):
app.config[name][key] = value

app.config['Flask']['reranker'] = args.reranker
app.config['Flask']['model'] = args.model
app.config['Flask']['cuda'] = args.cuda
app.config['Flask']['gpu'] = str(args.gpu)
app.config['Flask']['seed'] = str(args.seed)
app.config['Flask']['dataset'] = str(args.dataset)

print("Config: {}".format(args.config))
print("Index: {}".format(app.config['Flask']['index']))
print("Host: {}".format(app.config['Flask']['host']))
print("Port: {}".format(app.config['Flask']['port']))
print("Re-ranking Model: {}".format(app.config['Flask']['model']))
print("Re-ranking Model: {}".format(app.config['Flask']['reranker']))
print("Debug info: {}".format(args.debug))

app.run(debug=args.debug, host=app.config['Flask']['host'], port=int(app.config['Flask']['port']))
28 changes: 15 additions & 13 deletions sm_cnn/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from nltk.tokenize import TreebankWordTokenizer
from torchtext import data

from sm_cnn import model
from sm_cnn.external_features import compute_overlap, compute_idf_weighted_overlap, stopped
from sm_cnn.trec_dataset import TrecDataset
from sm_cnn.wiki_dataset import WikiDataset
from anserini_dependency.RetrieveSentences import RetrieveSentences
from sm_cnn import model

sys.modules['model'] = model

Expand All @@ -38,9 +38,8 @@ def __init__(self, args):
self.QUESTION = data.Field(batch_first=True)
self.ANSWER = data.Field(batch_first=True)
self.LABEL = data.Field(sequential=False)
self.EXTERNAL = data.Field(sequential=False, tensor_type=torch.FloatTensor, batch_first=True, use_vocab=False,
preprocessing=data.Pipeline(lambda arr, _, train: [float(y) for y in arr]))

self.EXTERNAL = data.Field(sequential=True, tensor_type=torch.FloatTensor, batch_first=True, use_vocab=False,
postprocessing=data.Pipeline(lambda arr, _, train: [float(y) for y in arr]))

if 'TrecQA' in args.dataset:
train, dev, test = TrecDataset.splits(self.QID, self.QUESTION, self.ANSWER, self.EXTERNAL, self.LABEL)
Expand All @@ -54,31 +53,32 @@ def __init__(self, args):
self.QUESTION.build_vocab(train, dev, test)
self.ANSWER.build_vocab(train, dev, test)
self.LABEL.build_vocab(train, dev, test)
self.retrieveSentencesObj = RetrieveSentences(args)
self.idf_json = self.retrieveSentencesObj.getTermIdfJSON()

if args.cuda:
self.model = torch.load(args.model, map_location=lambda storage, location: storage.cuda(args.gpu))
else:
self.model = torch.load(args.model, map_location=lambda storage, location: storage)

self.gpu = args.gpu

def parse(self, sentence):
s_toks = TreebankWordTokenizer().tokenize(sentence)
sentence = ' '.join(s_toks).lower()
return sentence

def rerank_candidate_answers(self, question, answers):
def rerank_candidate_answers(self, question, answers, idf_json):
# run through the model
scores_sentences = []
question = self.parse(question)
term_idfs = json.loads(self.idf_json)
term_idfs = json.loads(idf_json)
term_idfs = dict((k, float(v)) for k, v in term_idfs.items())

for term in question.split():
if term not in term_idfs:
term_idfs[term] = 0.0

for answer in answers:
answer = answer.split('\t')[0]
answer = self.parse(answer)
for term in answer.split():
if term not in term_idfs:
Expand All @@ -96,12 +96,12 @@ def rerank_candidate_answers(self, question, answers):

fields = [('question', self.QUESTION), ('answer', self.ANSWER), ('ext_feat', self.EXTERNAL)]
example = data.Example.fromlist([question, answer, ext_feats], fields)
this_question = self.QUESTION.numericalize(self.QUESTION.pad([example.question]), args.gpu)
this_answer = self.ANSWER.numericalize(self.ANSWER.pad([example.answer]), args.gpu)
this_external = self.EXTERNAL.numericalize(self.EXTERNAL.pad([example.ext_feat]), args.gpu)
this_question = self.QUESTION.numericalize(self.QUESTION.pad([example.question]), self.gpu)
this_answer = self.ANSWER.numericalize(self.ANSWER.pad([example.answer]), self.gpu)
this_external = self.EXTERNAL.numericalize(self.EXTERNAL.pad([example.ext_feat]), self.gpu)
self.model.eval()
scores = self.model(this_question, this_answer, this_external)
scores_sentences.append((scores[:, 2].cpu().data.numpy(), answer))
scores_sentences.append((scores[:, 2].cpu().data.numpy()[0].tolist(), answer))

return scores_sentences

Expand All @@ -127,6 +127,8 @@ def rerank_candidate_answers(self, question, answers):
if not args.cuda:
args.gpu = -1

retrieveSentencesObj = RetrieveSentences(args)
idf_json = retrieveSentencesObj.getTermIdfJSON()
smmodel = SMModelBridge(args)

train_set, dev_set, test_set = 'train', 'dev', 'test'
Expand All @@ -152,7 +154,7 @@ def rerank_candidate_answers(self, question, answers):
num_answers = q_counts[question]
q_answers = answers[answers_offset: answers_offset + num_answers]
answers_offset += num_answers
sentence_scores = smmodel.rerank_candidate_answers(question, q_answers)
sentence_scores = smmodel.rerank_candidate_answers(question, q_answers, idf_json)

for score, sentence in sentence_scores:
print('{} Q0 {} 0 {} sm_cnn_bridge.{}.run'.format(
Expand Down

0 comments on commit 85f35bb

Please sign in to comment.