Skip to content

Commit 308128b

Browse files
committed
llm - openai, gpt4all, llama
1 parent adc2380 commit 308128b

File tree

10 files changed

+181
-69
lines changed

10 files changed

+181
-69
lines changed

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,4 +162,7 @@ cython_debug/
162162
embeddings/*
163163
!embeddings/.gitkeep
164164
projects/*
165-
!projects/.gitkeep
165+
!projects/.gitkeep
166+
models/*
167+
!models/.gitkeep
168+
.DS_Store

app/brain.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,39 @@
11
from fastapi import HTTPException
22
from langchain.text_splitter import CharacterTextSplitter
3+
from langchain.chains import RetrievalQA
4+
from langchain import OpenAI
5+
from langchain.llms import GPT4All, LlamaCpp
36

47
from app.project import Project
58

9+
LLMS = {
10+
"openai": (OpenAI, {"temperature": 0, "model_name": "text-davinci-003"}),
11+
"llamacpp": (LlamaCpp, {"model": "./models/ggml-model-q4_0.bin"}),
12+
"gpt4all": (GPT4All, {"model": "./models/ggml-gpt4all-j-v1.3-groovy.bin", "backend": "gptj", "n_ctx": 1000}),
13+
}
14+
15+
616
class Brain:
717
def __init__(self):
8-
self.projects = []
18+
self.projects = []
19+
20+
self.text_splitter = CharacterTextSplitter(
21+
separator=" ", chunk_size=1024, chunk_overlap=0)
22+
23+
self.llmCache = {}
924

10-
self.text_splitter = CharacterTextSplitter(separator=" ", chunk_size=1024, chunk_overlap=0)
25+
def loadLLM(self, llmModel, **kwargs):
26+
if llmModel in self.llmCache:
27+
return self.llmCache[llmModel]
28+
else:
29+
if llmModel in LLMS:
30+
loader_class, llm_args = LLMS[llmModel]
31+
llm = loader_class(**llm_args, **kwargs)
32+
self.llmCache[llmModel] = llm
33+
return llm
34+
else:
35+
raise HTTPException(
36+
status_code=500, detail='{"error": "Invalid LLM type."}')
1137

1238
def listProjects(self):
1339
return [project.model.name for project in self.projects]
@@ -20,18 +46,31 @@ def createProject(self, projectModel):
2046

2147
def loadProject(self, name):
2248
for project in self.projects:
23-
if project.model.name == name:
24-
return project
49+
if project.model.name == name:
50+
return project
2551

2652
project = Project()
2753
project.load(name)
2854
self.projects.append(project)
2955
return project
30-
31-
56+
3257
def deleteProject(self, name):
3358
for project in self.projects:
34-
if project.model.name == name:
35-
project.delete()
36-
self.projects.remove(project)
59+
if project.model.name == name:
60+
project.delete()
61+
self.projects.remove(project)
62+
63+
def question(self, project, questionModel):
64+
retriever = project.db.as_retriever(
65+
search_type="similarity", search_kwargs={"k": 2}
66+
)
67+
68+
llm = self.loadLLM(questionModel.llm or project.model.llm)
69+
70+
qa = RetrievalQA.from_chain_type(
71+
llm=llm,
72+
chain_type="stuff",
73+
retriever=retriever,
74+
)
3775

76+
return qa.run(questionModel.question)

app/main.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22
import os
33
from tempfile import NamedTemporaryFile
44
from fastapi import FastAPI, HTTPException, Request, UploadFile
5-
from langchain import OpenAI
65
from langchain.document_loaders import (
76
WebBaseLoader,
87
)
98
from langchain.chains import RetrievalQA
109
from dotenv import load_dotenv
1110
from app.brain import Brain
1211

13-
from app.project import IngestModel, ProjectModel, QueryModel
12+
from app.models import IngestModel, ProjectModel, QuestionModel, ChatModel
1413
from app.tools import FindFileLoader, IndexDocuments
1514

1615
load_dotenv()
@@ -118,26 +117,24 @@ def ingestFile(projectName: str, file: UploadFile):
118117
raise HTTPException(
119118
status_code=500, detail='{"error": ' + str(e) + '}')
120119

121-
@app.post("/projects/{projectName}/query")
122-
def queryProject(projectName: str, input: QueryModel):
123-
try:
124-
project = brain.loadProject(projectName)
125-
126-
retriever = project.db.as_retriever(
127-
search_type="similarity", search_kwargs={"k": 2}
128-
)
129-
130-
llm = OpenAI(temperature=0, model_name="text-davinci-003") # type: ignore
131120

132-
qa = RetrievalQA.from_chain_type(
133-
llm=llm,
134-
chain_type="stuff",
135-
retriever=retriever,
136-
)
121+
@app.post("/projects/{projectName}/question")
122+
def questionProject(projectName: str, input: QuestionModel):
123+
try:
124+
project = brain.loadProject(projectName)
125+
answer = brain.question(project, input)
126+
return {"question": input.question, "answer": answer.strip()}
127+
except Exception as e:
128+
logging.error(e)
129+
raise HTTPException(
130+
status_code=500, detail='{"error": ' + str(e) + '}')
137131

138-
answer = qa.run(input.query)
139132

140-
return {"query": input.query, "answer": answer.strip()}
133+
@app.post("/projects/{projectName}/chat")
134+
def chatProject(projectName: str, input: ChatModel):
135+
try:
136+
project = brain.loadProject(projectName)
137+
return "Not implemented yet."
141138
except Exception as e:
142139
raise HTTPException(
143-
status_code=500, detail='{"error": ' + str(e) + '}')
140+
status_code=500, detail='{"error": ' + str(e) + '}')

app/models.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from pydantic import BaseModel
2+
from typing import Union
3+
4+
class IngestModel(BaseModel):
5+
url: str
6+
7+
class QuestionModel(BaseModel):
8+
question: str
9+
llm: Union[str, None] = None
10+
11+
class ChatModel(BaseModel):
12+
message: str
13+
conversation: Union[str, None] = None
14+
llm: Union[str, None] = None
15+
16+
class ProjectModel(BaseModel):
17+
name: str
18+
embeddings: Union[str, None] = None
19+
embeddings_model: Union[str, None] = None
20+
llm: Union[str, None] = None
21+
llm_model: Union[str, None] = None

app/project.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,13 @@
1-
from pydantic import BaseModel
2-
from typing import Union
31
import json
42
import os
53
import shutil
64

75
from langchain.vectorstores import Chroma
6+
from app.models import ProjectModel
87

98
from app.tools import GetEmbedding
109

1110

12-
class IngestModel(BaseModel):
13-
url: str
14-
15-
class QueryModel(BaseModel):
16-
query: str
17-
18-
class ProjectModel(BaseModel):
19-
name: str
20-
embeddings: Union[str, None] = None
21-
embeddings_model: Union[str, None] = None
22-
llm_model: Union[str, None] = None
23-
24-
2511
class Project:
2612

2713
def boot(self, model: ProjectModel):

app/tools.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def GetEmbedding(name: str, model=None):
2424

2525
return embeddings[name]
2626

27-
2827
def IndexDocuments(brain, project, documents):
2928
texts = brain.text_splitter.split_documents(documents)
3029
texts_final = [doc.page_content for doc in texts]
@@ -40,7 +39,7 @@ def IndexDocuments(brain, project, documents):
4039

4140

4241
def FindFileLoader(temp, ext):
43-
LOADERS_MAP = {
42+
loaders = {
4443
".csv": (CSVLoader, {}),
4544
".doc": (UnstructuredWordDocumentLoader, {}),
4645
".docx": (UnstructuredWordDocumentLoader, {}),
@@ -56,8 +55,8 @@ def FindFileLoader(temp, ext):
5655
".txt": (TextLoader, {"encoding": "utf8"}),
5756
}
5857

59-
if ext in LOADERS_MAP:
60-
loader_class, loader_args = LOADERS_MAP[ext]
58+
if ext in loaders:
59+
loader_class, loader_args = loaders[ext]
6160
return loader_class(temp.name, **loader_args)
6261
else:
6362
raise HTTPException(status_code=500, detail='{"error": "Invalid file type."}')

models/.gitkeep

Whitespace-only changes.

requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@ python-multipart
66
dotenv
77
pdfminer.six
88
unstructured
9-
pydantic
9+
pydantic
10+
pygpt4all
11+
llama-cpp-python

tests/test_huggingface.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from fastapi.testclient import TestClient
2+
3+
from app.main import app
4+
5+
client = TestClient(app)
6+
7+
8+
def test_get():
9+
response = client.get("/")
10+
assert response.status_code == 200
11+
assert response.json() == "REST AI API, so many 'A's and 'I's, so little time..."
12+
13+
14+
def test_createProjectHF():
15+
response = client.post(
16+
"/projects", json={"name": "test_huggingface", "embeddings": "huggingface"})
17+
assert response.status_code == 200
18+
19+
20+
def test_getProjectHF():
21+
response = client.get("/projects/test_huggingface")
22+
assert response.status_code == 200
23+
24+
25+
def test_ingestURLHF():
26+
response = client.post("/projects/test_huggingface/ingest/url",
27+
json={"url": "https://www.google.com"})
28+
assert response.status_code == 200
29+
30+
31+
def test_getProjectAfterIngestURLHF():
32+
response = client.get("/projects/test_huggingface")
33+
assert response.status_code == 200
34+
assert response.json() == {
35+
"project": "test_huggingface", "embeddings": "huggingface", "documents": 1, "metadatas": 1}
36+
37+
38+
def test_ingestUploadHF():
39+
response = client.post("/projects/test_huggingface/ingest/upload",
40+
files={"file": ("test.txt", open("tests/test.txt", "rb"))})
41+
assert response.status_code == 200
42+
43+
44+
def test_getProjectAfterIngestUploadHF():
45+
response = client.get("/projects/test_huggingface")
46+
assert response.status_code == 200
47+
assert response.json() == {
48+
"project": "test_huggingface", "embeddings": "huggingface", "documents": 2, "metadatas": 2}
49+
50+
51+
def test_questionProjectHF():
52+
response = client.post("/projects/test_huggingface/question",
53+
json={"question": "What is the secret?", "llm": "gpt4all"})
54+
assert response.status_code == 200
55+
assert response.json() == {"question": "What is the secret?",
56+
"answer": "The secret is that ingenuity should be bigger than politics and corporate greed."}
57+
58+
59+
def test_deleteProjectHF():
60+
response = client.delete("/projects/test_huggingface")
61+
assert response.status_code == 200
62+
assert response.json() == {"project": "test_huggingface"}
63+
64+
65+
def test_getProjectsAfterDelete():
66+
response = client.get("/projects")
67+
assert response.status_code == 200
68+
assert response.json() == {"projects": []}

tests/test_main.py renamed to tests/test_openai.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,59 +19,56 @@ def test_getProjects():
1919

2020
def test_createProject():
2121
response = client.post(
22-
"/projects", json={"name": "test", "embeddings": "openai"})
22+
"/projects", json={"name": "test_openai", "embeddings": "openai", "llm": "openai"})
2323
assert response.status_code == 200
2424

25-
2625
def test_getProject():
27-
response = client.get("/projects/test")
26+
response = client.get("/projects/test_openai")
2827
assert response.status_code == 200
2928

3029

3130
def test_ingestURL():
32-
response = client.post("/projects/test/ingest/url",
31+
response = client.post("/projects/test_openai/ingest/url",
3332
json={"url": "https://www.google.com"})
3433
assert response.status_code == 200
3534

3635

3736
def test_getProjectAfterIngestURL():
38-
response = client.get("/projects/test")
37+
response = client.get("/projects/test_openai")
3938
assert response.status_code == 200
4039
assert response.json() == {
41-
"project": "test", "embeddings": "openai", "documents": 1, "metadatas": 1}
40+
"project": "test_openai", "embeddings": "openai", "documents": 1, "metadatas": 1}
4241

4342

4443
def test_ingestUpload():
45-
response = client.post("/projects/test/ingest/upload",
44+
response = client.post("/projects/test_openai/ingest/upload",
4645
files={"file": ("test.txt", open("tests/test.txt", "rb"))})
4746
assert response.status_code == 200
4847

4948

5049
def test_getProjectAfterIngestUpload():
51-
response = client.get("/projects/test")
50+
response = client.get("/projects/test_openai")
5251
assert response.status_code == 200
5352
assert response.json() == {
54-
"project": "test", "embeddings": "openai", "documents": 2, "metadatas": 2}
53+
"project": "test_openai", "embeddings": "openai", "documents": 2, "metadatas": 2}
5554

5655

57-
def test_query():
58-
response = client.post("/projects/test/query",
59-
json={"query": "What is the secret?"})
56+
def test_questionProject():
57+
response = client.post("/projects/test_openai/question",
58+
json={"question": "What is the secret?"})
6059
assert response.status_code == 200
61-
assert response.json() == {"query": "What is the secret?",
60+
assert response.json() == {"question": "What is the secret?",
6261
"answer": "The secret is that ingenuity should be bigger than politics and corporate greed."}
6362

64-
6563
def test_deleteProject():
66-
response = client.delete("/projects/test")
64+
response = client.delete("/projects/test_openai")
6765
assert response.status_code == 200
68-
assert response.json() == {"project": "test"}
69-
66+
assert response.json() == {"project": "test_openai"}
7067

7168
def test_getProjectAfterDelete():
72-
response = client.get("/projects/test")
69+
response = client.get("/projects/test_openai")
7370
assert response.status_code == 404
74-
71+
7572

7673
def test_getProjectsAfterDelete():
7774
response = client.get("/projects")

0 commit comments

Comments
 (0)