Skip to content

Commit

Permalink
Merge pull request #100 from cloudera/mob/main
Browse files Browse the repository at this point in the history
store document summary embeddings in qdrant
  • Loading branch information
ewilliams-cloudera authored Jan 13, 2025
2 parents 8e3f9bc + dc20876 commit d9e8dc8
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 12 deletions.
43 changes: 32 additions & 11 deletions llm-service/app/ai/indexing/summary_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,20 @@
get_response_synthesizer,
load_index_from_storage,
)
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.llms import LLM
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.response_synthesizers import ResponseMode
from llama_index.core.schema import (
Document,
NodeRelationship,
NodeRelationship, TextNode,
)
from qdrant_client.http.exceptions import UnexpectedResponse

from app.services.models import get_noop_embedding_model, get_noop_llm_model
from app.services.models import get_noop_llm_model, get_noop_embedding_model
from .base import BaseTextIndexer
from .readers.base_reader import ReaderConfig, ChunksResult
from ..vector_stores.qdrant import QdrantVectorStore
from ...config import Settings

logger = logging.getLogger(__name__)
Expand All @@ -79,11 +82,13 @@ def __init__(
data_source_id: int,
splitter: SentenceSplitter,
llm: LLM,
embedding_model: BaseEmbedding,
reader_config: Optional[ReaderConfig] = None,
):
super().__init__(data_source_id, reader_config=reader_config)
self.splitter = splitter
self.llm = llm
self.embedding_model = embedding_model

@staticmethod
def __database_dir(data_source_id: int) -> str:
Expand All @@ -97,10 +102,10 @@ def __persist_root_dir() -> str:
return os.path.join(Settings().rag_databases_dir, "doc_summary_index_global")

def __index_kwargs(self) -> Dict[str, Any]:
return SummaryIndexer.__index_configuration(self.llm)
return SummaryIndexer.__index_configuration(self.llm, self.embedding_model, self.data_source_id)

@staticmethod
def __index_configuration(llm: LLM) -> Dict[str, Any]:
def __index_configuration(llm: LLM, embedding_model: BaseEmbedding, data_source_id: int) -> Dict[str, Any]:
return {
"llm": llm,
"response_synthesizer": get_response_synthesizer(
Expand All @@ -110,9 +115,10 @@ def __index_configuration(llm: LLM) -> Dict[str, Any]:
verbose=True,
),
"show_progress": True,
"embed_model": get_noop_embedding_model(),
"embed_model": embedding_model,
"embed_summaries": False,
"summary_query": SUMMARY_PROMPT,
"data_source_id": data_source_id,
}

def __init_summary_store(self, persist_dir: str) -> DocumentSummaryIndex:
Expand All @@ -135,8 +141,10 @@ def __summary_indexer(self, persist_dir: str) -> DocumentSummaryIndex:

@staticmethod
def __summary_indexer_with_config(persist_dir: str, index_configuration: Dict[str, Any]) -> DocumentSummaryIndex:
data_source_id: int = index_configuration.get("data_source_id")
storage_context = StorageContext.from_defaults(
persist_dir=persist_dir,
vector_store=QdrantVectorStore.for_summaries(data_source_id).llama_vector_store()
)
doc_summary_index: DocumentSummaryIndex = cast(
DocumentSummaryIndex,
Expand Down Expand Up @@ -171,6 +179,14 @@ def index_file(self, file_path: Path, document_id: str) -> None:
persist_dir = self.__persist_dir()
summary_store = self.__summary_indexer(persist_dir)
summary_store.insert_nodes(chunks.chunks)
summary = summary_store.get_document_summary(document_id)

summary_node = TextNode()
summary_node.embedding = self.embedding_model.get_text_embedding(summary)
summary_node.text = summary
summary_node.relationships[NodeRelationship.SOURCE] = Document(doc_id=document_id).as_related_node_info()
summary_node.metadata["document_id"] = document_id
summary_store.vector_store.add(nodes=[summary_node])
summary_store.storage_context.persist(persist_dir=persist_dir)

self.__update_global_summary_store(summary_store, added_node_id=document_id)
Expand Down Expand Up @@ -234,7 +250,8 @@ def __update_global_summary_store(
# Delete first so that we don't accumulate trash in the summary store.
try:
global_summary_store.delete_ref_doc(str(self.data_source_id), delete_from_docstore=True)
except KeyError:
except (KeyError, UnexpectedResponse):
# UnexpectedResponse is raised when the collection doesn't exist, which is fine, since it might be a new index.
pass
global_summary_store.insert_nodes(new_nodes)
global_summary_store.storage_context.persist(persist_dir=global_persist_dir)
Expand All @@ -253,7 +270,7 @@ def get_full_summary(self) -> Optional[str]:
global_summary_store = self.__summary_indexer(global_persist_dir)
document_id = str(self.data_source_id)
if (
document_id not in global_summary_store.index_struct.doc_id_to_summary_id
document_id not in global_summary_store.index_struct.doc_id_to_summary_id
):
return None
return global_summary_store.get_document_summary(document_id)
Expand All @@ -269,6 +286,7 @@ def delete_document(self, document_id: str) -> None:

summary_store.delete_ref_doc(document_id, delete_from_docstore=True)
summary_store.storage_context.persist(persist_dir=persist_dir)
summary_store.vector_store.delete(document_id)

def delete_data_source(self) -> None:
with _write_lock:
Expand All @@ -277,13 +295,16 @@ def delete_data_source(self) -> None:
@staticmethod
def delete_data_source_by_id(data_source_id: int) -> None:
with _write_lock:
vector_store = QdrantVectorStore.for_summaries(data_source_id)
vector_store.delete()
# TODO: figure out a less explosive way to do this.
shutil.rmtree(SummaryIndexer.__database_dir(data_source_id), ignore_errors=True)
global_persist_dir = SummaryIndexer.__persist_root_dir()
global_persist_dir: str = SummaryIndexer.__persist_root_dir()
try:
global_summary_store = SummaryIndexer.__summary_indexer_with_config(global_persist_dir,
SummaryIndexer.__index_configuration(
get_noop_llm_model()))
configuration: Dict[str, Any] = SummaryIndexer.__index_configuration(get_noop_llm_model(),
get_noop_embedding_model(),
data_source_id=data_source_id)
global_summary_store = SummaryIndexer.__summary_indexer_with_config(global_persist_dir, configuration)
except FileNotFoundError:
## global summary store doesn't exist, nothing to do
return
Expand Down
1 change: 1 addition & 0 deletions llm-service/app/routers/index/data_source/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def _get_summary_indexer(data_source_id: int) -> Optional[SummaryIndexer]:
return SummaryIndexer(
data_source_id=data_source_id,
splitter=SentenceSplitter(chunk_size=2048),
embedding_model=models.get_embedding_model(datasource.embedding_model),
llm=models.get_llm(datasource.summarization_model),
)

Expand Down
2 changes: 1 addition & 1 deletion llm-service/app/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def summary_vector_store(
monkeypatch.setattr(
QdrantVectorStore,
"for_summaries",
lambda ds_id: original(ds_id, qdrant_client),
lambda data_source_id: original(data_source_id, qdrant_client),
)


Expand Down

0 comments on commit d9e8dc8

Please sign in to comment.