Skip to content

Commit f484cc2

Browse files
committedSep 17, 2024·
inference chat mode
1 parent f87243d commit f484cc2

File tree

3 files changed

+66
-2
lines changed

3 files changed

+66
-2
lines changed
 

‎.vscode/settings.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
],
55
"python.testing.unittestEnabled": false,
66
"python.testing.pytestEnabled": true,
7-
"python.venvPath": "~/.cache/pypoetry/virtualenvs"
7+
"python.venvPath": "~/.cache/pypoetry/virtualenvs",
8+
"makefile.configureOnOpen": false
89
}

‎app/projects/inference.py

+62-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from fastapi import HTTPException
33
from requests import Session
44
from app import tools
5+
from app.chat import Chat
56
from app.guard import Guard
67
from app.models.models import ChatModel, QuestionModel, User
78
from app.project import Project
@@ -13,7 +14,67 @@
1314
class Inference(ProjectBase):
1415

1516
def chat(self, project: Project, chatModel: ChatModel, user: User, db: Session):
16-
raise HTTPException(status_code=400, detail='{"error": "Chat mode not available for this project type."}')
17+
chat = Chat(chatModel)
18+
output = {
19+
"question": chatModel.question,
20+
"type": "inference",
21+
"sources": [],
22+
"guard": False,
23+
"tokens": {
24+
"input": 0,
25+
"output": 0
26+
},
27+
"project": project.model.name,
28+
"id": chat.id
29+
}
30+
31+
if project.model.guard:
32+
guard = Guard(project.model.guard, self.brain, db)
33+
if guard.verify(chatModel.question):
34+
output["answer"] = project.model.censorship or self.brain.defaultCensorship
35+
output["guard"] = True
36+
output["tokens"] = {
37+
"input": tools.tokens_from_string(output["question"]),
38+
"output": tools.tokens_from_string(output["answer"])
39+
}
40+
yield output
41+
42+
model = self.brain.getLLM(project.model.llm, db)
43+
44+
sysTemplate = project.model.system or self.brain.defaultSystem
45+
model.llm.system_prompt = sysTemplate
46+
47+
if not chat.memory.get_all():
48+
chat.memory.chat_store.add_message(chat.memory.chat_store_key, ChatMessage(role="system", content=sysTemplate))
49+
50+
chat.memory.chat_store.add_message(chat.memory.chat_store_key, ChatMessage(role="user", content=chatModel.question))
51+
messages = chat.memory.get_all()
52+
53+
try:
54+
if(chatModel.stream):
55+
respgen = model.llm.stream_chat(messages)
56+
response = ""
57+
for text in respgen:
58+
response += text.delta
59+
yield "data: " + json.dumps({"text": text.delta}) + "\n\n"
60+
output["answer"] = response
61+
chat.memory.chat_store.add_message(chat.memory.chat_store_key, ChatMessage(role="assistant", content=response))
62+
yield "data: " + json.dumps(output) + "\n"
63+
yield "event: close\n\n"
64+
else:
65+
resp = model.llm.chat(messages)
66+
output["answer"] = resp.message.content.strip()
67+
output["tokens"] = {
68+
"input": tokens_from_string(output["question"]),
69+
"output": tokens_from_string(output["answer"])
70+
}
71+
chat.memory.chat_store.add_message(chat.memory.chat_store_key, ChatMessage(role="assistant", content=resp.message.content.strip()))
72+
yield output
73+
except Exception as e:
74+
if chatModel.stream:
75+
yield "data: Inference failed\n"
76+
yield "event: error\n\n"
77+
raise e
1778

1879
def question(self, project: Project, questionModel: QuestionModel, user: User, db: Session):
1980
output = {

‎download.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
import nltk
2+
nltk.download('averaged_perceptron_tagger')

0 commit comments

Comments
 (0)
Please sign in to comment.