-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathengine.py
307 lines (257 loc) · 10.4 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
import os
import random
from enum import Enum
from typing import List, Tuple
import streamlit as st
from openai import OpenAI
DEFAULT_SPYMASTER_PROMPT = """The words to guess on your team are: {SLF}.
The words on your opponent's team NOT to guess are: {NTR}.
The neutral words not to guess are: {OPP}.
The forbidden word REALLY NOT to guess is: {KLL}.
Give me your best hint."""
DEFAULT_SPYMASTER_INSTRUCT = """You are playing Codenames as a bold and creative spymaster giving hints.
Your answers should be in the format WORD - NUMBER."""
DEFAULT_SPYMASTER_TEMPERATURE = 0.9
FULL_LANGUAGES = {
"cz": "Czech",
"de": "German",
"en": "English",
"es": "Spanish",
"fr": "Fremnh",
"it": "Italian",
}
@st.cache_data
def get_lang_options() -> List[str]:
"""Get available language options"""
return [x[:-4] for x in os.listdir("words_lists") if x.endswith(".txt")]
@st.cache_data
def get_default_words_list(lang: str = "en") -> List[str]:
"""Returns the default list of words for the given language"""
with open(os.path.join("words_lists", f"{lang}.txt"), "r") as open_file:
words_list = open_file.read().upper()
return words_list
@st.cache_resource
def get_openai_client(api_key: str) -> OpenAI:
"""Returns an OpenAI client"""
client = OpenAI(api_key=api_key)
available_models = [x.id for x in client.models.list()]
return client, available_models
@st.cache_data
def generate_board(
words_list: List[str], side_length: int = 5, random_seed: int = 42
) -> Tuple[List[str], List[int]]:
"""Generate a board of `side_length**2` words"""
words_list = [x.strip() for x in words_list.splitlines() if len(x.strip())]
random.seed(random_seed)
random.shuffle(words_list)
# TODO: Adapt number of cards to larger board
team_assignment = [-1] * 1 + [0] * 7 + [1] * 8 + [2] * (side_length**2 - 16)
random.shuffle(team_assignment)
return words_list[: side_length**2], team_assignment
def generate_spymaster_prompt(
words: List[str], team_assignment: List[int]
) -> Tuple[List[str], List[str], List[str], List[str]]:
""" "Separate words on the board in their respective teams"""
SLF, OPP, NTR, KLL = [], [], [], []
for idx, (w, a) in enumerate(zip(words, team_assignment)):
if a == -1:
KLL.append(w)
elif a == 0:
NTR.append(w)
elif a == 1:
OPP.append(w)
elif a == 2:
SLF.append(w)
return SLF, OPP, NTR, KLL
class MessageType(Enum):
"""Type of messages added to the chat history during conversation with the Spymaster"""
Prompt = 0
Guess = 1
Hint = 2
Instruct = 3
class Spymaster:
"""Base Spymaster type"""
def __init__(
self, client, model_name: str, use_last_prompt_only: bool = False
) -> None:
self.client = client
self.model_name = model_name
self._prompt = DEFAULT_SPYMASTER_PROMPT
self.current_hint_word = None
self.og_hint_num = 0
self.current_hint_num = 0
self.temperature = DEFAULT_SPYMASTER_TEMPERATURE
self.chat_history = [
[
(
MessageType.Instruct,
{"role": "system", "content": DEFAULT_SPYMASTER_INSTRUCT},
),
]
for _ in range(2)
]
self.use_last_prompt_only = use_last_prompt_only
self.current_team = 1
def words(self, team: int) -> List[str]:
"""Return words belonging to the given team and still on the board"""
return self.slf if team == 1 else self.opp
def get_history(self, team: int) -> str:
"""Return chat history for the given team with markdown formatting"""
return "\n\n".join(
f"**{x['content']}**"
if msg_type == MessageType.Hint
else f" * {x['content']}"
for msg_type, x in self.chat_history[team]
if msg_type in [MessageType.Guess, MessageType.Hint]
)
def update_words(self, words: List[str], team_assignment: List[int]) -> None:
"""Update the words and team assignments"""
self.slf, self.opp, self.ntr, self.kll = generate_spymaster_prompt(
words, team_assignment
)
def update_prompt(self, prompt: str) -> None:
"""Update the prompt template"""
self._prompt = prompt
def update_instruct(self, instruct: str) -> None:
"""Update the base system instruct"""
for lst in self.chat_history:
lst[0][1]["content"] = instruct
def update_temperature(self, t: float) -> None:
self.temperature = t
def use_whole_history(self, enabled: bool) -> None:
"""Whether to use the whole chat history or not"""
self.use_last_prompt_only = not enabled
@property
def prompt(self) -> str:
"""Format prompt with the current words"""
return self._prompt.format(
SLF=", ".join(self.words(self.current_team)),
OPP=", ".join(self.words(1 - self.current_team)),
NTR=", ".join(self.ntr),
KLL=", ".join(self.kll),
)
def remove(self, word: str, team: int) -> None:
"""Action of guessing the given `word` which is assigned to the given `team`
:param word: Word clicked on the board
:param team: Word's team assignment. -1 for the killer card (instant loss),
0 for neutral card, 1 and 2 for either the blue or red team
"""
self.chat_history[self.current_team].append(
(
MessageType.Guess,
{"role": "user", "content": f"Your teammate picked {word}"},
),
)
# Guessed the killer card :(
if team == -1:
self.kll.remove(word)
# Guessed a neutral or opponent : end turn
elif team == 0 or team != (self.current_team + 1):
(self.ntr if team == 0 else self.words(1 - self.current_team)).remove(word)
self.end_turn()
# Guessed a correct word: we only continue if we have left over guesses + 1
else:
self.current_hint_num -= 1
self.words(self.current_team).remove(word)
if self.current_hint_num < 0:
self.end_turn()
def end_turn(self) -> None:
"""End turn immediately"""
self.current_hint_word = None
self.current_team = 1 - self.current_team
def give_hint(self, num_retries: int = 2, debug: bool = False) -> None:
"""Generates hint by prompting the language model
:param num_retries: Number of retries in case the generated hint is
badly formatted
:param debug: If True, print more verbose output
"""
self.chat_history[self.current_team].append(
(MessageType.Prompt, {"role": "user", "content": self.prompt})
)
if debug:
print(
"\n\n".join(
f"{msg_type} - {x['content']}"
for msg_type, x in self.chat_history[self.current_team]
)
)
self.current_hint_num = -1
while self.current_hint_num < 1 and num_retries >= 0:
try:
# Prompt assistant
completion = self.client.chat.completions.create(
model=self.model_name,
messages=[
self.chat_history[self.current_team][0][1],
self.chat_history[self.current_team][-1][1],
]
if self.use_last_prompt_only
else [x[1] for x in self.chat_history[self.current_team]],
temperature=self.temperature,
)
out = completion.choices[0].message.content.split("-")
# Parse response until we get a valid hint
hint_word, self.current_hint_num = out[0].strip().upper(), int(
out[-1].strip().replace(".", "")
)
# Need to give at least one number and not give a word on the board
if (
self.current_hint_num >= 1
and not (
hint_word in self.slf
or hint_word in self.opp
or hint_word in self.ntr
or hint_word in self.kll
)
# if num retries hits 0, we still give a hint even though it might be invalid
) or num_retries == 0:
self.current_hint_word = hint_word
self.og_hint_num = self.current_hint_num
self.chat_history[self.current_team].append(
(
MessageType.Hint,
{
"role": "assistant",
"content": f"{self.current_hint_word} - {self.current_hint_num}",
},
)
)
break
except ValueError:
pass
num_retries -= 1
def play(self) -> Tuple[str, int]:
"""Display action in the hint box based on the current game's state"""
fmt = ":blue[{}]" if self.current_team == 1 else ":red[{}]"
# Check if we lost by guessing the killer card in the previous action
if len(self.kll) == 0:
return (
fmt.format("You found the assasin. You lost ☠️"),
-1,
)
# Check if we lost by guessing the opponent's last word
if len(self.words(1 - self.current_team)) == 0:
return (
fmt.format("You guessed for the other team.You lost ☠️"),
-1,
)
# Check if we won
if len(self.words(self.current_team)) == 0:
return fmt.format("You guessed all your cards. You win 🪩 !"), 1
# Otherwise, give a hint and continue
if self.current_hint_word is None:
self.give_hint()
return (
fmt.format(f"{self.current_hint_word} - {self.og_hint_num}")
+ " \n"
+ fmt.format(f"({self.current_hint_num+ 1} guesses left)"),
0,
)
@st.cache_resource
def init_spymaster(
_client: OpenAI, model_name: str, words: List[str], team_assignment: List[int]
) -> Spymaster:
"""Init the spymaster object"""
spymaster = Spymaster(_client, model_name)
spymaster.update_words(words, team_assignment)
return spymaster