Skip to content

Commit 03b4ae6

Browse files
authored
Merge pull request #67 from biocypher/xinference-system-msg
adjustment to message appending to account for ..
2 parents fa9cf75 + 8b42346 commit 03b4ae6

File tree

1 file changed

+59
-10
lines changed

1 file changed

+59
-10
lines changed

biochatter/llm_connect.py

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ def append_system_message(self, message: str):
120120
),
121121
)
122122

123+
def append_ca_message(self, message: str):
124+
self.ca_messages.append(
125+
SystemMessage(
126+
content=message,
127+
),
128+
)
129+
123130
def append_user_message(self, message: str):
124131
self.messages.append(
125132
HumanMessage(
@@ -133,19 +140,11 @@ def setup(self, context: str):
133140
"""
134141
for msg in self.prompts["primary_model_prompts"]:
135142
if msg:
136-
self.messages.append(
137-
SystemMessage(
138-
content=msg,
139-
),
140-
)
143+
self.append_system_message(msg)
141144

142145
for msg in self.prompts["correcting_agent_prompts"]:
143146
if msg:
144-
self.ca_messages.append(
145-
SystemMessage(
146-
content=msg,
147-
),
148-
)
147+
self.append_ca_message(msg)
149148

150149
self.context = context
151150
msg = f"The topic of the research is {context}."
@@ -384,6 +383,56 @@ def load_models(self):
384383
# names.append(name)
385384
# return names
386385

386+
def append_system_message(self, message: str):
387+
"""
388+
We override the system message addition because Xinference does not
389+
accept multiple system messages. We concatenate them if there are
390+
multiple.
391+
392+
Args:
393+
message (str): The message to append.
394+
"""
395+
# if there is not already a system message in self.messages
396+
if not any(isinstance(m, SystemMessage) for m in self.messages):
397+
self.messages.append(
398+
SystemMessage(
399+
content=message,
400+
),
401+
)
402+
else:
403+
# if there is a system message, append to the last one
404+
for i, msg in enumerate(self.messages):
405+
if isinstance(msg, SystemMessage):
406+
self.messages[i].content += f"\n{message}"
407+
break
408+
409+
def append_ca_message(self, message: str):
410+
"""
411+
412+
We also override the system message addition for the correcting agent,
413+
likewise because Xinference does not accept multiple system messages. We
414+
concatenate them if there are multiple.
415+
416+
TODO this currently assumes that the correcting agent is the same model
417+
as the primary one.
418+
419+
Args:
420+
message (str): The message to append.
421+
"""
422+
# if there is not already a system message in self.messages
423+
if not any(isinstance(m, SystemMessage) for m in self.ca_messages):
424+
self.ca_messages.append(
425+
SystemMessage(
426+
content=message,
427+
),
428+
)
429+
else:
430+
# if there is a system message, append to the last one
431+
for i, msg in enumerate(self.ca_messages):
432+
if isinstance(msg, SystemMessage):
433+
self.ca_messages[i].content += f"\n{message}"
434+
break
435+
387436
def _primary_query(self):
388437
"""
389438

0 commit comments

Comments
 (0)