-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathrag_app.py
96 lines (73 loc) · 2.8 KB
/
rag_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# rag_app.py
"""
Module for retrieval and generation of responses using Qdrant and GPT-3.5-turbo.
This module provides the function to retrieve context from Qdrant and generate
an answer using GPT-3.5-turbo based on the retrieved context.
"""
import os
from dotenv import load_dotenv
from langchain.chains import RetrievalQA
from langchain.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_qdrant import Qdrant
from qdrant_client import QdrantClient
# Load environmental variables from a .env file
load_dotenv()
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
QDRANT_API_KEY = os.getenv('QDRANT_API_KEY')
QDRANT_URL = os.getenv('QDRANT_URL')
COLLECTION_NAME = os.getenv('COLLECTION_NAME')
# Initialize Qdrant client
client = QdrantClient(
url=QDRANT_URL,
api_key=QDRANT_API_KEY
)
# Embed and store documents in Qdrant
embedding = OpenAIEmbeddings(openai_api_key=OPENAI_API_KEY)
def rag_retrieve_and_generate(query, collection_name):
"""
Retrieve context from Qdrant and generate an answer using GPT-3.5-turbo.
Args:
----
query (str): The question to retrieve context for.
collection_name (str): The name of the Qdrant collection to search in.
Returns:
-------
str: The generated answer based on retrieved context.
"""
# Initialize vector store
vectorstore = Qdrant(client=client,
collection_name=collection_name,
embeddings=embedding,
vector_name="content")
# Define the prompt template
template = """
You are an assistant for question-answering tasks. \
Use the following pieces of retrieved context to answer the question. \
If you don't know the answer, just say that you don't know. \
Question: {question}
Context: {context}
Answer:
"""
# Initialize retriever
retriever = vectorstore.as_retriever()
# Create prompt using the template
prompt = ChatPromptTemplate.from_template(template)
# Initialize the LLM (GPT-3.5-turbo)
llm35 = ChatOpenAI(temperature=0.0,
model="gpt-3.5-turbo",
max_tokens=512)
# Create a retrieval QA chain
qa_d35 = RetrievalQA.from_chain_type(
llm=llm35,
chain_type="stuff",
chain_type_kwargs = {"prompt": prompt},
retriever=retriever)
# Invoke the chain with the query to get the result
result = qa_d35.invoke({"query": query})["result"]
return result
if __name__ == "__main__":
# Example usage
collection_name = COLLECTION_NAME
query = "What is the attention mechanism?"
print(f"Response: {rag_retrieve_and_generate(query, collection_name)}")