-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
146 lines (128 loc) · 5.07 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import streamlit as st
import os
from langchain_community.document_loaders import PDFMinerLoader
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain_community.llms import HuggingFacePipeline
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
import torch
st.title("Custom PDF Chatbot")
# Custom CSS for chat messages
st.markdown("""
<style>
.user-message {
text-align: right;
background-color: #3c8ce7;
color: white;
padding: 10px;
border-radius: 10px;
margin-bottom: 10px;
display: inline-block;
width: fit-content;
max-width: 70%;
margin-left: auto;
box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1);
}
.assistant-message {
text-align: left;
background-color: #d16ba5;
color: white;
padding: 10px;
border-radius: 10px;
margin-bottom: 10px;
display: inline-block;
width: fit-content;
max-width: 70%;
margin-right: auto;
box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1);
}
</style>
""", unsafe_allow_html=True)
def get_file_size(file):
file.seek(0, os.SEEK_END)
file_size = file.tell()
file.seek(0)
return file_size
# Add a sidebar for model selection and user details
st.sidebar.write("Settings")
st.sidebar.write("-----------")
model_options = ["MBZUAI/LaMini-T5-738M", "google/flan-t5-base", "google/flan-t5-small"]
selected_model = st.sidebar.radio("Choose Model", model_options)
st.sidebar.write("-----------")
uploaded_file = st.sidebar.file_uploader("Upload file", type=["pdf"])
st.sidebar.write("-----------")
st.sidebar.write("About Me")
st.sidebar.write("Name: Deepak Yadav")
st.sidebar.write("Bio: Passionate about AI and machine learning. Enjoys working on innovative projects and sharing knowledge with the community.")
st.sidebar.write("[GitHub](https://github.com/deepak7376)")
st.sidebar.write("[LinkedIn](https://www.linkedin.com/in/dky7376/)")
st.sidebar.write("-----------")
@st.cache_resource
def initialize_qa_chain(filepath, CHECKPOINT):
loader = PDFMinerLoader(filepath)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=500)
splits = text_splitter.split_documents(documents)
# Create embeddings
embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
vectordb = FAISS.from_documents(splits, embeddings)
# Initialize model
TOKENIZER = AutoTokenizer.from_pretrained(CHECKPOINT)
BASE_MODEL = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT, device_map=torch.device('cpu'), torch_dtype=torch.float32)
pipe = pipeline(
'text2text-generation',
model=BASE_MODEL,
tokenizer=TOKENIZER,
max_length=256,
do_sample=True,
temperature=0.3,
top_p=0.95,
)
llm = HuggingFacePipeline(pipeline=pipe)
# Build a QA chain
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vectordb.as_retriever(),
)
return qa_chain
def process_answer(instruction, qa_chain):
generated_text = qa_chain.run(instruction)
return generated_text
if uploaded_file is not None:
os.makedirs("docs", exist_ok=True)
filepath = os.path.join("docs", uploaded_file.name)
with open(filepath, "wb") as temp_file:
temp_file.write(uploaded_file.read())
temp_filepath = temp_file.name
with st.spinner('Embeddings are in process...'):
qa_chain = initialize_qa_chain(temp_filepath, selected_model)
else:
qa_chain = None
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
if message["role"] == "user":
st.markdown(f"<div class='user-message'>{message['content']}</div>", unsafe_allow_html=True)
else:
st.markdown(f"<div class='assistant-message'>{message['content']}</div>", unsafe_allow_html=True)
# React to user input
if prompt := st.chat_input("What is up?"):
# Display user message in chat message container
st.markdown(f"<div class='user-message'>{prompt}</div>", unsafe_allow_html=True)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
if qa_chain:
# Generate response
response = process_answer({'query': prompt}, qa_chain)
else:
# Prompt to upload a file
response = "Please upload a PDF file to enable the chatbot."
# Display assistant response in chat message container
st.markdown(f"<div class='assistant-message'>{response}</div>", unsafe_allow_html=True)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response})