diff --git a/llm-service/app/ai/indexing/summary_indexer.py b/llm-service/app/ai/indexing/summary_indexer.py index 8876f42d..e70ae7da 100644 --- a/llm-service/app/ai/indexing/summary_indexer.py +++ b/llm-service/app/ai/indexing/summary_indexer.py @@ -49,17 +49,20 @@ get_response_synthesizer, load_index_from_storage, ) +from llama_index.core.base.embeddings.base import BaseEmbedding from llama_index.core.llms import LLM from llama_index.core.node_parser import SentenceSplitter from llama_index.core.response_synthesizers import ResponseMode from llama_index.core.schema import ( Document, - NodeRelationship, + NodeRelationship, TextNode, ) +from qdrant_client.http.exceptions import UnexpectedResponse -from app.services.models import get_noop_embedding_model, get_noop_llm_model +from app.services.models import get_noop_llm_model, get_noop_embedding_model from .base import BaseTextIndexer from .readers.base_reader import ReaderConfig, ChunksResult +from ..vector_stores.qdrant import QdrantVectorStore from ...config import Settings logger = logging.getLogger(__name__) @@ -79,11 +82,13 @@ def __init__( data_source_id: int, splitter: SentenceSplitter, llm: LLM, + embedding_model: BaseEmbedding, reader_config: Optional[ReaderConfig] = None, ): super().__init__(data_source_id, reader_config=reader_config) self.splitter = splitter self.llm = llm + self.embedding_model = embedding_model @staticmethod def __database_dir(data_source_id: int) -> str: @@ -97,10 +102,10 @@ def __persist_root_dir() -> str: return os.path.join(Settings().rag_databases_dir, "doc_summary_index_global") def __index_kwargs(self) -> Dict[str, Any]: - return SummaryIndexer.__index_configuration(self.llm) + return SummaryIndexer.__index_configuration(self.llm, self.embedding_model, self.data_source_id) @staticmethod - def __index_configuration(llm: LLM) -> Dict[str, Any]: + def __index_configuration(llm: LLM, embedding_model: BaseEmbedding, data_source_id: int) -> Dict[str, Any]: return { "llm": llm, "response_synthesizer": get_response_synthesizer( @@ -110,9 +115,10 @@ def __index_configuration(llm: LLM) -> Dict[str, Any]: verbose=True, ), "show_progress": True, - "embed_model": get_noop_embedding_model(), + "embed_model": embedding_model, "embed_summaries": False, "summary_query": SUMMARY_PROMPT, + "data_source_id": data_source_id, } def __init_summary_store(self, persist_dir: str) -> DocumentSummaryIndex: @@ -135,8 +141,10 @@ def __summary_indexer(self, persist_dir: str) -> DocumentSummaryIndex: @staticmethod def __summary_indexer_with_config(persist_dir: str, index_configuration: Dict[str, Any]) -> DocumentSummaryIndex: + data_source_id: int = index_configuration.get("data_source_id") storage_context = StorageContext.from_defaults( persist_dir=persist_dir, + vector_store=QdrantVectorStore.for_summaries(data_source_id).llama_vector_store() ) doc_summary_index: DocumentSummaryIndex = cast( DocumentSummaryIndex, @@ -171,6 +179,14 @@ def index_file(self, file_path: Path, document_id: str) -> None: persist_dir = self.__persist_dir() summary_store = self.__summary_indexer(persist_dir) summary_store.insert_nodes(chunks.chunks) + summary = summary_store.get_document_summary(document_id) + + summary_node = TextNode() + summary_node.embedding = self.embedding_model.get_text_embedding(summary) + summary_node.text = summary + summary_node.relationships[NodeRelationship.SOURCE] = Document(doc_id=document_id).as_related_node_info() + summary_node.metadata["document_id"] = document_id + summary_store.vector_store.add(nodes=[summary_node]) summary_store.storage_context.persist(persist_dir=persist_dir) self.__update_global_summary_store(summary_store, added_node_id=document_id) @@ -234,7 +250,8 @@ def __update_global_summary_store( # Delete first so that we don't accumulate trash in the summary store. try: global_summary_store.delete_ref_doc(str(self.data_source_id), delete_from_docstore=True) - except KeyError: + except (KeyError, UnexpectedResponse): + # UnexpectedResponse is raised when the collection doesn't exist, which is fine, since it might be a new index. pass global_summary_store.insert_nodes(new_nodes) global_summary_store.storage_context.persist(persist_dir=global_persist_dir) @@ -253,7 +270,7 @@ def get_full_summary(self) -> Optional[str]: global_summary_store = self.__summary_indexer(global_persist_dir) document_id = str(self.data_source_id) if ( - document_id not in global_summary_store.index_struct.doc_id_to_summary_id + document_id not in global_summary_store.index_struct.doc_id_to_summary_id ): return None return global_summary_store.get_document_summary(document_id) @@ -269,6 +286,7 @@ def delete_document(self, document_id: str) -> None: summary_store.delete_ref_doc(document_id, delete_from_docstore=True) summary_store.storage_context.persist(persist_dir=persist_dir) + summary_store.vector_store.delete(document_id) def delete_data_source(self) -> None: with _write_lock: @@ -277,13 +295,16 @@ def delete_data_source(self) -> None: @staticmethod def delete_data_source_by_id(data_source_id: int) -> None: with _write_lock: + vector_store = QdrantVectorStore.for_summaries(data_source_id) + vector_store.delete() # TODO: figure out a less explosive way to do this. shutil.rmtree(SummaryIndexer.__database_dir(data_source_id), ignore_errors=True) - global_persist_dir = SummaryIndexer.__persist_root_dir() + global_persist_dir: str = SummaryIndexer.__persist_root_dir() try: - global_summary_store = SummaryIndexer.__summary_indexer_with_config(global_persist_dir, - SummaryIndexer.__index_configuration( - get_noop_llm_model())) + configuration: Dict[str, Any] = SummaryIndexer.__index_configuration(get_noop_llm_model(), + get_noop_embedding_model(), + data_source_id=data_source_id) + global_summary_store = SummaryIndexer.__summary_indexer_with_config(global_persist_dir, configuration) except FileNotFoundError: ## global summary store doesn't exist, nothing to do return diff --git a/llm-service/app/routers/index/data_source/__init__.py b/llm-service/app/routers/index/data_source/__init__.py index ce928b68..85c45520 100644 --- a/llm-service/app/routers/index/data_source/__init__.py +++ b/llm-service/app/routers/index/data_source/__init__.py @@ -91,6 +91,7 @@ def _get_summary_indexer(data_source_id: int) -> Optional[SummaryIndexer]: return SummaryIndexer( data_source_id=data_source_id, splitter=SentenceSplitter(chunk_size=2048), + embedding_model=models.get_embedding_model(datasource.embedding_model), llm=models.get_llm(datasource.summarization_model), ) diff --git a/llm-service/app/tests/conftest.py b/llm-service/app/tests/conftest.py index 4b5992fd..4fd02592 100644 --- a/llm-service/app/tests/conftest.py +++ b/llm-service/app/tests/conftest.py @@ -139,7 +139,7 @@ def summary_vector_store( monkeypatch.setattr( QdrantVectorStore, "for_summaries", - lambda ds_id: original(ds_id, qdrant_client), + lambda data_source_id: original(data_source_id, qdrant_client), )