Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Aryn reader #1172

Merged
merged 9 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 113 additions & 18 deletions lib/sycamore/sycamore/connectors/aryn/ArynReader.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
import io
import json
import logging
import struct
from dataclasses import dataclass
from typing import Any
from time import time
from typing import Any, TYPE_CHECKING

import requests
from requests import Response
import httpx

from sycamore.connectors.aryn.client import ArynClient

from sycamore.connectors.base_reader import BaseDBReader
from sycamore.data import Document
from sycamore.data.element import create_element
from sycamore.decorators import experimental

if TYPE_CHECKING:
from ray.data import Dataset

logger = logging.getLogger(__name__)


@dataclass
Expand Down Expand Up @@ -41,39 +52,123 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:
return docs


class ArynClient(BaseDBReader.Client):
def __init__(self, client_params: ArynClientParams, **kwargs):
class ArynReaderClient(BaseDBReader.Client):
def __init__(self, client: ArynClient, client_params: ArynClientParams, **kwargs):
self.aryn_url = client_params.aryn_url
self.api_key = client_params.api_key
self._client = client
self.kwargs = kwargs

def read_records(self, query_params: "BaseDBReader.QueryParams") -> "ArynQueryResponse":
assert isinstance(query_params, ArynQueryParams)
headers = {"Authorization": f"Bearer {self.api_key}"}
response: Response = requests.post(
f"{self.aryn_url}/docsets/{query_params.docset_id}/read", stream=True, headers=headers
)
assert response.status_code == 200
docs = []
print(f"Reading from docset: {query_params.docset_id}")
for chunk in response.iter_lines():
# print(f"\n{chunk}\n")
doc = json.loads(chunk)
docs.append(doc)

client = httpx.Client()
with client.stream(
"POST", f"{self.aryn_url}/docsets/{query_params.docset_id}/read", headers=headers
) as response:

docs = []
print(f"Reading from docset: {query_params.docset_id}")
buffer = io.BytesIO()
to_read = 0
start_new_doc = True
doc_size_buf = bytearray(4)
idx = 0
chunk_count = 0
t0 = time()
for chunk in response.iter_bytes():
cur_pos = 0
chunk_count += 1
remaining = len(chunk)
print(f"Chunk {chunk_count} size: {len(chunk)}")
assert len(chunk) >= 4, f"Chunk too small: {len(chunk)} < 4"
while cur_pos < len(chunk):
if start_new_doc:
doc_size_buf[idx:] = chunk[cur_pos : cur_pos + 4 - idx]
to_read = struct.unpack("!i", doc_size_buf)[0]
print(f"Reading doc of size: {to_read}")
doc_size_buf = bytearray(4)
idx = 0
cur_pos += 4
remaining = len(chunk) - cur_pos
start_new_doc = False
if to_read > remaining:
buffer.write(chunk[cur_pos:])
to_read -= remaining
print(f"Remaining to read: {to_read}")
# Read the next chunk
break
else:
print("Reading the rest of the doc from the chunk")
buffer.write(chunk[cur_pos : cur_pos + to_read])
docs.append(json.loads(buffer.getvalue().decode()))
buffer.flush()
buffer.seek(0)
cur_pos += to_read
to_read = 0
start_new_doc = True
if (cur_pos - len(chunk)) < 4:
idx = left_over = cur_pos - len(chunk)
doc_size_buf[:left_over] = chunk[cur_pos:]
# Need to get the rest of the next chunk
break

t1 = time()
print(f"Reading took: {t1 - t0} seconds")
return ArynQueryResponse(docs)

def check_target_presence(self, query_params: "BaseDBReader.QueryParams") -> bool:
return True

@classmethod
def from_client_params(cls, params: "BaseDBReader.ClientParams") -> "ArynClient":
def from_client_params(cls, params: "BaseDBReader.ClientParams") -> "ArynReaderClient":
assert isinstance(params, ArynClientParams)
return cls(params)
client = ArynClient(params.aryn_url, params.api_key)
return cls(client, params)


@experimental
class ArynReader(BaseDBReader):
Client = ArynClient
Client = ArynReaderClient
Record = ArynQueryResponse
ClientParams = ArynClientParams
QueryParams = ArynQueryParams

def __init__(
self,
client_params: ArynClientParams,
query_params: ArynQueryParams,
**kwargs,
):
super().__init__(client_params=client_params, query_params=query_params, **kwargs)

def _to_doc(self, doc: dict[str, Any]) -> dict[str, Any]:
assert isinstance(self._client_params, ArynClientParams)
assert isinstance(self._query_params, ArynQueryParams)

client = self.Client.from_client_params(self._client_params)
aryn_client = client._client

doc = aryn_client.get_doc(self._query_params.docset_id, doc["doc_id"])
elements = doc.get("elements", [])
document = Document(**doc)
document.data["elements"] = [create_element(**element) for element in elements]
return {"doc": Document.serialize(document)}

def execute(self, **kwargs) -> "Dataset":

assert isinstance(self._client_params, ArynClientParams)
assert isinstance(self._query_params, ArynQueryParams)

client = self.Client.from_client_params(self._client_params)
aryn_client = client._client

# TODO paginate
docs = aryn_client.list_docs(self._query_params.docset_id)
logger.debug(f"Found {len(docs)} docs in docset: {self._query_params.docset_id}")

from ray.data import from_items

ds = from_items([{"doc_id": doc_id} for doc_id in docs])
return ds.map(self._to_doc)
2 changes: 2 additions & 0 deletions lib/sycamore/sycamore/connectors/aryn/ArynWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from sycamore.connectors.base_writer import BaseDBWriter
from sycamore.data import Document
from sycamore.decorators import experimental


@dataclass
Expand Down Expand Up @@ -67,6 +68,7 @@ def get_existing_target_params(self, target_params: "BaseDBWriter.TargetParams")
pass


@experimental
class ArynWriter(BaseDBWriter):
Client = ArynWriterClient
Record = ArynWriterRecord
Expand Down
52 changes: 52 additions & 0 deletions lib/sycamore/sycamore/connectors/aryn/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import logging
from typing import Any

import requests

from sycamore.decorators import experimental

logger = logging.getLogger(__name__)


@experimental
class ArynClient:
def __init__(self, aryn_url: str, api_key: str):
self.aryn_url = aryn_url
self.api_key = api_key

def list_docs(self, docset_id: str) -> list[str]:
try:
response = requests.get(
f"{self.aryn_url}/docsets/{docset_id}/docs", headers={"Authorization": f"Bearer {self.api_key}"}
)
items = response.json()["items"]
return [item["doc_id"] for item in items]
except Exception as e:
raise ValueError(f"Error listing docs: {e}")

def get_doc(self, docset_id: str, doc_id: str) -> dict[str, Any]:
try:
response = requests.get(
f"{self.aryn_url}/docsets/{docset_id}/docs/{doc_id}",
headers={"Authorization": f"Bearer {self.api_key}"},
)
if response.status_code != 200:
raise ValueError(
f"Error getting doc {doc_id}, received {response.status_code} {response.text} {response.reason}"
)
doc = response.json()
if doc is None:
raise ValueError(f"Received None for doc {doc_id}")
logger.debug(f"Got doc {doc}")
return doc
except Exception as e:
raise ValueError(f"Error getting doc {doc_id}: {e}")

def create_docset(self, name: str) -> str:
try:
response = requests.post(
f"{self.aryn_url}/docsets", json={"name": name}, headers={"Authorization": f"Bearer {self.api_key}"}
)
return response.json()["docset_id"]
except Exception as e:
raise ValueError(f"Error creating docset: {e}")
15 changes: 15 additions & 0 deletions lib/sycamore/sycamore/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import warnings


def experimental(cls):
"""
Decorator to mark a class as experimental.
"""

def wrapper(*args, **kwargs):
warnings.warn(
f"Class {cls.__name__} is experimental and may change in the future.", FutureWarning, stacklevel=2
)
return cls(*args, **kwargs)

return wrapper
2 changes: 2 additions & 0 deletions lib/sycamore/sycamore/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from sycamore.connectors.doc_reconstruct import DocumentReconstructor
from sycamore.context import context_params
from sycamore.decorators import experimental
from sycamore.plan_nodes import Node
from sycamore import Context, DocSet
from sycamore.data import Document
Expand Down Expand Up @@ -634,6 +635,7 @@ def qdrant(self, client_params: dict, query_params: dict, **kwargs) -> DocSet:
)
return DocSet(self._context, wr)

@experimental
def aryn(
self, docset_id: str, aryn_api_key: Optional[str] = None, aryn_url: Optional[str] = None, **kwargs
) -> DocSet:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os

import pytest

from sycamore.connectors.aryn.client import ArynClient


aryn_endpoint = os.getenv("ARYN_ENDPOINT")


@pytest.mark.skip(reason="For manual testing only")
def test_list_docs():
aryn_api_key = os.getenv("ARYN_TEST_API_KEY")
client = ArynClient(aryn_url=f"{aryn_endpoint}", api_key=aryn_api_key)
docset_id = ""
docs = client.list_docs(docset_id)
for doc in docs:
print(doc)


@pytest.mark.skip(reason="For manual testing only")
def test_get_doc():
aryn_api_key = os.getenv("ARYN_TEST_API_KEY")
client = ArynClient(aryn_url=f"{aryn_endpoint}", api_key=aryn_api_key)
docset_id = ""
docs = client.list_docs(docset_id)
for doc in docs:
print(doc)
doc = client.get_doc(docset_id, doc)
print(doc)
18 changes: 11 additions & 7 deletions lib/sycamore/sycamore/writer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import logging
from typing import Any, Callable, Optional, Union, TYPE_CHECKING

import requests
from pyarrow.fs import FileSystem

from sycamore.connectors.aryn.client import ArynClient
from sycamore.context import Context, ExecMode, context_params
from sycamore.connectors.common import HostAndPort
from sycamore.connectors.file.file_writer import default_doc_to_bytes, default_filename, FileWriter, JsonWriter
from sycamore.data import Document
from sycamore.decorators import experimental
from sycamore.executor import Execution
from sycamore.plan_nodes import Node
from sycamore.docset import DocSet
Expand Down Expand Up @@ -543,6 +544,7 @@ def elasticsearch(
)
return self._maybe_execute(es_docs, execute)

@experimental
@requires_modules("neo4j", extra="neo4j")
def neo4j(
self,
Expand Down Expand Up @@ -811,6 +813,7 @@ def json(

self._maybe_execute(node, True)

@experimental
def aryn(
self,
docset_id: Optional[str] = None,
Expand All @@ -824,8 +827,6 @@ def aryn(

Args:
docset_id: The id of the docset to write to. If not provided, a new docset will be created.
create_new_docset: If true, a new docset will be created. If false, the docset with the provided
id will be used.
name: The name of the new docset to create. Required if create_new_docset is true.
aryn_api_key: The api key to use for authentication. If not provided, the api key from the config
file will be used.
Expand All @@ -848,10 +849,13 @@ def aryn(
raise ValueError("Either docset_id or name must be provided")

if docset_id is None and name is not None:
headers = {"Authorization": f"Bearer {aryn_api_key}"}
res = requests.post(url=f"{aryn_url}/docsets", data={"name": name}, headers=headers)
docset_id = res.json()["docset_id"]

try:
aryn_client = ArynClient(aryn_url, aryn_api_key)
docset_id = aryn_client.create_docset(name)
logger.info(f"Created new docset with id {docset_id} and name {name}")
except Exception as e:
logger.error(f"Error creating new docset: {e}")
raise e
client_params = ArynWriterClientParams(aryn_url, aryn_api_key)
target_params = ArynWriterTargetParams(docset_id)
ds = ArynWriter(self.plan, client_params=client_params, target_params=target_params, **kwargs)
Expand Down
Loading