-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathdatabase.py
100 lines (86 loc) · 3.17 KB
/
database.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# INITIALIZE DATABASE
#####################################################################################
import re
import pandas as pd
from envs import *
from tqdm import tqdm
from haystack.schema import Document
from haystack.nodes import PreProcessor
from haystack.document_stores import InMemoryDocumentStore
from qdrant_haystack import QdrantDocumentStore
def initialize_db(args):
print("[+] Initialize database...")
if args.dev:
document_store = InMemoryDocumentStore(
use_gpu=False, use_bm25=ENABLE_BM25, embedding_dim=EMBEDDING_DIM
)
else:
document_store = QdrantDocumentStore(
url=QDRANTDB_URL,
embedding_dim=EMBEDDING_DIM,
timeout=DB_TIMEOUT,
embedding_field="embedding",
hnsw_config={"m": 128, "ef_construct": 100},
similarity="cosine",
recreate_index=args.reindex,
)
processor = PreProcessor(
clean_empty_lines=True,
clean_whitespace=True,
clean_header_footer=True,
remove_substrings=None,
split_by="passage",
split_length=1,
split_respect_sentence_boundary=False,
split_overlap=0,
max_chars_check=10000,
)
if args.reindex:
if FAQ_FILE.endswith(".csv"):
faq_df = pd.read_csv(FAQ_FILE)
elif FAQ_FILE.endswith(".json"):
faq_df = pd.read_json(FAQ_FILE)
if not ("query" in faq_df.columns and "answer" in faq_df.columns):
raise KeyError("FAQ file must have two keys 'query' and 'answer'")
if WEB_FILE.endswith(".csv"):
web_df = pd.read_csv(WEB_FILE)
elif WEB_FILE.endswith(".json"):
web_df = pd.read_json(WEB_FILE)
if not ("text" in web_df.columns and "tables" in web_df.columns):
raise KeyError("WEB file must have two keys 'text' and 'tables'")
if args.dev:
faq_df = faq_df.head(10)
web_df = web_df.head(20)
faq_documents = []
idx = 0
for _, d in tqdm(faq_df.iterrows(), desc="Loading FAQ..."):
content = d["query"]
faq_documents.append(
Document(content=content, id=idx, meta={"answer": d["answer"]})
)
idx += 1
faq_documents = processor.process(faq_documents)
document_store.write_documents(
documents=faq_documents,
index="faq",
batch_size=DB_BATCH_SIZE,
)
web_documents = []
idx = 0
for _, d in tqdm(web_df.iterrows(), desc="Loading web data..."):
content = d["text"]
web_documents.append(Document(content=content, id=idx))
idx += 1
if len(d["tables"]) > 0:
for table in d["tables"]:
web_documents.append(
Document(content=table, content_type="table", id=idx)
)
idx += 1
web_documents = processor.process(web_documents)
document_store.write_documents(
documents=web_documents,
index="web",
batch_size=DB_BATCH_SIZE,
)
return document_store, processor