@@ -120,6 +120,13 @@ def append_system_message(self, message: str):
120
120
),
121
121
)
122
122
123
+ def append_ca_message (self , message : str ):
124
+ self .ca_messages .append (
125
+ SystemMessage (
126
+ content = message ,
127
+ ),
128
+ )
129
+
123
130
def append_user_message (self , message : str ):
124
131
self .messages .append (
125
132
HumanMessage (
@@ -133,19 +140,11 @@ def setup(self, context: str):
133
140
"""
134
141
for msg in self .prompts ["primary_model_prompts" ]:
135
142
if msg :
136
- self .messages .append (
137
- SystemMessage (
138
- content = msg ,
139
- ),
140
- )
143
+ self .append_system_message (msg )
141
144
142
145
for msg in self .prompts ["correcting_agent_prompts" ]:
143
146
if msg :
144
- self .ca_messages .append (
145
- SystemMessage (
146
- content = msg ,
147
- ),
148
- )
147
+ self .append_ca_message (msg )
149
148
150
149
self .context = context
151
150
msg = f"The topic of the research is { context } ."
@@ -384,6 +383,56 @@ def load_models(self):
384
383
# names.append(name)
385
384
# return names
386
385
386
+ def append_system_message (self , message : str ):
387
+ """
388
+ We override the system message addition because Xinference does not
389
+ accept multiple system messages. We concatenate them if there are
390
+ multiple.
391
+
392
+ Args:
393
+ message (str): The message to append.
394
+ """
395
+ # if there is not already a system message in self.messages
396
+ if not any (isinstance (m , SystemMessage ) for m in self .messages ):
397
+ self .messages .append (
398
+ SystemMessage (
399
+ content = message ,
400
+ ),
401
+ )
402
+ else :
403
+ # if there is a system message, append to the last one
404
+ for i , msg in enumerate (self .messages ):
405
+ if isinstance (msg , SystemMessage ):
406
+ self .messages [i ].content += f"\n { message } "
407
+ break
408
+
409
+ def append_ca_message (self , message : str ):
410
+ """
411
+
412
+ We also override the system message addition for the correcting agent,
413
+ likewise because Xinference does not accept multiple system messages. We
414
+ concatenate them if there are multiple.
415
+
416
+ TODO this currently assumes that the correcting agent is the same model
417
+ as the primary one.
418
+
419
+ Args:
420
+ message (str): The message to append.
421
+ """
422
+ # if there is not already a system message in self.messages
423
+ if not any (isinstance (m , SystemMessage ) for m in self .ca_messages ):
424
+ self .ca_messages .append (
425
+ SystemMessage (
426
+ content = message ,
427
+ ),
428
+ )
429
+ else :
430
+ # if there is a system message, append to the last one
431
+ for i , msg in enumerate (self .ca_messages ):
432
+ if isinstance (msg , SystemMessage ):
433
+ self .ca_messages [i ].content += f"\n { message } "
434
+ break
435
+
387
436
def _primary_query (self ):
388
437
"""
389
438
0 commit comments