Skip to content

【STEP4】RAG機能の実装 #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,13 @@
chainlit==1.0.200
chainlit==1.0.101
langchain==0.1.5
langchain-openai
openai==1.10.0
pyautogen==0.2.10
pymupdf
spacy
chromadb==0.4.22
python-docx
openpyxl
pandas
unstructured==0.6.7
docx2txt
131 changes: 114 additions & 17 deletions src/demo.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,121 @@
import os

import chainlit as cl
from langchain.chains import ConversationalRetrievalChain
from langchain.chat_models import ChatOpenAI
from langchain.embeddings import OpenAIEmbeddings
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import Chroma

from document_loader import excel_loader, pdf_loader, word_loader

ALLOWED_MIME_TYPES = [
"application/pdf",
"application/octet-stream",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
]

ALLOWED_EXTENSIONS = [
".pdf",
".docx",
".xlsx",
]


@cl.on_chat_start
async def on_chat_start():
"""初回起動時に呼び出される."""
files = None

# awaitメソッドのために、whileを利用する。アップロードされるまで続く。
while files is None:
files = await cl.AskFileMessage(
max_size_mb=20,
content="ファイルを選択してください(.pdf、.docx、.xlsxに対応しています)",
accept=ALLOWED_MIME_TYPES,
raise_on_timeout=False,
).send()

file = files[0]
ext = os.path.splitext(file.name)[1]

# アップロードされたファイルのパスから中身を読み込む。
if ext in ALLOWED_EXTENSIONS:
if ext == ".pdf":
documents = pdf_loader(file.path)
elif ext == ".docx":
documents = word_loader(file.path)
else:
documents = excel_loader(file.path)

text_splitter = CharacterTextSplitter(chunk_size=500)
splitted_documents = text_splitter.split_documents(documents)

# テキストをベクトル化するOpenAIのモデル
embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")

# Chromaにembedding APIを指定して、初期化する。
database = Chroma(embedding_function=embeddings)

# PDFから内容を分割されたドキュメントを保存する。
database.add_documents(splitted_documents)

cl.user_session.set("data", database)
await cl.Message(content="アップロードが完了しました!").send()

llm = ChatOpenAI(model="gpt-4o", temperature=0)
message_history = ChatMessageHistory()
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key="answer",
chat_memory=message_history,
return_messages=True,
)

# チャットモデルを初期化する。
if database is not None:
retriever = database.as_retriever()
qa = ConversationalRetrievalChain.from_llm(
llm=llm,
chain_type="stuff",
retriever=retriever,
verbose=True,
memory=memory,
return_source_documents=True,
)
cl.user_session.set("qa", qa)


@cl.step
def tool():
return "Response from the tool!"
@cl.on_message
async def on_message(input_message: cl.Message):
"""メッセージが送られるたびに呼び出される."""
qa = cl.user_session.get("qa")

# ユーザーの入力をQAモデルに渡して、回答を取得する。
res = await qa.acall(
input_message.content,
callbacks=[cl.AsyncLangchainCallbackHandler()]
)
answer = res["answer"]
source_documents = res["source_documents"]
text_elements = []

@cl.on_message # this function will be called every time a user inputs a message in the UI
async def main(message: cl.Message):
"""
This function is called every time a user inputs a message in the UI.
It sends back an intermediate response from the tool, followed by the final answer.
Args:
message: The user's message.
Returns:
None.
"""
# 参照元のドキュメントを表示する。
if source_documents:
for source_idx, source_doc in enumerate(source_documents):
source_name = f"source_{source_idx}"
# Create the text element referenced in the message
text_elements.append(
cl.Text(content=source_doc.page_content, name=source_name)
)
source_names = [text_el.name for text_el in text_elements]

# Call the tool
tool()
if source_names:
answer += f"\n参照元: {', '.join(source_names)}"
else:
answer += "\n参照先が見つかりませんでした。"

# Send the final answer.
await cl.Message(content="This is the final answer").send()
await cl.Message(content=answer, elements=text_elements).send()
17 changes: 17 additions & 0 deletions src/document_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from langchain.document_loaders import (Docx2txtLoader, PyMuPDFLoader,
UnstructuredExcelLoader)


def pdf_loader(file_path):
documents = PyMuPDFLoader(file_path).load()
return documents


def word_loader(file_path):
documents = Docx2txtLoader(file_path).load()
return documents


def excel_loader(file_path):
documents = UnstructuredExcelLoader(file_path, mode="elements").load()
return documents