diff --git a/nesis/api/core/document_loaders/loader_helper.py b/nesis/api/core/document_loaders/loader_helper.py index bfc3127..805d18c 100644 --- a/nesis/api/core/document_loaders/loader_helper.py +++ b/nesis/api/core/document_loaders/loader_helper.py @@ -1,20 +1,31 @@ -import uuid +import datetime import json -from nesis.api.core.models.entities import Document -from nesis.api.core.util.dateutil import strptime import logging -from nesis.api.core.services.util import get_document, delete_document +import uuid +from typing import Optional, Dict, Any, Callable + +import nesis.api.core.util.http as http +from nesis.api.core.document_loaders.runners import ( + IngestRunner, + ExtractRunner, + RagRunner, +) +from nesis.api.core.models.entities import Document, Datasource +from nesis.api.core.services.util import delete_document +from nesis.api.core.services.util import ( + get_document, +) +from nesis.api.core.util.constants import DEFAULT_DATETIME_FORMAT +from nesis.api.core.util.dateutil import strptime _LOG = logging.getLogger(__name__) def upload_document_to_llm(upload_document, file_metadata, rag_endpoint, http_client): - return _upload_document_to_pgpt( - upload_document, file_metadata, rag_endpoint, http_client - ) + return _upload_document(upload_document, file_metadata, rag_endpoint, http_client) -def _upload_document_to_pgpt(upload_document, file_metadata, rag_endpoint, http_client): +def _upload_document(upload_document, file_metadata, rag_endpoint, http_client): document_id = file_metadata["unique_id"] file_name = file_metadata["name"] @@ -51,3 +62,101 @@ def _upload_document_to_pgpt(upload_document, file_metadata, rag_endpoint, http_ request_object = {"file_name": file_name, "text": upload_document.page_content} response = http_client.post(url=f"{rag_endpoint}/v1/ingest", payload=request_object) return json.loads(response) + + +class DocumentProcessor(object): + def __init__( + self, + config, + http_client: http.HttpClient, + datasource: Datasource, + ): + self._datasource = datasource + + # This is left package public for testing + self._extract_runner: ExtractRunner = Optional[None] + _ingest_runner = IngestRunner(config=config, http_client=http_client) + + if self._datasource.connection.get("destination") is not None: + self._extract_runner = ExtractRunner( + config=config, + http_client=http_client, + destination=self._datasource.connection.get("destination"), + ) + + self._mode = self._datasource.connection.get("mode") or "ingest" + + match self._mode: + case "ingest": + self._ingest_runners: list[RagRunner] = [_ingest_runner] + case "extract": + self._ingest_runners: list[RagRunner] = [self._extract_runner] + case _: + raise ValueError( + f"Invalid mode {self._mode}. Expected 'ingest' or 'extract'" + ) + + def sync( + self, + endpoint: str, + file_path: str, + last_modified: datetime.datetime, + metadata: Dict[str, Any], + store_metadata: Dict[str, Any], + ) -> None: + """ + Here we check if this file has been updated. + If the file has been updated, we delete it from the vector store and re-ingest the new updated file + """ + document_id = str( + uuid.uuid5( + uuid.NAMESPACE_DNS, f"{self._datasource.uuid}/{metadata['self_link']}" + ) + ) + document: Document = get_document(document_id=document_id) + for _ingest_runner in self._ingest_runners: + try: + response_json = _ingest_runner.run( + file_path=file_path, + metadata=metadata, + document_id=None if document is None else document.uuid, + last_modified=last_modified.replace(tzinfo=None).replace( + microsecond=0 + ), + datasource=self._datasource, + ) + except ValueError: + _LOG.warning(f"File {file_path} ingestion failed", exc_info=True) + response_json = None + except UserWarning: + _LOG.warning(f"File {file_path} is already processing") + continue + + if response_json is None: + _LOG.warning("No response from ingest runner received") + continue + + _ingest_runner.save( + document_id=document_id, + datasource_id=self._datasource.uuid, + filename=store_metadata["filename"], + base_uri=endpoint, + rag_metadata=response_json, + store_metadata=store_metadata, + last_modified=last_modified, + ) + + def unsync(self, clean: Callable) -> None: + endpoint = self._datasource.connection.get("endpoint") + + for _ingest_runner in self._ingest_runners: + documents = _ingest_runner.get(base_uri=endpoint) + for document in documents: + store_metadata = document.store_metadata + try: + rag_metadata = document.rag_metadata + except AttributeError: + rag_metadata = document.extract_metadata + + if clean(store_metadata=store_metadata): + _ingest_runner.delete(document=document, rag_metadata=rag_metadata) diff --git a/nesis/api/core/document_loaders/minio.py b/nesis/api/core/document_loaders/minio.py index e391e58..aee6318 100644 --- a/nesis/api/core/document_loaders/minio.py +++ b/nesis/api/core/document_loaders/minio.py @@ -1,136 +1,25 @@ -import concurrent -import concurrent.futures import logging -import multiprocessing +import logging import os -import queue import tempfile -import uuid -from typing import Dict, Any, Optional +from typing import Dict, Any import memcache -import minio from minio import Minio import nesis.api.core.util.http as http -from nesis.api.core.document_loaders.runners import ( - IngestRunner, - ExtractRunner, - RagRunner, -) -from nesis.api.core.models.entities import Document, Datasource -from nesis.api.core.services.util import ( - get_document, - get_documents, -) +from nesis.api.core.document_loaders.loader_helper import DocumentProcessor +from nesis.api.core.models.entities import Datasource from nesis.api.core.util import clean_control, isblank from nesis.api.core.util.concurrency import ( IOBoundPool, as_completed, - BlockingThreadPoolExecutor, ) from nesis.api.core.util.constants import DEFAULT_DATETIME_FORMAT _LOG = logging.getLogger(__name__) -class DocumentProcessor(object): - def __init__( - self, - config, - http_client: http.HttpClient, - datasource: Datasource, - ): - self._datasource = datasource - - # This is left package public for testing - self._extract_runner: ExtractRunner = Optional[None] - _ingest_runner = IngestRunner(config=config, http_client=http_client) - - if self._datasource.connection.get("destination") is not None: - self._extract_runner = ExtractRunner( - config=config, - http_client=http_client, - destination=self._datasource.connection.get("destination"), - ) - - self._mode = self._datasource.connection.get("mode") or "ingest" - - match self._mode: - case "ingest": - self._ingest_runners: list[RagRunner] = [_ingest_runner] - case "extract": - self._ingest_runners: list[RagRunner] = [self._extract_runner] - case _: - raise ValueError( - f"Invalid mode {self._mode}. Expected 'ingest' or 'extract'" - ) - - def sync(self, endpoint, file_path, item, metadata): - """ - Here we check if this file has been updated. - If the file has been updated, we delete it from the vector store and re-ingest the new updated file - """ - document_id = str( - uuid.uuid5(uuid.NAMESPACE_DNS, f"{self._datasource.uuid}/{item.etag}") - ) - document: Document = get_document(document_id=document_id) - for _ingest_runner in self._ingest_runners: - try: - response_json = _ingest_runner.run( - file_path=file_path, - metadata=metadata, - document_id=None if document is None else document.uuid, - last_modified=item.last_modified.replace(tzinfo=None).replace( - microsecond=0 - ), - datasource=self._datasource, - ) - except ValueError: - _LOG.warning(f"File {file_path} ingestion failed", exc_info=True) - response_json = None - except UserWarning: - _LOG.warning(f"File {file_path} is already processing") - continue - - if response_json is None: - _LOG.warning("No response from ingest runner received") - continue - - _ingest_runner.save( - document_id=document_id, - datasource_id=self._datasource.uuid, - filename=item.object_name, - base_uri=endpoint, - rag_metadata=response_json, - store_metadata={ - "bucket_name": item.bucket_name, - "object_name": item.object_name, - "size": item.size, - "last_modified": item.last_modified.strftime( - DEFAULT_DATETIME_FORMAT - ), - "version_id": item.version_id, - }, - last_modified=item.last_modified, - ) - - def unsync(self, clean): - endpoint = self._datasource.connection.get("endpoint") - - for _ingest_runner in self._ingest_runners: - documents = _ingest_runner.get(base_uri=endpoint) - for document in documents: - store_metadata = document.store_metadata - try: - rag_metadata = document.rag_metadata - except AttributeError: - rag_metadata = document.extract_metadata - - if clean(store_metadata=store_metadata): - _ingest_runner.delete(document=document, rag_metadata=rag_metadata) - - class MinioProcessor(DocumentProcessor): def __init__( self, @@ -167,7 +56,6 @@ def run(self, metadata: Dict[str, Any]): ) self._unsync_documents( client=_minio_client, - connection=connection, ) except: _LOG.exception("Error fetching sharepoint documents") @@ -266,7 +154,22 @@ def _sync_document( file_path=file_path, ) - self.sync(endpoint, file_path, item, metadata) + self.sync( + endpoint, + file_path, + item.last_modified, + metadata, + store_metadata={ + "bucket_name": item.bucket_name, + "object_name": item.object_name, + "filename": item.object_name, + "size": item.size, + "last_modified": item.last_modified.strftime( + DEFAULT_DATETIME_FORMAT + ), + "version_id": item.version_id, + }, + ) _LOG.info( f"Done {self._mode}ing object {item.object_name} in bucket {bucket_name}" @@ -283,52 +186,25 @@ def _sync_document( def _unsync_documents( self, client: Minio, - connection: dict, ) -> None: - try: - # endpoint = connection.get("endpoint") - # - # for _ingest_runner in self._ingest_runners: - # documents = _ingest_runner.get(base_uri=endpoint) - # for document in documents: - # store_metadata = document.store_metadata - # try: - # rag_metadata = document.rag_metadata - # except AttributeError: - # rag_metadata = document.extract_metadata - # bucket_name = store_metadata["bucket_name"] - # object_name = store_metadata["object_name"] - # try: - # client.stat_object( - # bucket_name=bucket_name, object_name=object_name - # ) - # except Exception as ex: - # str_ex = str(ex) - # if "NoSuchKey" in str_ex and "does not exist" in str_ex: - # _ingest_runner.delete( - # document=document, rag_metadata=rag_metadata - # ) - # else: - # raise - - def clean(**kwargs): - store_metadata = kwargs["store_metadata"] - try: - client.stat_object( - bucket_name=store_metadata["bucket_name"], - object_name=store_metadata["object_name"], - ) - return False - except Exception as ex: - str_ex = str(ex) - if "NoSuchKey" in str_ex and "does not exist" in str_ex: - return True - else: - raise + def clean(**kwargs): + store_metadata = kwargs["store_metadata"] + try: + client.stat_object( + bucket_name=store_metadata["bucket_name"], + object_name=store_metadata["object_name"], + ) + return False + except Exception as ex: + str_ex = str(ex) + if "NoSuchKey" in str_ex and "does not exist" in str_ex: + return True + else: + raise + try: self.unsync(clean=clean) - except: _LOG.warning("Error fetching and updating documents", exc_info=True) diff --git a/nesis/api/core/document_loaders/s3.py b/nesis/api/core/document_loaders/s3.py index 7d6f274..df7f3e3 100644 --- a/nesis/api/core/document_loaders/s3.py +++ b/nesis/api/core/document_loaders/s3.py @@ -8,6 +8,7 @@ import memcache import nesis.api.core.util.http as http +from nesis.api.core.document_loaders.loader_helper import DocumentProcessor from nesis.api.core.models.entities import Document, Datasource from nesis.api.core.services import util from nesis.api.core.services.util import ( @@ -18,282 +19,208 @@ ingest_file, ) from nesis.api.core.util import clean_control, isblank +from nesis.api.core.util.constants import DEFAULT_DATETIME_FORMAT from nesis.api.core.util.dateutil import strptime _LOG = logging.getLogger(__name__) -def fetch_documents( - datasource: Datasource, - rag_endpoint: str, - http_client: http.HttpClient, - cache_client: memcache.Client, - metadata: Dict[str, Any], -) -> None: - try: - connection = datasource.connection - endpoint = connection.get("endpoint") - access_key = connection.get("user") - secret_key = connection.get("password") - region = connection.get("region") - if all([access_key, secret_key]): - if endpoint: - s3_client = boto3.client( - "s3", - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - region_name=region, - endpoint_url=endpoint, - ) - else: - s3_client = boto3.client( - "s3", - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - region_name=region, - ) - else: - if endpoint: - s3_client = boto3.client( - "s3", region_name=region, endpoint_url=endpoint - ) +class Processor(DocumentProcessor): + def __init__( + self, + config, + http_client: http.HttpClient, + cache_client: memcache.Client, + datasource: Datasource, + ): + super().__init__(config, http_client, datasource) + self._config = config + self._http_client = http_client + self._cache_client = cache_client + self._datasource = datasource + + def run(self, metadata: Dict[str, Any]): + connection: Dict[str, str] = self._datasource.connection + try: + endpoint = connection.get("endpoint") + access_key = connection.get("user") + secret_key = connection.get("password") + region = connection.get("region") + if all([access_key, secret_key]): + if endpoint: + s3_client = boto3.client( + "s3", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=region, + endpoint_url=endpoint, + ) + else: + s3_client = boto3.client( + "s3", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=region, + ) else: - s3_client = boto3.client("s3", region_name=region) - - _sync_documents( - client=s3_client, - datasource=datasource, - rag_endpoint=rag_endpoint, - http_client=http_client, - cache_client=cache_client, - metadata=metadata, - ) - _unsync_documents( - client=s3_client, - connection=connection, - rag_endpoint=rag_endpoint, - http_client=http_client, - ) - except Exception as ex: - _LOG.exception(f"Error fetching s3 documents - {ex}") - - -def _sync_documents( - client, - datasource: Datasource, - rag_endpoint: str, - http_client: http.HttpClient, - cache_client: memcache.Client, - metadata: dict, -) -> None: - - try: - - # Data objects allow us to specify bucket names - connection = datasource.connection - bucket_paths = connection.get("dataobjects") - if bucket_paths is None: - _LOG.warning("No bucket names supplied, so I can't do much") - - bucket_paths_parts = bucket_paths.split(",") - - _LOG.info(f"Initializing syncing to endpoint {rag_endpoint}") - - for bucket_path in bucket_paths_parts: - - # a/b/c/// should only give [a,b,c] - bucket_path_parts = [ - part for part in bucket_path.split("/") if len(part) != 0 - ] - - path = "/".join(bucket_path_parts[1:]) - bucket_name = bucket_path_parts[0] + if endpoint: + s3_client = boto3.client( + "s3", region_name=region, endpoint_url=endpoint + ) + else: + s3_client = boto3.client("s3", region_name=region) - paginator = client.get_paginator("list_objects_v2") - page_iterator = paginator.paginate( - Bucket=bucket_name, - Prefix="" if path == "" else f"{path}/", + self._sync_documents( + client=s3_client, + datasource=self._datasource, + metadata=metadata, ) - for result in page_iterator: - if result["KeyCount"] == 0: - continue - # iterate through files - for item in result["Contents"]: - # Paths ending in / are folders so we skip them - if item["Key"].endswith("/"): - continue - - endpoint = connection["endpoint"] - self_link = f"{endpoint}/{bucket_name}/{item['Key']}" - _metadata = { - **(metadata or {}), - "file_name": f"{bucket_name}/{item['Key']}", - "self_link": self_link, - } - - """ - We use memcache's add functionality to implement a shared lock to allow for multiple instances - operating - """ - _lock_key = clean_control(f"{__name__}/locks/{self_link}") - if cache_client.add(key=_lock_key, val=_lock_key, time=30 * 60): - try: - _sync_document( - client=client, - datasource=datasource, - rag_endpoint=rag_endpoint, - http_client=http_client, - metadata=_metadata, - bucket_name=bucket_name, - item=item, - ) - finally: - cache_client.delete(_lock_key) - else: - _LOG.info(f"Document {self_link} is already processing") + self._unsync_documents( + client=s3_client, + ) + except: + _LOG.exception("Error fetching sharepoint documents") - _LOG.info(f"Completed syncing to endpoint {rag_endpoint}") + def _sync_documents( + self, + client, + datasource: Datasource, + metadata: dict, + ) -> None: - except: - _LOG.warning("Error fetching and updating documents", exc_info=True) + try: + # Data objects allow us to specify bucket names + connection = datasource.connection + bucket_paths = connection.get("dataobjects") + if bucket_paths is None: + _LOG.warning("No bucket names supplied, so I can't do much") -def _sync_document( - client, - datasource: Datasource, - rag_endpoint: str, - http_client: http.HttpClient, - metadata: dict, - bucket_name: str, - item, -): - connection = datasource.connection - endpoint = connection["endpoint"] - _metadata = metadata + bucket_paths_parts = bucket_paths.split(",") - with tempfile.NamedTemporaryFile( - dir=tempfile.gettempdir(), - ) as tmp: - key_parts = item["Key"].split("/") + for bucket_path in bucket_paths_parts: - path_to_tmp = f"{str(pathlib.Path(tmp.name).absolute())}-{key_parts[-1]}" + # a/b/c/// should only give [a,b,c] + bucket_path_parts = [ + part for part in bucket_path.split("/") if len(part) != 0 + ] - try: - _LOG.info(f"Starting syncing object {item['Key']} in bucket {bucket_name}") - # Write item to file - client.download_file(bucket_name, item["Key"], path_to_tmp) + path = "/".join(bucket_path_parts[1:]) + bucket_name = bucket_path_parts[0] - document: Document = get_document(document_id=item["ETag"]) - if document and document.base_uri == endpoint: - store_metadata = document.store_metadata - if store_metadata and store_metadata.get("last_modified"): - last_modified = store_metadata["last_modified"] - if not strptime(date_string=last_modified).replace( - tzinfo=None - ) < item["LastModified"].replace(tzinfo=None).replace( - microsecond=0 - ): - _LOG.debug( - f"Skipping document {item['Key']} already up to date" - ) - return - rag_metadata: dict = document.rag_metadata - if rag_metadata is None: - return - for document_data in rag_metadata.get("data") or []: - try: - util.un_ingest_file( - http_client=http_client, - endpoint=rag_endpoint, - doc_id=document_data["doc_id"], - ) - except: - _LOG.warning( - f"Failed to delete document {document_data['doc_id']}" - ) - - try: - delete_document(document_id=document.id) - except: - _LOG.warning( - f"Failed to delete document {item.object_name}'s record. Continuing anyway..." - ) + paginator = client.get_paginator("list_objects_v2") + page_iterator = paginator.paginate( + Bucket=bucket_name, + Prefix="" if path == "" else f"{path}/", + ) + for result in page_iterator: + if result["KeyCount"] == 0: + continue + # iterate through files + for item in result["Contents"]: + # Paths ending in / are folders so we skip them + if item["Key"].endswith("/"): + continue + + endpoint = connection["endpoint"] + self_link = f"{endpoint}/{bucket_name}/{item['Key']}" + _metadata = { + **(metadata or {}), + "file_name": f"{bucket_name}/{item['Key']}", + "self_link": self_link, + } + + """ + We use memcache's add functionality to implement a shared lock to allow for multiple instances + operating + """ + _lock_key = clean_control(f"{__name__}/locks/{self_link}") + if self._cache_client.add( + key=_lock_key, val=_lock_key, time=30 * 60 + ): + try: + self._sync_document( + client=client, + datasource=datasource, + metadata=_metadata, + bucket_name=bucket_name, + item=item, + ) + finally: + self._cache_client.delete(_lock_key) + else: + _LOG.info(f"Document {self_link} is already processing") + + except: + _LOG.warning("Error fetching and updating documents", exc_info=True) + + def _sync_document( + self, + client, + datasource: Datasource, + metadata: dict, + bucket_name: str, + item, + ): + endpoint = datasource.connection["endpoint"] + _metadata = metadata + + with tempfile.NamedTemporaryFile( + dir=tempfile.gettempdir(), + ) as tmp: + key_parts = item["Key"].split("/") + + path_to_tmp = f"{str(pathlib.Path(tmp.name).absolute())}-{key_parts[-1]}" try: - response = ingest_file( - http_client=http_client, - endpoint=rag_endpoint, - metadata=_metadata, - file_path=path_to_tmp, + _LOG.info( + f"Starting syncing object {item['Key']} in bucket {bucket_name}" + ) + # Write item to file + client.download_file(bucket_name, item["Key"], path_to_tmp) + self.sync( + endpoint, + path_to_tmp, + last_modified=item["LastModified"], + metadata=metadata, + store_metadata={ + "bucket_name": bucket_name, + "object_name": item["Key"], + "filename": item["Key"], + "size": item["Size"], + "last_modified": item["LastModified"].strftime( + DEFAULT_DATETIME_FORMAT + ), + }, ) - response_json = json.loads(response) - - except ValueError: - _LOG.warning(f"File {path_to_tmp} ingestion failed", exc_info=True) - response_json = {} - except UserWarning: - _LOG.debug(f"File {path_to_tmp} is already processing") - return - - save_document( - document_id=item["ETag"], - filename=item["Key"], - datasource_id=datasource.uuid, - base_uri=endpoint, - rag_metadata=response_json, - store_metadata={ - "bucket_name": bucket_name, - "object_name": item["Key"], - "etag": item["ETag"], - "size": item["Size"], - "last_modified": str(item["LastModified"]), - }, - last_modified=item["LastModified"], - ) - - _LOG.info(f"Done syncing object {item['Key']} in bucket {bucket_name}") - except Exception as ex: - _LOG.warning( - f"Error when getting and ingesting document {item['Key']} - {ex}" - ) - - -def _unsync_documents( - client, connection: dict, rag_endpoint: str, http_client: http.HttpClient -) -> None: - try: - endpoint = connection.get("endpoint") + _LOG.info(f"Done syncing object {item['Key']} in bucket {bucket_name}") + except Exception as ex: + _LOG.warning( + f"Error when getting and ingesting document {item['Key']} - {ex}", + exc_info=True, + ) - documents = get_documents(base_uri=endpoint) - for document in documents: - store_metadata = document.store_metadata - rag_metadata = document.rag_metadata - bucket_name = store_metadata["bucket_name"] - object_name = store_metadata["object_name"] + def _unsync_documents(self, client) -> None: + def clean(**kwargs): + store_metadata = kwargs["store_metadata"] try: - client.head_object(Bucket=bucket_name, Key=object_name) + client.head_object( + Bucket=store_metadata["bucket_name"], + Key=store_metadata["object_name"], + ) + return False except Exception as ex: - str_ex = str(ex).lower() + str_ex = str(ex) if not ("object" in str_ex and "not found" in str_ex): + return True + else: raise - try: - http_client.deletes( - urls=[ - f"{rag_endpoint}/v1/ingest/documents/{document_data['doc_id']}" - for document_data in rag_metadata.get("data") or [] - ] - ) - _LOG.info(f"Deleting document {document.filename}") - delete_document(document_id=document.id) - except: - _LOG.warning( - f"Failed to delete document {document.filename}", - exc_info=True, - ) - except: - _LOG.warn("Error fetching and updating documents", exc_info=True) + try: + self.unsync(clean=clean) + except: + _LOG.warning("Error fetching and updating documents", exc_info=True) def validate_connection_info(connection: Dict[str, Any]) -> Dict[str, Any]: diff --git a/nesis/api/tests/core/document_loaders/test_minio.py b/nesis/api/tests/core/document_loaders/test_minio.py index b84ee43..1a074bc 100644 --- a/nesis/api/tests/core/document_loaders/test_minio.py +++ b/nesis/api/tests/core/document_loaders/test_minio.py @@ -400,13 +400,15 @@ def test_update_ingest_documents( session.add(datasource) session.commit() + self_link = "http://localhost:4566/my-test-bucket/SomeName" + # The document record document = Document( base_uri="http://localhost:4566", document_id=str( uuid.uuid5( uuid.NAMESPACE_DNS, - f"{datasource.uuid}/d41d8cd98f00b204e9800998ecf8427e", + f"{datasource.uuid}/{self_link}", ) ), filename="invalid.pdf", @@ -453,8 +455,8 @@ def test_update_ingest_documents( ) # The document would be deleted from the rag engine - _, upload_kwargs = http_client.deletes.call_args_list[0] - urls = upload_kwargs["urls"] + _, deletes_kwargs = http_client.deletes.call_args_list[0] + urls = deletes_kwargs["urls"] assert ( urls[0] @@ -478,7 +480,7 @@ def test_update_ingest_documents( { "datasource": "documents", "file_name": "my-test-bucket/SomeName", - "self_link": "http://localhost:4566/my-test-bucket/SomeName", + "self_link": self_link, }, ) diff --git a/nesis/api/tests/core/document_loaders/test_s3.py b/nesis/api/tests/core/document_loaders/test_s3.py index d95f50d..f4cfe73 100644 --- a/nesis/api/tests/core/document_loaders/test_s3.py +++ b/nesis/api/tests/core/document_loaders/test_s3.py @@ -107,12 +107,14 @@ def test_sync_documents( ] s3_client.get_paginator.return_value = paginator - s3.fetch_documents( - datasource=datasource, + ingestor = s3.Processor( + config=tests.config, http_client=http_client, - metadata={"datasource": "documents"}, - rag_endpoint="http://localhost:8080", cache_client=cache, + datasource=datasource, + ) + ingestor.run( + metadata={"datasource": "documents"}, ) _, upload_kwargs = http_client.upload.call_args_list[0] @@ -150,7 +152,7 @@ def test_update_sync_documents( "connection": { "endpoint": "http://localhost:4566", "region": "us-east-1", - "dataobjects": "my-test-bucket", + "dataobjects": "some-bucket", }, } @@ -164,17 +166,26 @@ def test_update_sync_documents( session.add(datasource) session.commit() + self_link = "http://localhost:4566/some-bucket/invalid.pdf" + + # The document record document = Document( base_uri="http://localhost:4566", - document_id="d41d8cd98f00b204e9800998ecf8427e", + document_id=str( + uuid.uuid5( + uuid.NAMESPACE_DNS, + f"{datasource.uuid}/{self_link}", + ) + ), filename="invalid.pdf", rag_metadata={"data": [{"doc_id": str(uuid.uuid4())}]}, store_metadata={ "bucket_name": "some-bucket", - "object_name": "file/path.pdf", + "object_name": "invalid.pdf", "last_modified": "2023-07-18 06:40:07", }, - last_modified=datetime.datetime.utcnow(), + last_modified=strptime("2023-07-19 06:40:07"), + datasource_id=datasource.uuid, ) session.add(document) @@ -191,8 +202,8 @@ def test_update_sync_documents( "KeyCount": 1, "Contents": [ { - "Key": "image.jpg", - "LastModified": strptime("2023-07-19 06:40:07"), + "Key": "invalid.pdf", + "LastModified": strptime("2023-07-20 06:40:07"), "ETag": "d41d8cd98f00b204e9800998ecf8427e", "Size": 0, "StorageClass": "STANDARD", @@ -206,22 +217,24 @@ def test_update_sync_documents( ] s3_client.get_paginator.return_value = paginator - s3.fetch_documents( - datasource=datasource, + ingestor = s3.Processor( + config=tests.config, http_client=http_client, - metadata={"datasource": "documents"}, - rag_endpoint="http://localhost:8080", cache_client=cache, + datasource=datasource, + ) + + ingestor.run( + metadata={"datasource": "documents"}, ) # The document would be deleted from the rag engine - _, upload_kwargs = http_client.delete.call_args_list[0] - url = upload_kwargs["url"] + _, deletes_kwargs = http_client.deletes.call_args_list[0] + url = deletes_kwargs["urls"] - assert ( - url - == f"http://localhost:8080/v1/ingest/documents/{document.rag_metadata['data'][0]['doc_id']}" - ) + assert url == [ + f"http://localhost:8080/v1/ingest/documents/{document.rag_metadata['data'][0]['doc_id']}" + ] # And then re-ingested _, upload_kwargs = http_client.upload.call_args_list[0] @@ -231,21 +244,22 @@ def test_update_sync_documents( field = upload_kwargs["field"] assert url == f"http://localhost:8080/v1/ingest/files" - assert file_path.endswith("image.jpg") + assert file_path.endswith("invalid.pdf") assert field == "file" ut.TestCase().assertDictEqual( metadata, { "datasource": "documents", - "file_name": "my-test-bucket/image.jpg", - "self_link": "http://localhost:4566/my-test-bucket/image.jpg", + "file_name": "some-bucket/invalid.pdf", + "self_link": self_link, }, ) # The document has now been updated documents = session.query(Document).all() assert len(documents) == 1 - assert documents[0].store_metadata["last_modified"] == "2023-07-19 06:40:07" + assert documents[0].store_metadata["last_modified"] == "2023-07-20 06:40:07" + assert str(documents[0].last_modified) == "2023-07-20 06:40:07" @mock.patch("nesis.api.core.document_loaders.s3.boto3.client") @@ -260,8 +274,6 @@ def test_unsync_s3_documents( "engine": "s3", "connection": { "endpoint": "http://localhost:4566", - # "user": "test", - # "password": "test", "region": "us-east-1", "dataobjects": "some-non-existing-bucket", }, @@ -299,12 +311,15 @@ def test_unsync_s3_documents( documents = session.query(Document).all() assert len(documents) == 1 - s3.fetch_documents( - datasource=datasource, + ingestor = s3.Processor( + config=tests.config, http_client=http_client, - metadata={"datasource": "documents"}, - rag_endpoint="http://localhost:8080", cache_client=cache, + datasource=datasource, + ) + + ingestor.run( + metadata={"datasource": "documents"}, ) _, upload_kwargs = http_client.deletes.call_args_list[0]