Skip to content

Commit c4240f0

Browse files
authored
Add support NeMo Retriever Text Reranking NIM in O-RAN chatbot (#187)
* Add support for NeMo Retriever Text Reranking NIM in oran chatbot * Add default reranker and NIM reranker configurations for oran chatbot
1 parent 9852997 commit c4240f0

File tree

2 files changed

+43
-25
lines changed

2 files changed

+43
-25
lines changed

community/oran-chatbot-multimodal/Multimodal_Assistant.py

Lines changed: 38 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from retriever.retriever import Retriever, get_relevant_docs, get_relevant_docs_mq
3636
from utils.feedback import feedback_kwargs
3737

38-
from langchain_nvidia_ai_endpoints import ChatNVIDIA
38+
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIARerank
3939
from langchain_core.messages import HumanMessage
4040
from langchain_core.output_parsers import StrOutputParser
4141
from langchain_core.prompts import ChatPromptTemplate
@@ -400,52 +400,66 @@ def load_config(cfg_arg):
400400
if rag_type == 1:
401401
augmented_queries = augment_multiple_query(transformed_query["text"])
402402
queries = [transformed_query["text"]] + augmented_queries[2:]
403-
print("Queries are = ", queries)
403+
# print("Queries are = ", queries)
404404
retrieved_documents = []
405405
retrieved_metadatas = []
406+
relevant_docs = []
406407
for query in queries:
407408
ret_docs,cons,srcs = get_relevant_docs(CORE_DIR, query)
408409
for doc in ret_docs:
409410
retrieved_documents.append(doc.page_content)
410411
retrieved_metadatas.append(doc.metadata['source'])
412+
relevant_docs.append(doc)
411413
print("length of retrieved docs: ", len(retrieved_documents))
412414
#Remove all duplicated documents and retain the original metadata
413415
unique_documents = []
414416
unique_documents_metadata = []
415-
for document,source in zip(retrieved_documents,retrieved_metadatas):
417+
unique_relevant_documents = []
418+
for idx, (document,source) in enumerate(zip(retrieved_documents,retrieved_metadatas)):
416419
if document not in unique_documents:
417420
unique_documents.append(document)
418421
unique_documents_metadata.append(source)
422+
unique_relevant_documents.append(relevant_docs[idx])
419423

420424
if len(retrieved_documents) == 0:
421425
context = ""
422426
print("not context found context")
423427
else:
424428
print("length of unique docs: ", len(unique_documents))
425-
#Instantiate the cross-encoder model and get scores for each retrieved document
426-
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') # ('BAAI/bge-reranker-large')('cross-encoder/ms-marco-MiniLM-L-6-v2')
427-
pairs = [[prompt, doc] for doc in unique_documents]
428-
scores = cross_encoder.predict(pairs)
429-
#Sort the scores from highest to least
430-
order_ids = np.argsort(scores)[::-1]
431-
# print(order_ids)
429+
#Instantiate the re-ranker model and get scores for each retrieved document
432430
new_updated_documents = []
433431
new_updated_sources = []
434-
#Get the top 6 scores
435-
if len(order_ids)>=10:
436-
for i in range(10):
437-
new_updated_documents.append(unique_documents[order_ids[i]])
438-
new_updated_sources.append(unique_documents_metadata[order_ids[i]])
432+
if not config_yaml['Reranker_NIM']:
433+
print("\n\nReranking with Cross-encoder model: ", config_yaml['reranker_model'])
434+
cross_encoder = CrossEncoder(config_yaml['reranker_model'])
435+
pairs = [[prompt, doc] for doc in unique_documents]
436+
scores = cross_encoder.predict(pairs)
437+
#Sort the scores from highest to least
438+
order_ids = np.argsort(scores)[::-1]
439+
#Get the top 10 scores
440+
if len(order_ids)>=10:
441+
for i in range(10):
442+
new_updated_documents.append(unique_documents[order_ids[i]])
443+
new_updated_sources.append(unique_documents_metadata[order_ids[i]])
444+
else:
445+
for i in range(len(order_ids)):
446+
new_updated_documents.append(unique_documents[order_ids[i]])
447+
new_updated_sources.append(unique_documents_metadata[order_ids[i]])
439448
else:
440-
for i in range(len(order_ids)):
441-
new_updated_documents.append(unique_documents[order_ids[i]])
442-
new_updated_sources.append(unique_documents_metadata[order_ids[i]])
449+
print("\n\nReranking with Retriever Text Reranking NIM model: ", config_yaml["reranker_model_name"])
450+
# Initialize and connect to the running NeMo Retriever Text Reranking NIM
451+
reranker = NVIDIARerank(model=config_yaml["reranker_model_name"],
452+
base_url=config_yaml["reranker_api_endpoint_url"], top_n=10)
453+
reranked_chunks = reranker.compress_documents(query=transformed_query["text"], documents=unique_relevant_documents)
454+
for chunks in reranked_chunks:
455+
metadata = chunks.metadata
456+
page_content = chunks.page_content
457+
new_updated_documents.append(page_content)
458+
new_updated_sources.append(metadata['source'])
443459

444-
print(new_updated_sources)
445-
print(len(new_updated_documents))
460+
print("Reranking of completed for ", len(new_updated_documents), " chunks")
446461

447462
context = ""
448-
# sources = ""
449463
sources = {}
450464
for doc in new_updated_documents:
451465
context += doc + "\n\n"
@@ -455,7 +469,7 @@ def load_config(cfg_arg):
455469
sources[src] = {"doc_content": sources[src]["doc_content"]+"\n\n"+new_updated_documents[i], "doc_metadata": src}
456470
else:
457471
sources[src] = {"doc_content": new_updated_documents[i], "doc_metadata": src}
458-
print("length of source docs: ", len(sources))
472+
print("Length of unique source docs: ", len(sources))
459473
#Send the top 10 results along with the query to LLM
460474

461475
if rag_type == 2:
@@ -486,7 +500,7 @@ def load_config(cfg_arg):
486500

487501
print("length of unique docs: ", len(unique_documents))
488502
#Instantiate the cross-encoder model and get scores for each retrieved document
489-
cross_encoder = CrossEncoder('BAAI/bge-reranker-large') #('cross-encoder/ms-marco-MiniLM-L-6-v2')
503+
cross_encoder = CrossEncoder(config_yaml['reranker_model'])
490504
pairs = [[prompt, doc] for doc in unique_documents]
491505
scores = cross_encoder.predict(pairs)
492506
#Sort the scores from highest to least
@@ -544,7 +558,7 @@ def load_config(cfg_arg):
544558

545559
print("length of unique docs: ", len(unique_documents))
546560
#Instantiate the cross-encoder model and get scores for each retrieved document
547-
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') #('BAAI/bge-reranker-large')
561+
cross_encoder = CrossEncoder(config_yaml['reranker_model'])
548562
pairs = [[prompt, doc] for doc in unique_documents]
549563
scores = cross_encoder.predict(pairs)
550564
#Sort the scores from highest to least

community/oran-chatbot-multimodal/config.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
nvidia_api_key: "nvapi--***"
33
## Set these to required models endpoints from NVIDIA NGC
44
llm_model: "mistralai/mixtral-8x7b-instruct-v0.1"
5-
# Augmentation_model:
65
embedding_model: "nvidia/nv-embedqa-e5-v5"
6+
reranker_model: "cross-encoder/ms-marco-MiniLM-L-6-v2"
77

88
NIM: false
99
nim_model_name: "meta/llama3-8b-instruct"
@@ -17,4 +17,8 @@ nrem_model_name: "nvidia/nv-embedqa-e5-v5"
1717
nrem_api_endpoint_url: "http://localhost:8001/v1"
1818
nrem_truncate: "END"
1919

20+
Reranker_NIM: false
21+
reranker_model_name: "nvidia/nv-rerankqa-mistral-4b-v3"
22+
reranker_api_endpoint_url: "http://localhost:8000/v1"
23+
2024
file_delete_password: "oranpwd"

0 commit comments

Comments
 (0)