-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassify.py
133 lines (124 loc) · 5.49 KB
/
classify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Create a class called Classifier.
In the init function, load environment variables and save them to self.
Read davinci_base_prompt.txt and save to self.base_prompt.
"""
import dotenv
import os
import openai
import sqlite3
dotenv.load_dotenv()
class Classifier:
def __init__(self):
self.api_key = os.getenv('OPENAI_API_KEY')
self.database_path = os.path.join("data", os.getenv('SQLITE_DB_NAME'))
self.prompt_version = int(os.getenv('PROMPT_VERSION'))
self.base_prompt = open(os.path.join("prompts", f'davinci_base_prompt_v{self.prompt_version}.txt'), 'r').read()
self.create_classification_table()
self.parsing_error_count = 0
def classify_text(self, text, text_id=None):
"""
Use the OpenAI API davinci model to classify the text.
"""
# TODO: consider top_p to see alternatives considered
# TODO: tinker w/ frequency_penalty / precense penalty
openai.api_key = self.api_key
prompt = self.base_prompt.replace("{text}", text)
response = openai.Completion.create(
model="text-davinci-003",
prompt=prompt,
top_p=1,
temperature=0,
max_tokens=500,
frequency_penalty=0,
presence_penalty=0,
)
if len(response.choices[0].text) == 0:
raise("Empty response from OpenAI, possibly due to stop words")
return response, text_id
def extract_openapi_response(self, response, input_text_id=None):
response_text = response.choices[0].text
possible_delusion = None
excerpt = None
dominant_theme = None
parsing_error = False
try:
if self.prompt_version == 2:
possible_delusion = True if response_text.split("Possible Delusion: ")[1].split("\n")[0] == "true" else False
excerpt = response_text.split("Excerpt: ")[1].split("\n")[0]
dominant_theme = response_text.split("Dominant Theme: ")[1].split("\n")[0]
elif self.prompt_version == 3:
possible_delusion = True if response_text.split("Possible Delusion: ")[1].split("\n")[0] == "true" else False
excerpt = response_text.split("Excerpt: ")[1].split("\n")[0]
elif self.prompt_version == 4:
dominant_theme = response.choices[0].text
elif self.prompt_version == 5:
possible_delusion = True if response_text.split("Possible Delusion: ")[1].split("\n")[0].lower() == "true" else False
excerpt = response_text.split("Excerpt: ")[1].split("\n")[0]
except IndexError:
print(response_text)
parsing_error = True
self.parsing_error_count += 1
response_created = response.created
id = response.id
model = response.model
method = response.object
prompt_tokens = response.usage.prompt_tokens
completion_tokens = response.usage.completion_tokens
total_tokens = response.usage.total_tokens
return (input_text_id, response_text, possible_delusion, excerpt, dominant_theme, parsing_error, response_created, id, model, method, prompt_tokens, completion_tokens, total_tokens, self.prompt_version)
def create_classification_table(self):
conn = sqlite3.connect(self.database_path)
c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS classifications (
input_text_id INT,
full_response_text STRING,
is_possible_delusion BOOLEAN,
excerpt STRING,
dominant_theme STRING,
parsing_error BOOLEAN,
created_ts INT,
response_id STRING,
model STRING,
object STRING,
prompt_tokens INT,
completion_tokens INT,
total_tokens INT,
prompt_version INT,
load_ts DEFAULT CURRENT_TIMESTAMP
)''')
conn.commit()
conn.close()
def classify_batch_and_save(self, batch_size=100):
conn = sqlite3.connect(self.database_path)
c = conn.cursor()
batch_complete = False
while not batch_complete:
c.execute(f"""
SELECT rowid, comment_text FROM comments
WHERE to_classify = 1
AND NOT EXISTS (
SELECT 1 FROM classifications WHERE input_text_id = comments.rowid
AND prompt_version = {self.prompt_version}
)
LIMIT {batch_size}
""")
results = c.fetchall()
print("Executing Batch Size: ", len(results))
if len(results) == 0:
batch_complete = True
break
for row in results:
print("Classifying: ", row)
text_id, text = row
response, text_id = self.classify_text(text, text_id)
response_data = self.extract_openapi_response(response, input_text_id=text_id)
c.execute("""INSERT INTO classifications
(input_text_id, full_response_text, is_possible_delusion, excerpt, dominant_theme, parsing_error, created_ts, response_id, model, object, prompt_tokens, completion_tokens, total_tokens, prompt_version)
VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", response_data)
conn.commit()
if self.parsing_error_count > 5:
raise("Parsing Error Count > 5")
if __name__ == "__main__":
classifier = Classifier()
classifier.classify_batch_and_save()