-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_google.py
executable file
·66 lines (54 loc) · 2.41 KB
/
run_google.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#!/usr/bin/env -S python3 -u
import re
import os
import csv
import sys
import codecs
import argparse
import time
from collections import defaultdict
import google.generativeai as genai
DEFAULT_SYSTEM_PROMPT="You are a master of logical thinking. You carefully analyze the premises step by step, take detailed notes and draw intermediate conclusions based on which you can find the final answer to any question."
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", help="Model name.", required=True)
parser.add_argument("-s", "--system-prompt", help="Use given system prompt. By default, the system prompt is not used. When this option is passed without a value, the default system prompt value is used: " + repr(DEFAULT_SYSTEM_PROMPT), const=DEFAULT_SYSTEM_PROMPT, default=None, nargs='?')
args = parser.parse_args()
model_name = args.model
system_prompt = args.system_prompt
quiz_reader = csv.reader(sys.stdin, delimiter=',', quotechar='"')
correct_answers = defaultdict(lambda: 0)
incorrect_answers = defaultdict(lambda: 0)
missing_answers = defaultdict(lambda: 0)
all_answers = defaultdict(lambda: 0)
model = genai.GenerativeModel(model_name, system_instruction=system_prompt)
for distance, relation_name, correct_answer, quiz in quiz_reader:
quiz = codecs.escape_decode(bytes(quiz, "utf-8"))[0].decode("utf-8")
if system_prompt:
print(f"System prompt: {system_prompt}")
print(f"User prompt: {quiz}")
while True:
try:
response = model.generate_content(quiz)
model_response = response.text
break
except Exception as ex:
print(ex)
pass
print(f"Response: {model_response}")
all_answers[relation_name] += 1
matches = re.findall(r'<ANSWER>(.*?)</ANSWER>', model_response)
if matches:
if correct_answer == matches[0].strip():
correct_answers[relation_name] += 1
else:
incorrect_answers[relation_name] += 1
else:
missing_answers[relation_name] += 1
time.sleep(6)
for relation_name in all_answers.keys():
num_correct = correct_answers[relation_name]
num_incorrect = incorrect_answers[relation_name]
num_missing = missing_answers[relation_name]
num_all = all_answers[relation_name]
percent_correct = 100 * num_correct / num_all
print(f"{relation_name}: {percent_correct:.2f} (C: {num_correct}, I: {num_incorrect}, M: {num_missing} A: {num_all})")