From 3f434ae8fa5fec86135bed63eb33060fb6ede9f7 Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Thu, 6 Feb 2025 13:44:08 -0800 Subject: [PATCH 1/8] Ensure parent docs are collected during doc reconstruct --- .../opensearch/opensearch_reader.py | 11 +- .../opensearch/test_opensearch_read.py | 146 ++++++++++++++++-- 2 files changed, 143 insertions(+), 14 deletions(-) diff --git a/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py b/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py index 0e5ebb431..b1ba3325c 100644 --- a/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py +++ b/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py @@ -333,8 +333,17 @@ def _to_parent_doc(self, slice_query: dict[str, Any]) -> List[dict[str, Any]]: and hit["_source"]["parent_id"] is not None and hit["_source"]["parent_id"] not in parent_ids ): - results.append(hit) + # Only add a child doc whose parent_id has not been found, yet. parent_ids.add(hit["_source"]["parent_id"]) + results.append(hit) + elif ("parent_id" not in hit["_source"] or hit["_source"]["parent_id"] is None) and hit[ + "_id" + ] not in parent_ids: + # Add a parent doc if it's a match. + parent_id = hit["_id"] + parent_ids.add(parent_id) + hit["_source"]["parent_id"] = parent_id + results.append(hit) page += 1 diff --git a/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py b/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py index 15845ba73..de44f0fb5 100644 --- a/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py +++ b/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py @@ -78,19 +78,6 @@ def get_doc_count(os_client, index_name: str, query: Optional[Dict[str, Any]] = return res["count"] -""" -class MockLLM(LLM): - def __init__(self): - super().__init__(model_name="mock_model") - - def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: - return str(uuid.uuid4()) - - def is_chat_mode(self): - return True -""" - - class TestOpenSearchRead: INDEX_SETTINGS = { "body": { @@ -575,6 +562,139 @@ def test_parallel_query_with_pit(self, setup_index_large, os_client): if doc.parent_id is not None: assert doc.parent_id == expected_docs[doc.doc_id]["parent_id"] + def test_parallel_query_on_property_with_pit(self, setup_index, os_client): + context = sycamore.init(exec_mode=ExecMode.RAY) + key = "property1" + hidden = str(uuid.uuid4()) + query = {"query": {"match": {f"properties.{key}": hidden}}} + # make sure we read from pickle files -- this part won't be written into opensearch. + dicts = [ + { + "doc_id": "1", + "properties": {key: hidden}, + "elements": [ + {"properties": {"_element_index": 1}, "text_representation": "here is an animal that meows"}, + ], + }, + { + "doc_id": "2", + "elements": [ + {"id": 7, "properties": {"_element_index": 7}, "text_representation": "this is a cat"}, + { + "id": 1, + "properties": {"_element_index": 1}, + "text_representation": "here is an animal that moos", + }, + ], + }, + { + "doc_id": "3", + "elements": [ + {"properties": {"_element_index": 1}, "text_representation": "here is an animal that moos"}, + ], + }, + { + "doc_id": "4", + "elements": [ + {"id": 1, "properties": {"_element_index": 1}}, + ], + }, + { + "doc_id": "5", + "elements": [ + { + "properties": {"_element_index": 1}, + "text_representation": "the number of pages in this document are 253", + } + ], + }, + { + "doc_id": "6", + "elements": [ + {"id": 1, "properties": {"_element_index": 1}}, + ], + }, + ] + docs = [Document(item) for item in dicts] + + original_docs = ( + context.read.document(docs) + # .materialize(path={"root": cache_dir, "name": doc_to_name}) + .explode() + .write.opensearch( + os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, + index_name=setup_index, + index_settings=TestOpenSearchRead.INDEX_SETTINGS, + execute=False, + ) + .take_all() + ) + + os_client.indices.refresh(setup_index) + + expected_count = len(original_docs) + actual_count = get_doc_count(os_client, setup_index) + # refresh should have made all ingested docs immediately available for search + assert actual_count == expected_count, f"Expected {expected_count} documents, found {actual_count}" + + t0 = time.time() + retrieved_docs = context.read.opensearch( + os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, + index_name=setup_index, + query=query, + reconstruct_document=True, + ).take_all() + t1 = time.time() + + print(f"Retrieved {len(retrieved_docs)} documents in {t1 - t0} seconds") + expected_docs = self.get_ids(os_client, setup_index, True, query) + assert len(retrieved_docs) == len(expected_docs) + assert "1" == retrieved_docs[0].doc_id + assert hidden == retrieved_docs[0].properties[key] + + def test_parallel_query_on_extracted_property_with_pit(self, setup_index, os_client): + + path = str(TEST_DIR / "resources/data/pdfs/Ray.pdf") + context = sycamore.init(exec_mode=ExecMode.RAY) + llm = OpenAI(OpenAIModels.GPT_4O_MINI) + extractor = OpenAIEntityExtractor("title", llm=llm) + original_docs = ( + context.read.binary(path, binary_format="pdf") + .partition(ArynPartitioner(aryn_api_key=ARYN_API_KEY)) + .extract_entity(extractor) + # .materialize(path={"root": materialized_dir, "name": doc_to_name}) + .explode() + .write.opensearch( + os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, + index_name=setup_index, + index_settings=TestOpenSearchRead.INDEX_SETTINGS, + execute=False, + ) + .take_all() + ) + + os_client.indices.refresh(setup_index) + + expected_count = len(original_docs) + actual_count = get_doc_count(os_client, setup_index) + # refresh should have made all ingested docs immediately available for search + assert actual_count == expected_count, f"Expected {expected_count} documents, found {actual_count}" + + query = {"query": {"match": {"properties.title": "ray"}}} + + t0 = time.time() + retrieved_docs = context.read.opensearch( + os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, + index_name=setup_index, + query=query, + reconstruct_document=True, + ).take_all() + t1 = time.time() + + print(f"Retrieved {len(retrieved_docs)} documents in {t1 - t0} seconds") + expected_docs = self.get_ids(os_client, setup_index, True, query) + assert len(retrieved_docs) == len(expected_docs) + @staticmethod def get_ids( os_client, index_name, parents_only: bool = False, query: Optional[Dict[str, Any]] = None From a80fc95afb1e3b5f47c811cd9543407f23032d92 Mon Sep 17 00:00:00 2001 From: Dhruv Kaliraman <112497058+dhruvkaliraman7@users.noreply.github.com> Date: Thu, 6 Feb 2025 16:14:43 -0800 Subject: [PATCH 2/8] mock using patch (#1160) --- .../sycamore/tests/unit/test_materialize.py | 73 ++++++++++--------- 1 file changed, 38 insertions(+), 35 deletions(-) diff --git a/lib/sycamore/sycamore/tests/unit/test_materialize.py b/lib/sycamore/sycamore/tests/unit/test_materialize.py index 179ba51a7..1b37632f1 100644 --- a/lib/sycamore/sycamore/tests/unit/test_materialize.py +++ b/lib/sycamore/sycamore/tests/unit/test_materialize.py @@ -703,47 +703,50 @@ def failing_map(doc): assert retry_counter.x == 23 # 2 successful, 21 unsuccessful def test_mrr_path_handling(self): + from unittest.mock import patch from pyarrow.fs import S3FileSystem, LocalFileSystem from sycamore.docset import DocSet - """Test MaterializeReadReliability path handling for both local and S3 paths""" ctx = sycamore.init(exec_mode=ExecMode.LOCAL) mrr = MaterializeReadReliability(max_batch=3) mrr._refresh_seen_files = lambda: None mrr.seen = set() ctx.rewrite_rules.append(mrr) - # Test various path formats - test_cases = [ - # Local paths - {"path": "/tmp/local/path", "expected_fs": "LocalFileSystem"}, - {"path": Path("/tmp/local/path2"), "expected_fs": "LocalFileSystem"}, - {"path": {"root": "/tmp/local/path3"}, "expected_fs": "LocalFileSystem"}, - {"path": {"root": Path("/tmp/local/path4")}, "expected_fs": "LocalFileSystem"}, - # S3 paths - {"path": "s3://test-example/path", "should_execute": True, "expected_fs": "S3FileSystem"}, - {"path": {"root": "s3://test-example/a/path"}, "should_execute": True, "expected_fs": "S3FileSystem"}, - ] - - MaterializeReadReliability.execute_reliably = lambda context, plan, mrr, **kwargs: None - for case in test_cases: - # Create a dummy materialize plan - plan = Materialize(None, ctx, path=case["path"]) - - # Test should_execute_reliably - - MaterializeReadReliability.maybe_execute_reliably(DocSet(context=ctx, plan=plan)) - - # Verify the path was properly initialized in mrr_instance - assert hasattr(mrr, "path"), f"mrr_instance missing path attribute for {case['path']}" - assert hasattr(mrr, "fs"), f"mrr_instance missing fs attribute for {case['path']}" - - # Verify correct filesystem type - if case["expected_fs"] == "S3FileSystem": - assert isinstance( - mrr.fs, S3FileSystem - ), f"Expected S3FileSystem for path {case['path']}, got {type(mrr.fs)}" - else: - assert isinstance( - mrr.fs, LocalFileSystem - ), f"Expected LocalFileSystem for path {case['path']}, got {type(mrr.fs)}" + # Use patch instead of modifying class + with patch.object(MaterializeReadReliability, "execute_reliably", return_value=None): + + # Test various path formats + test_cases = [ + # Local paths + {"path": "/tmp/local/path", "expected_fs": "LocalFileSystem"}, + {"path": Path("/tmp/local/path2"), "expected_fs": "LocalFileSystem"}, + {"path": {"root": "/tmp/local/path3"}, "expected_fs": "LocalFileSystem"}, + {"path": {"root": Path("/tmp/local/path4")}, "expected_fs": "LocalFileSystem"}, + # S3 paths + {"path": "s3://test-example/path", "should_execute": True, "expected_fs": "S3FileSystem"}, + {"path": {"root": "s3://test-example/a/path"}, "should_execute": True, "expected_fs": "S3FileSystem"}, + ] + + MaterializeReadReliability.execute_reliably = lambda context, plan, mrr, **kwargs: None + for case in test_cases: + # Create a dummy materialize plan + plan = Materialize(None, ctx, path=case["path"]) + + # Test should_execute_reliably + + MaterializeReadReliability.maybe_execute_reliably(DocSet(context=ctx, plan=plan)) + + # Verify the path was properly initialized in mrr_instance + assert hasattr(mrr, "path"), f"mrr_instance missing path attribute for {case['path']}" + assert hasattr(mrr, "fs"), f"mrr_instance missing fs attribute for {case['path']}" + + # Verify correct filesystem type + if case["expected_fs"] == "S3FileSystem": + assert isinstance( + mrr.fs, S3FileSystem + ), f"Expected S3FileSystem for path {case['path']}, got {type(mrr.fs)}" + else: + assert isinstance( + mrr.fs, LocalFileSystem + ), f"Expected LocalFileSystem for path {case['path']}, got {type(mrr.fs)}" From 9b31927cb81f26f16286888aa8d430755728917e Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Sat, 8 Feb 2025 14:26:08 -0800 Subject: [PATCH 3/8] Add a client for Aryn, use the new client to read docs --- .../sycamore/connectors/aryn/ArynReader.py | 115 +++++++++++++++--- .../sycamore/connectors/aryn/client.py | 28 +++++ .../connectors/aryn/test_client.py | 22 ++++ lib/sycamore/sycamore/writer.py | 15 ++- 4 files changed, 156 insertions(+), 24 deletions(-) create mode 100644 lib/sycamore/sycamore/connectors/aryn/client.py create mode 100644 lib/sycamore/sycamore/tests/integration/connectors/aryn/test_client.py diff --git a/lib/sycamore/sycamore/connectors/aryn/ArynReader.py b/lib/sycamore/sycamore/connectors/aryn/ArynReader.py index 44e528a12..658ec0764 100644 --- a/lib/sycamore/sycamore/connectors/aryn/ArynReader.py +++ b/lib/sycamore/sycamore/connectors/aryn/ArynReader.py @@ -1,14 +1,23 @@ +import io import json +import struct from dataclasses import dataclass -from typing import Any +from time import time +from typing import Any, TYPE_CHECKING -import requests -from requests import Response +import httpx +import pandas + +from sycamore.connectors.aryn.client import ArynClient +# import requests +# from requests import Response from sycamore.connectors.base_reader import BaseDBReader from sycamore.data import Document from sycamore.data.element import create_element +if TYPE_CHECKING: + from ray.data import Dataset @dataclass class ArynClientParams(BaseDBReader.ClientParams): @@ -41,39 +50,109 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]: return docs -class ArynClient(BaseDBReader.Client): - def __init__(self, client_params: ArynClientParams, **kwargs): +class ArynReaderClient(BaseDBReader.Client): + def __init__(self, client: ArynClient, client_params: ArynClientParams, **kwargs): self.aryn_url = client_params.aryn_url self.api_key = client_params.api_key + self._client = client self.kwargs = kwargs def read_records(self, query_params: "BaseDBReader.QueryParams") -> "ArynQueryResponse": assert isinstance(query_params, ArynQueryParams) headers = {"Authorization": f"Bearer {self.api_key}"} - response: Response = requests.post( - f"{self.aryn_url}/docsets/{query_params.docset_id}/read", stream=True, headers=headers - ) - assert response.status_code == 200 - docs = [] - print(f"Reading from docset: {query_params.docset_id}") - for chunk in response.iter_lines(): - # print(f"\n{chunk}\n") - doc = json.loads(chunk) - docs.append(doc) + client = httpx.Client() + with client.stream("POST", f"{self.aryn_url}/docsets/{query_params.docset_id}/read", headers=headers) as response: + + docs = [] + print(f"Reading from docset: {query_params.docset_id}") + buffer = io.BytesIO() + to_read = 0 + start_new_doc = True + doc_size_buf = bytearray(4) + idx = 0 + chunk_count = 0 + t0 = time() + for chunk in response.iter_bytes(): + cur_pos = 0 + chunk_count += 1 + remaining = len(chunk) + print(f"Chunk {chunk_count} size: {len(chunk)}") + assert len(chunk) >= 4, f"Chunk too small: {len(chunk)} < 4" + while cur_pos < len(chunk): + if start_new_doc: + doc_size_buf[idx:] = chunk[cur_pos:cur_pos + 4 - idx] + to_read = struct.unpack('!i', doc_size_buf)[0] + print(f"Reading doc of size: {to_read}") + doc_size_buf = bytearray(4) + idx = 0 + cur_pos += 4 + remaining = len(chunk) - cur_pos + start_new_doc = False + if to_read > remaining: + buffer.write(chunk[cur_pos:]) + to_read -= remaining + print(f"Remaining to read: {to_read}") + # Read the next chunk + break + else: + print("Reading the rest of the doc from the chunk") + buffer.write(chunk[cur_pos:cur_pos + to_read]) + docs.append(json.loads(buffer.getvalue().decode())) + buffer.flush() + buffer.seek(0) + cur_pos += to_read + to_read = 0 + start_new_doc = True + if (cur_pos - len(chunk)) < 4: + idx = left_over = cur_pos - len(chunk) + doc_size_buf[:left_over] = chunk[cur_pos:] + # Need to get the rest of the next chunk + break + + t1 = time() + print(f"Reading took: {t1 - t0} seconds") return ArynQueryResponse(docs) def check_target_presence(self, query_params: "BaseDBReader.QueryParams") -> bool: return True @classmethod - def from_client_params(cls, params: "BaseDBReader.ClientParams") -> "ArynClient": + def from_client_params(cls, params: "BaseDBReader.ClientParams") -> "ArynReaderClient": assert isinstance(params, ArynClientParams) - return cls(params) + client = ArynClient(params.aryn_url, params.api_key) + return cls(client, params) class ArynReader(BaseDBReader): - Client = ArynClient + Client = ArynReaderClient Record = ArynQueryResponse ClientParams = ArynClientParams QueryParams = ArynQueryParams + + def __init__( + self, + client_params: ArynClientParams, + query_params: ArynQueryParams, + **kwargs, + ): + super().__init__(client_params=client_params, query_params=query_params, **kwargs) + + def _to_doc(self, doc: dict[str, Any]) -> dict[str, Any]: + elements = doc.get("elements", []) + doc = Document(**doc) + doc.data["elements"] = [create_element(**element) for element in elements] + return {"doc": Document.serialize(doc)} + + def execute(self, **kwargs) -> "Dataset": + + assert isinstance(self._client_params, ArynClientParams) + assert isinstance(self._query_params, ArynQueryParams) + + client = self.Client.from_client_params(self._client_params) + aryn_client = client._client + docs = aryn_client.list_docs(self._query_params.docset_id) + print(f"Found {len(docs)} docs in docset: {self._query_params.docset_id}") + from ray.data import from_items + ds = from_items([{"doc_id": doc_id} for doc_id in docs]) + return ds.map(self._to_doc) diff --git a/lib/sycamore/sycamore/connectors/aryn/client.py b/lib/sycamore/sycamore/connectors/aryn/client.py new file mode 100644 index 000000000..1b9dd95d4 --- /dev/null +++ b/lib/sycamore/sycamore/connectors/aryn/client.py @@ -0,0 +1,28 @@ +from typing import Any + +import requests + + +class ArynClient: + def __init__(self, aryn_url: str, api_key: str): + self.aryn_url = aryn_url + self.api_key = api_key + + def list_docs(self, docset_id: str) -> list[str]: + try: + response = requests.get(f"{self.aryn_url}/docsets/{docset_id}/docs", headers={"Authorization": f"Bearer {self.api_key}"}) + items = response.json()["items"] + return [item["doc_id"] for item in items] + except Exception as e: + raise ValueError(f"Error listing docs: {e}") + + def get_doc(self, docset_id: str, doc_id: str) -> dict[str, Any]: + response = requests.get(f"{self.aryn_url}/docsets/{docset_id}/docs/{doc_id}", headers={"Authorization": f"Bearer {self.api_key}"}) + return response.json() + + def create_docset(self, name: str) -> str: + try: + response = requests.post(f"{self.aryn_url}/docsets", json={"name": name}, headers={"Authorization": f"Bearer {self.api_key}"}) + return response.json()["docset_id"] + except Exception as e: + raise ValueError(f"Error creating docset: {e}") \ No newline at end of file diff --git a/lib/sycamore/sycamore/tests/integration/connectors/aryn/test_client.py b/lib/sycamore/sycamore/tests/integration/connectors/aryn/test_client.py new file mode 100644 index 000000000..68ca952b1 --- /dev/null +++ b/lib/sycamore/sycamore/tests/integration/connectors/aryn/test_client.py @@ -0,0 +1,22 @@ +import os + +from sycamore.connectors.aryn.client import ArynClient + + +def test_list_docs(): + aryn_api_key = os.getenv("ARYN_TEST_API_KEY") + client = ArynClient(aryn_url="http://localhost:8002/v1/docstore", api_key=aryn_api_key) + docset_id = "aryn:ds-fwaagauoj6yqcia2n4c3zfd" + docs = client.list_docs(docset_id) + for doc in docs: + print(doc) + +def test_get_doc(): + aryn_api_key = os.getenv("ARYN_TEST_API_KEY") + client = ArynClient(aryn_url="http://localhost:8002/v1/docstore", api_key=aryn_api_key) + docset_id = "aryn:ds-fwaagauoj6yqcia2n4c3zfd" + docs = client.list_docs(docset_id) + for doc in docs: + print(doc) + doc = client.get_doc(docset_id, doc) + print(doc) diff --git a/lib/sycamore/sycamore/writer.py b/lib/sycamore/sycamore/writer.py index 39a7a2e6d..f53c39997 100644 --- a/lib/sycamore/sycamore/writer.py +++ b/lib/sycamore/sycamore/writer.py @@ -824,8 +824,6 @@ def aryn( Args: docset_id: The id of the docset to write to. If not provided, a new docset will be created. - create_new_docset: If true, a new docset will be created. If false, the docset with the provided - id will be used. name: The name of the new docset to create. Required if create_new_docset is true. aryn_api_key: The api key to use for authentication. If not provided, the api key from the config file will be used. @@ -848,10 +846,15 @@ def aryn( raise ValueError("Either docset_id or name must be provided") if docset_id is None and name is not None: - headers = {"Authorization": f"Bearer {aryn_api_key}"} - res = requests.post(url=f"{aryn_url}/docsets", data={"name": name}, headers=headers) - docset_id = res.json()["docset_id"] - + try: + headers = {"Authorization": f"Bearer {aryn_api_key}"} + res = requests.post(url=f"{aryn_url}/docsets", json={"name": name}, headers=headers) + print(res) + docset_id = res.json()["docset_id"] + logger.info(f"Created new docset with id {docset_id} and name {name}") + except Exception as e: + logger.error(f"Error creating new docset: {e}") + raise e client_params = ArynWriterClientParams(aryn_url, aryn_api_key) target_params = ArynWriterTargetParams(docset_id) ds = ArynWriter(self.plan, client_params=client_params, target_params=target_params, **kwargs) From 5cac2d5ea44bd391ee38bc19732f32066e364c98 Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Mon, 10 Feb 2025 17:56:09 -0800 Subject: [PATCH 4/8] Use list_docs and get_doc for reading from Aryn --- .../sycamore/connectors/aryn/ArynReader.py | 32 ++++++++++++----- .../sycamore/connectors/aryn/client.py | 35 ++++++++++++++++--- .../connectors/aryn/test_client.py | 10 ++++-- 3 files changed, 61 insertions(+), 16 deletions(-) diff --git a/lib/sycamore/sycamore/connectors/aryn/ArynReader.py b/lib/sycamore/sycamore/connectors/aryn/ArynReader.py index 658ec0764..aee0bfc36 100644 --- a/lib/sycamore/sycamore/connectors/aryn/ArynReader.py +++ b/lib/sycamore/sycamore/connectors/aryn/ArynReader.py @@ -1,16 +1,14 @@ import io import json +import logging import struct from dataclasses import dataclass from time import time from typing import Any, TYPE_CHECKING import httpx -import pandas from sycamore.connectors.aryn.client import ArynClient -# import requests -# from requests import Response from sycamore.connectors.base_reader import BaseDBReader from sycamore.data import Document @@ -19,6 +17,9 @@ if TYPE_CHECKING: from ray.data import Dataset +logger = logging.getLogger(__name__) + + @dataclass class ArynClientParams(BaseDBReader.ClientParams): def __init__(self, aryn_url: str, api_key: str, **kwargs): @@ -62,7 +63,9 @@ def read_records(self, query_params: "BaseDBReader.QueryParams") -> "ArynQueryRe headers = {"Authorization": f"Bearer {self.api_key}"} client = httpx.Client() - with client.stream("POST", f"{self.aryn_url}/docsets/{query_params.docset_id}/read", headers=headers) as response: + with client.stream( + "POST", f"{self.aryn_url}/docsets/{query_params.docset_id}/read", headers=headers + ) as response: docs = [] print(f"Reading from docset: {query_params.docset_id}") @@ -81,8 +84,8 @@ def read_records(self, query_params: "BaseDBReader.QueryParams") -> "ArynQueryRe assert len(chunk) >= 4, f"Chunk too small: {len(chunk)} < 4" while cur_pos < len(chunk): if start_new_doc: - doc_size_buf[idx:] = chunk[cur_pos:cur_pos + 4 - idx] - to_read = struct.unpack('!i', doc_size_buf)[0] + doc_size_buf[idx:] = chunk[cur_pos : cur_pos + 4 - idx] + to_read = struct.unpack("!i", doc_size_buf)[0] print(f"Reading doc of size: {to_read}") doc_size_buf = bytearray(4) idx = 0 @@ -97,7 +100,7 @@ def read_records(self, query_params: "BaseDBReader.QueryParams") -> "ArynQueryRe break else: print("Reading the rest of the doc from the chunk") - buffer.write(chunk[cur_pos:cur_pos + to_read]) + buffer.write(chunk[cur_pos : cur_pos + to_read]) docs.append(json.loads(buffer.getvalue().decode())) buffer.flush() buffer.seek(0) @@ -139,6 +142,15 @@ def __init__( super().__init__(client_params=client_params, query_params=query_params, **kwargs) def _to_doc(self, doc: dict[str, Any]) -> dict[str, Any]: + assert isinstance(self._client_params, ArynClientParams) + assert isinstance(self._query_params, ArynQueryParams) + + client = self.Client.from_client_params(self._client_params) + aryn_client = client._client + + doc = aryn_client.get_doc(self._query_params.docset_id, doc["doc_id"]) + if 0 == len(doc.keys()): + return {"doc": Document.serialize(Document())} elements = doc.get("elements", []) doc = Document(**doc) doc.data["elements"] = [create_element(**element) for element in elements] @@ -151,8 +163,12 @@ def execute(self, **kwargs) -> "Dataset": client = self.Client.from_client_params(self._client_params) aryn_client = client._client + + # TODO paginate docs = aryn_client.list_docs(self._query_params.docset_id) - print(f"Found {len(docs)} docs in docset: {self._query_params.docset_id}") + logger.debug(f"Found {len(docs)} docs in docset: {self._query_params.docset_id}") + from ray.data import from_items + ds = from_items([{"doc_id": doc_id} for doc_id in docs]) return ds.map(self._to_doc) diff --git a/lib/sycamore/sycamore/connectors/aryn/client.py b/lib/sycamore/sycamore/connectors/aryn/client.py index 1b9dd95d4..649564791 100644 --- a/lib/sycamore/sycamore/connectors/aryn/client.py +++ b/lib/sycamore/sycamore/connectors/aryn/client.py @@ -1,7 +1,10 @@ +import logging from typing import Any import requests +logger = logging.getLogger(__name__) + class ArynClient: def __init__(self, aryn_url: str, api_key: str): @@ -10,19 +13,41 @@ def __init__(self, aryn_url: str, api_key: str): def list_docs(self, docset_id: str) -> list[str]: try: - response = requests.get(f"{self.aryn_url}/docsets/{docset_id}/docs", headers={"Authorization": f"Bearer {self.api_key}"}) + response = requests.get( + f"{self.aryn_url}/docsets/{docset_id}/docs", headers={"Authorization": f"Bearer {self.api_key}"} + ) items = response.json()["items"] return [item["doc_id"] for item in items] except Exception as e: raise ValueError(f"Error listing docs: {e}") def get_doc(self, docset_id: str, doc_id: str) -> dict[str, Any]: - response = requests.get(f"{self.aryn_url}/docsets/{docset_id}/docs/{doc_id}", headers={"Authorization": f"Bearer {self.api_key}"}) - return response.json() + try: + response = requests.get( + f"{self.aryn_url}/docsets/{docset_id}/docs/{doc_id}", + headers={"Authorization": f"Bearer {self.api_key}"}, + ) + if response.status_code != 200: + raise ValueError( + f"Error getting doc {doc_id}, received {response.status_code} {response.text} {response.reason}" + ) + doc = response.json() + if doc is None: + print(f"Received None for doc {doc_id}") + return {} + print(f">>> DOC {doc}") + logger.info(f"Got doc {doc}") + return doc + except Exception as e: + # raise ValueError(f"Error getting doc {doc_id}: {e}") + print(f"Error getting doc {doc_id}: {e}") + return {} def create_docset(self, name: str) -> str: try: - response = requests.post(f"{self.aryn_url}/docsets", json={"name": name}, headers={"Authorization": f"Bearer {self.api_key}"}) + response = requests.post( + f"{self.aryn_url}/docsets", json={"name": name}, headers={"Authorization": f"Bearer {self.api_key}"} + ) return response.json()["docset_id"] except Exception as e: - raise ValueError(f"Error creating docset: {e}") \ No newline at end of file + raise ValueError(f"Error creating docset: {e}") diff --git a/lib/sycamore/sycamore/tests/integration/connectors/aryn/test_client.py b/lib/sycamore/sycamore/tests/integration/connectors/aryn/test_client.py index 68ca952b1..f74087ecc 100644 --- a/lib/sycamore/sycamore/tests/integration/connectors/aryn/test_client.py +++ b/lib/sycamore/sycamore/tests/integration/connectors/aryn/test_client.py @@ -1,20 +1,24 @@ import os -from sycamore.connectors.aryn.client import ArynClient +import pytest +from sycamore.connectors.aryn.client import ArynClient +@pytest.mark.skip(reason="For manual testing only") def test_list_docs(): aryn_api_key = os.getenv("ARYN_TEST_API_KEY") client = ArynClient(aryn_url="http://localhost:8002/v1/docstore", api_key=aryn_api_key) - docset_id = "aryn:ds-fwaagauoj6yqcia2n4c3zfd" + docset_id = "" docs = client.list_docs(docset_id) for doc in docs: print(doc) + +@pytest.mark.skip(reason="For manual testing only") def test_get_doc(): aryn_api_key = os.getenv("ARYN_TEST_API_KEY") client = ArynClient(aryn_url="http://localhost:8002/v1/docstore", api_key=aryn_api_key) - docset_id = "aryn:ds-fwaagauoj6yqcia2n4c3zfd" + docset_id = "" docs = client.list_docs(docset_id) for doc in docs: print(doc) From 5c73af3594620eb7d29ddb6d56bdc4e6b52dc73c Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Mon, 10 Feb 2025 18:27:19 -0800 Subject: [PATCH 5/8] Mark Aryn classes experimental --- lib/sycamore/sycamore/connectors/aryn/ArynReader.py | 4 ++-- lib/sycamore/sycamore/connectors/aryn/ArynWriter.py | 2 ++ lib/sycamore/sycamore/connectors/aryn/client.py | 13 ++++++------- lib/sycamore/sycamore/decorators.py | 10 ++++++++++ lib/sycamore/sycamore/reader.py | 2 ++ lib/sycamore/sycamore/writer.py | 10 ++++++---- 6 files changed, 28 insertions(+), 13 deletions(-) create mode 100644 lib/sycamore/sycamore/decorators.py diff --git a/lib/sycamore/sycamore/connectors/aryn/ArynReader.py b/lib/sycamore/sycamore/connectors/aryn/ArynReader.py index aee0bfc36..82b389834 100644 --- a/lib/sycamore/sycamore/connectors/aryn/ArynReader.py +++ b/lib/sycamore/sycamore/connectors/aryn/ArynReader.py @@ -13,6 +13,7 @@ from sycamore.connectors.base_reader import BaseDBReader from sycamore.data import Document from sycamore.data.element import create_element +from sycamore.decorators import experimental if TYPE_CHECKING: from ray.data import Dataset @@ -127,6 +128,7 @@ def from_client_params(cls, params: "BaseDBReader.ClientParams") -> "ArynReaderC return cls(client, params) +@experimental class ArynReader(BaseDBReader): Client = ArynReaderClient Record = ArynQueryResponse @@ -149,8 +151,6 @@ def _to_doc(self, doc: dict[str, Any]) -> dict[str, Any]: aryn_client = client._client doc = aryn_client.get_doc(self._query_params.docset_id, doc["doc_id"]) - if 0 == len(doc.keys()): - return {"doc": Document.serialize(Document())} elements = doc.get("elements", []) doc = Document(**doc) doc.data["elements"] = [create_element(**element) for element in elements] diff --git a/lib/sycamore/sycamore/connectors/aryn/ArynWriter.py b/lib/sycamore/sycamore/connectors/aryn/ArynWriter.py index 60a3fdce7..5456b53a1 100644 --- a/lib/sycamore/sycamore/connectors/aryn/ArynWriter.py +++ b/lib/sycamore/sycamore/connectors/aryn/ArynWriter.py @@ -5,6 +5,7 @@ from sycamore.connectors.base_writer import BaseDBWriter from sycamore.data import Document +from sycamore.decorators import experimental @dataclass @@ -67,6 +68,7 @@ def get_existing_target_params(self, target_params: "BaseDBWriter.TargetParams") pass +@experimental class ArynWriter(BaseDBWriter): Client = ArynWriterClient Record = ArynWriterRecord diff --git a/lib/sycamore/sycamore/connectors/aryn/client.py b/lib/sycamore/sycamore/connectors/aryn/client.py index 649564791..6249f835e 100644 --- a/lib/sycamore/sycamore/connectors/aryn/client.py +++ b/lib/sycamore/sycamore/connectors/aryn/client.py @@ -3,9 +3,12 @@ import requests +from sycamore.decorators import experimental + logger = logging.getLogger(__name__) +@experimental class ArynClient: def __init__(self, aryn_url: str, api_key: str): self.aryn_url = aryn_url @@ -33,15 +36,11 @@ def get_doc(self, docset_id: str, doc_id: str) -> dict[str, Any]: ) doc = response.json() if doc is None: - print(f"Received None for doc {doc_id}") - return {} - print(f">>> DOC {doc}") - logger.info(f"Got doc {doc}") + raise ValueError(f"Received None for doc {doc_id}") + logger.debug(f"Got doc {doc}") return doc except Exception as e: - # raise ValueError(f"Error getting doc {doc_id}: {e}") - print(f"Error getting doc {doc_id}: {e}") - return {} + raise ValueError(f"Error getting doc {doc_id}: {e}") def create_docset(self, name: str) -> str: try: diff --git a/lib/sycamore/sycamore/decorators.py b/lib/sycamore/sycamore/decorators.py new file mode 100644 index 000000000..3a9860d16 --- /dev/null +++ b/lib/sycamore/sycamore/decorators.py @@ -0,0 +1,10 @@ +import warnings + +def experimental(cls): + """ + Decorator to mark a class as experimental. + """ + def wrapper(*args, **kwargs): + warnings.warn(f"Class {cls.__name__} is experimental and may change in the future.", FutureWarning, stacklevel=2) + return cls(*args, **kwargs) + return wrapper \ No newline at end of file diff --git a/lib/sycamore/sycamore/reader.py b/lib/sycamore/sycamore/reader.py index f29713af4..bafd9dae3 100644 --- a/lib/sycamore/sycamore/reader.py +++ b/lib/sycamore/sycamore/reader.py @@ -7,6 +7,7 @@ from sycamore.connectors.doc_reconstruct import DocumentReconstructor from sycamore.context import context_params +from sycamore.decorators import experimental from sycamore.plan_nodes import Node from sycamore import Context, DocSet from sycamore.data import Document @@ -634,6 +635,7 @@ def qdrant(self, client_params: dict, query_params: dict, **kwargs) -> DocSet: ) return DocSet(self._context, wr) + @experimental def aryn( self, docset_id: str, aryn_api_key: Optional[str] = None, aryn_url: Optional[str] = None, **kwargs ) -> DocSet: diff --git a/lib/sycamore/sycamore/writer.py b/lib/sycamore/sycamore/writer.py index f53c39997..2bd9d60f0 100644 --- a/lib/sycamore/sycamore/writer.py +++ b/lib/sycamore/sycamore/writer.py @@ -4,10 +4,12 @@ import requests from pyarrow.fs import FileSystem +from sycamore.connectors.aryn.client import ArynClient from sycamore.context import Context, ExecMode, context_params from sycamore.connectors.common import HostAndPort from sycamore.connectors.file.file_writer import default_doc_to_bytes, default_filename, FileWriter, JsonWriter from sycamore.data import Document +from sycamore.decorators import experimental from sycamore.executor import Execution from sycamore.plan_nodes import Node from sycamore.docset import DocSet @@ -543,6 +545,7 @@ def elasticsearch( ) return self._maybe_execute(es_docs, execute) + @experimental @requires_modules("neo4j", extra="neo4j") def neo4j( self, @@ -811,6 +814,7 @@ def json( self._maybe_execute(node, True) + @experimental def aryn( self, docset_id: Optional[str] = None, @@ -847,10 +851,8 @@ def aryn( if docset_id is None and name is not None: try: - headers = {"Authorization": f"Bearer {aryn_api_key}"} - res = requests.post(url=f"{aryn_url}/docsets", json={"name": name}, headers=headers) - print(res) - docset_id = res.json()["docset_id"] + aryn_client = ArynClient(aryn_url, aryn_api_key) + docset_id = aryn_client.create_docset(name) logger.info(f"Created new docset with id {docset_id} and name {name}") except Exception as e: logger.error(f"Error creating new docset: {e}") From bbb7c1e249225f3b471d1191cf6d70c921965259 Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Thu, 13 Feb 2025 18:57:40 -0800 Subject: [PATCH 6/8] Fix lint --- lib/sycamore/sycamore/decorators.py | 3 ++- .../tests/integration/connectors/aryn/test_client.py | 7 +++++-- lib/sycamore/sycamore/writer.py | 1 - 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/lib/sycamore/sycamore/decorators.py b/lib/sycamore/sycamore/decorators.py index 3a9860d16..06399d372 100644 --- a/lib/sycamore/sycamore/decorators.py +++ b/lib/sycamore/sycamore/decorators.py @@ -5,6 +5,7 @@ def experimental(cls): Decorator to mark a class as experimental. """ def wrapper(*args, **kwargs): - warnings.warn(f"Class {cls.__name__} is experimental and may change in the future.", FutureWarning, stacklevel=2) + warnings.warn(f"Class {cls.__name__} is experimental and may change in the future.", + FutureWarning, stacklevel=2) return cls(*args, **kwargs) return wrapper \ No newline at end of file diff --git a/lib/sycamore/sycamore/tests/integration/connectors/aryn/test_client.py b/lib/sycamore/sycamore/tests/integration/connectors/aryn/test_client.py index f74087ecc..49609746a 100644 --- a/lib/sycamore/sycamore/tests/integration/connectors/aryn/test_client.py +++ b/lib/sycamore/sycamore/tests/integration/connectors/aryn/test_client.py @@ -4,10 +4,13 @@ from sycamore.connectors.aryn.client import ArynClient + +aryn_endpoint = os.getenv("ARYN_ENDPOINT") + @pytest.mark.skip(reason="For manual testing only") def test_list_docs(): aryn_api_key = os.getenv("ARYN_TEST_API_KEY") - client = ArynClient(aryn_url="http://localhost:8002/v1/docstore", api_key=aryn_api_key) + client = ArynClient(aryn_url=f"{aryn_endpoint}", api_key=aryn_api_key) docset_id = "" docs = client.list_docs(docset_id) for doc in docs: @@ -17,7 +20,7 @@ def test_list_docs(): @pytest.mark.skip(reason="For manual testing only") def test_get_doc(): aryn_api_key = os.getenv("ARYN_TEST_API_KEY") - client = ArynClient(aryn_url="http://localhost:8002/v1/docstore", api_key=aryn_api_key) + client = ArynClient(aryn_url=f"{aryn_endpoint}", api_key=aryn_api_key) docset_id = "" docs = client.list_docs(docset_id) for doc in docs: diff --git a/lib/sycamore/sycamore/writer.py b/lib/sycamore/sycamore/writer.py index 2bd9d60f0..5e5088bbc 100644 --- a/lib/sycamore/sycamore/writer.py +++ b/lib/sycamore/sycamore/writer.py @@ -1,7 +1,6 @@ import logging from typing import Any, Callable, Optional, Union, TYPE_CHECKING -import requests from pyarrow.fs import FileSystem from sycamore.connectors.aryn.client import ArynClient From e8bd9d0e5b0fd00f548bdb68bd68e31d31e7f450 Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Thu, 13 Feb 2025 19:15:34 -0800 Subject: [PATCH 7/8] Fix lint --- lib/sycamore/sycamore/decorators.py | 10 +++++++--- .../tests/integration/connectors/aryn/test_client.py | 1 + 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/lib/sycamore/sycamore/decorators.py b/lib/sycamore/sycamore/decorators.py index 06399d372..e4579cdea 100644 --- a/lib/sycamore/sycamore/decorators.py +++ b/lib/sycamore/sycamore/decorators.py @@ -1,11 +1,15 @@ import warnings + def experimental(cls): """ Decorator to mark a class as experimental. """ + def wrapper(*args, **kwargs): - warnings.warn(f"Class {cls.__name__} is experimental and may change in the future.", - FutureWarning, stacklevel=2) + warnings.warn( + f"Class {cls.__name__} is experimental and may change in the future.", FutureWarning, stacklevel=2 + ) return cls(*args, **kwargs) - return wrapper \ No newline at end of file + + return wrapper diff --git a/lib/sycamore/sycamore/tests/integration/connectors/aryn/test_client.py b/lib/sycamore/sycamore/tests/integration/connectors/aryn/test_client.py index 49609746a..22d269734 100644 --- a/lib/sycamore/sycamore/tests/integration/connectors/aryn/test_client.py +++ b/lib/sycamore/sycamore/tests/integration/connectors/aryn/test_client.py @@ -7,6 +7,7 @@ aryn_endpoint = os.getenv("ARYN_ENDPOINT") + @pytest.mark.skip(reason="For manual testing only") def test_list_docs(): aryn_api_key = os.getenv("ARYN_TEST_API_KEY") From e608ab5e94826c0f97b0eef382d02e8516c8bbdf Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Thu, 13 Feb 2025 19:27:15 -0800 Subject: [PATCH 8/8] Fix mypy --- lib/sycamore/sycamore/connectors/aryn/ArynReader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/sycamore/sycamore/connectors/aryn/ArynReader.py b/lib/sycamore/sycamore/connectors/aryn/ArynReader.py index 82b389834..e2d602e59 100644 --- a/lib/sycamore/sycamore/connectors/aryn/ArynReader.py +++ b/lib/sycamore/sycamore/connectors/aryn/ArynReader.py @@ -152,9 +152,9 @@ def _to_doc(self, doc: dict[str, Any]) -> dict[str, Any]: doc = aryn_client.get_doc(self._query_params.docset_id, doc["doc_id"]) elements = doc.get("elements", []) - doc = Document(**doc) - doc.data["elements"] = [create_element(**element) for element in elements] - return {"doc": Document.serialize(doc)} + document = Document(**doc) + document.data["elements"] = [create_element(**element) for element in elements] + return {"doc": Document.serialize(document)} def execute(self, **kwargs) -> "Dataset":