Skip to content

Commit

Permalink
Merge branch 'main' into fix-os-reader
Browse files Browse the repository at this point in the history
  • Loading branch information
austintlee committed Feb 11, 2025
2 parents 5c73af3 + 87f230e commit 8a22bb9
Show file tree
Hide file tree
Showing 23 changed files with 730 additions and 345 deletions.
24 changes: 14 additions & 10 deletions lib/aryn-sdk/aryn_sdk/partition/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_logger.addHandler(logging.StreamHandler(sys.stderr))

g_version = "0.1.13"
g_parameters = {"path_filter": "^/v1/document/partition$"}


class PartitionError(Exception):
Expand Down Expand Up @@ -447,7 +448,9 @@ def partition_file_async_result(

specific_task_url = f"{async_result_url.rstrip('/')}/{task_id}"
headers = _generate_headers(aryn_config.api_key())
response = requests.get(specific_task_url, headers=headers, stream=_should_stream(), verify=ssl_verify)
response = requests.get(
specific_task_url, params=g_parameters, headers=headers, stream=_should_stream(), verify=ssl_verify
)

if response.status_code == 200:
return {"status": "done", "status_code": response.status_code, "result": response.json()}
Expand Down Expand Up @@ -485,7 +488,9 @@ def partition_file_async_cancel(

specific_task_url = f"{async_cancel_url.rstrip('/')}/{task_id}"
headers = _generate_headers(aryn_config.api_key())
response = requests.post(specific_task_url, headers=headers, stream=_should_stream(), verify=ssl_verify)
response = requests.post(
specific_task_url, params=g_parameters, headers=headers, stream=_should_stream(), verify=ssl_verify
)
if response.status_code == 200:
return True
elif response.status_code == 404:
Expand Down Expand Up @@ -522,14 +527,13 @@ def partition_file_async_list(
aryn_config = _process_config(aryn_api_key, aryn_config)

headers = _generate_headers(aryn_config.api_key())
response = requests.get(async_list_url, headers=headers, stream=_should_stream(), verify=ssl_verify)

all_tasks = response.json()["tasks"]
result = {}
for task_id in all_tasks.keys():
if all_tasks[task_id]["path"] == "/v1/document/partition":
del all_tasks[task_id]["path"]
result[task_id] = all_tasks[task_id]
response = requests.get(
async_list_url, params=g_parameters, headers=headers, stream=_should_stream(), verify=ssl_verify
)

result = response.json()["tasks"]
for v in result.values():
v.pop("action", None)
return result


Expand Down
13 changes: 7 additions & 6 deletions lib/sycamore/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion lib/sycamore/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ ray = { extras = ["default"], version = "^2.36.0" }
pyarrow = "^14.0.2"
numpy = "<2.0.0"
openai = "^1.60.2"
beautifulsoup4 = "^4.12.2"
beautifulsoup4 = "^4.13.1"
amazon-textract-textractor = "^1.3.2"
boto3 = "^1.28.70"
boto3-stubs = {extras = ["essential"], version = "^1.35.12"}
Expand Down
12 changes: 8 additions & 4 deletions lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
from copy import deepcopy

Expand Down Expand Up @@ -288,7 +289,7 @@ def __init__(
logger.info(f"OpenSearchReader using PIT: {self.use_pit}")

@timetrace("OpenSearchReader")
def _to_parent_doc(self, slice_query: dict[str, Any]) -> List[dict[str, Any]]:
def _to_parent_doc(self, doc: dict[str, Any]) -> List[dict[str, Any]]:
"""
Get all parent documents from a given slice.
"""
Expand All @@ -304,6 +305,7 @@ def _to_parent_doc(self, slice_query: dict[str, Any]) -> List[dict[str, Any]]:
raise ValueError("Target is not present\n" f"Parameters: {self._query_params}\n")

os_client = client._client
slice_query = json.loads(doc["doc"])

assert (
get_doc_count_for_slice(os_client, slice_query) < 10000
Expand Down Expand Up @@ -350,7 +352,7 @@ def _to_parent_doc(self, slice_query: dict[str, Any]) -> List[dict[str, Any]]:
logger.info(f"Read {len(results)} documents from {self._query_params.index_name}")

except Exception as e:
raise ValueError(f"Error reading from target: {e}")
raise ValueError(f"Error reading from target: {e}, query: {slice_query}")
finally:
if client is not None:
client.close()
Expand All @@ -359,7 +361,7 @@ def _to_parent_doc(self, slice_query: dict[str, Any]) -> List[dict[str, Any]]:
return ret

@timetrace("OpenSearchReader")
def _to_doc(self, slice_query: dict[str, Any]) -> List[dict[str, Any]]:
def _to_doc(self, doc: dict[str, Any]) -> List[dict[str, Any]]:
"""
Get all documents from a given slice.
"""
Expand All @@ -377,6 +379,7 @@ def _to_doc(self, slice_query: dict[str, Any]) -> List[dict[str, Any]]:
raise ValueError("Target is not present\n" f"Parameters: {self._query_params}\n")

os_client = client._client
slice_query = json.loads(doc["doc"])
slice_count = get_doc_count_for_slice(os_client, slice_query)
assert slice_count <= 10000, f"Slice count ({slice_count}) should return <= 10,000 documents"

Expand Down Expand Up @@ -556,7 +559,8 @@ def _execute_pit(self, **kwargs) -> "Dataset":
}
if "query" in query:
_query["query"] = query["query"]
docs.append(_query)

docs.append({"doc": json.dumps(_query)})
logger.debug(f"Added slice {i} to the query {_query}")
except Exception as e:
raise ValueError(f"Error reading from target: {e}")
Expand Down
14 changes: 9 additions & 5 deletions lib/sycamore/sycamore/data/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def from_html(cls, html_str: Optional[str] = None, html_tag: Optional[Tag] = Non

if (html_str is not None and html_tag is not None) or (html_str is None and html_tag is None):
raise ValueError("Exactly one of html_str and html_tag must be specified.")

root: Union[Tag, BeautifulSoup]
if html_str is not None:
html_str = html_str.strip()
if not html_str.startswith("<table") or not html_str.endswith("</table>"):
Expand All @@ -222,9 +222,10 @@ def from_html(cls, html_str: Optional[str] = None, html_tag: Optional[Tag] = Non

cells = []
caption = None

assert isinstance(root, Tag), "Expected root to be a Tag"
# Traverse the tree of elements in a pre-order traversal.
for tag in root.find_all(recursive=True):
assert isinstance(tag, Tag), "Expected root to be a Tag"
if tag.name == "tr":
cur_row += 1 # TODO: Should this be based on rowspan?
cur_col = 0
Expand All @@ -234,9 +235,12 @@ def from_html(cls, html_str: Optional[str] = None, html_tag: Optional[Tag] = Non
# they have a thead.
if cur_row < 0:
cur_row += 1

rowspan = int(tag.attrs.get("rowspan", "1"))
colspan = int(tag.attrs.get("colspan", "1"))
if rowspan_str := tag.attrs.get("rowspan", "1"):
assert isinstance(rowspan_str, str) # For mypy
rowspan = int(rowspan_str)
if colspan_str := tag.attrs.get("colspan", "1"):
assert isinstance(colspan_str, str) # For mypy
colspan = int(colspan_str)

content = tag.get_text()

Expand Down
8 changes: 4 additions & 4 deletions lib/sycamore/sycamore/llms/prompts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from sycamore.llms.prompts.default_prompts import (
SimplePrompt,
EntityExtractorZeroShotGuidancePrompt,
EntityExtractorFewShotGuidancePrompt,
EntityExtractorZeroShotJinjaPrompt,
EntityExtractorFewShotJinjaPrompt,
TextSummarizerGuidancePrompt,
SchemaZeroShotGuidancePrompt,
PropertiesZeroShotGuidancePrompt,
Expand All @@ -26,8 +26,8 @@

prompts = [
"SimplePrompt",
"EntityExtractorZeroShotGuidancePrompt",
"EntityExtractorFewShotGuidancePrompt",
"EntityExtractorZeroShotJinjaPrompt",
"EntityExtractorFewShotJinjaPrompt",
"TextSummarizerGuidancePrompt",
"SchemaZeroShotGuidancePrompt",
"PropertiesZeroShotGuidancePrompt",
Expand Down
133 changes: 121 additions & 12 deletions lib/sycamore/sycamore/llms/prompts/default_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,19 @@
from typing import Any, Optional, Type
import textwrap

from sycamore.llms.prompts.prompts import ElementListPrompt, ElementPrompt, StaticPrompt
from sycamore.llms.prompts.prompts import (
ElementListPrompt,
ElementPrompt,
StaticPrompt,
JinjaPrompt,
JinjaElementPrompt,
)
from sycamore.llms.prompts.jinja_fragments import (
J_DYNAMIC_DOC_TEXT,
J_FORMAT_SCHEMA_MACRO,
J_SET_ENTITY,
J_SET_SCHEMA,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,11 +61,14 @@ class _EntityExtractorZeroShotGuidancePrompt(_SimplePrompt):
"""


EntityExtractorZeroShotGuidancePrompt = ElementListPrompt(
EntityExtractorZeroShotJinjaPrompt = JinjaPrompt(
system="You are a helpful entity extractor",
user="""You are given a few text elements of a document. The {entity} of the document is in these few text elements.
Using this context, FIND, COPY, and RETURN the {entity}. DO NOT REPHRASE OR MAKE UP AN ANSWER.
{elements}""",
user="""You are given a few text elements of a document. The {{ entity }} of the document is in these few text elements.
Using this context, FIND, COPY, and RETURN the {{ entity }}. DO NOT REPHRASE OR MAKE UP AN ANSWER.
{% for elt in doc.elements[:num_elements] %} ELEMENT {{ elt.element_index }}: {{ elt.field_to_value(field) }}
{% endfor %}""",
field="text_representation",
num_elements=35,
)


Expand All @@ -69,15 +84,72 @@ class _EntityExtractorFewShotGuidancePrompt(SimplePrompt):
"""


EntityExtractorFewShotGuidancePrompt = ElementListPrompt(
EntityExtractorFewShotJinjaPrompt = JinjaPrompt(
system="You are a helpful entity extractor",
user="""You are given a few text elements of a document. The {entity} of the document is in these few text elements. Here are
some example groups of text elements where the {entity} has been identified.
{examples}
Using the context from the document and the provided examples, FIND, COPY, and RETURN the {entity}. Only return the {entity} as part
user="""You are given a few text elements of a document. The {{ entity }} of the document is in these few text elements. Here are
some example groups of text elements where the {{ entity }} has been identified.
{{ examples }}
Using the context from the document and the provided examples, FIND, COPY, and RETURN the {{ entity }}. Only return the {{ entity }} as part
of your answer. DO NOT REPHRASE OR MAKE UP AN ANSWER.
{elements}
""",
{% for elt in doc.elements[:num_elements] %} ELEMENT {{ elt.element_index }}: {{ elt.field_to_value(field) }}
{% endfor %}""",
field="text_representation",
num_elements=35,
)


SummarizeImagesJinjaPrompt = JinjaElementPrompt(
user=textwrap.dedent(
"""
You are given an image from a PDF document along with with some snippets of text preceding
and following the image on the page. Based on this context, please decide whether the image is a
graph or not. An image is a graph if it is a bar chart or a line graph. If the image is a graph,
please summarize the axes, including their units, and provide a summary of the results in no more
than 5 sentences.
Return the results in the following JSON schema:
{
"is_graph": true,
"x-axis": string,
"y-axis": string,
"summary": string
}
If the image is not a graph, please summarize the contents of the image in no more than five sentences
in the following JSON format:
{
"is_graph": false,
"summary": string
}
In all cases return only JSON and check your work.
{% if include_context -%}
{%- set posns = namespace(pos=-1) -%}
{%- for e in doc.elements -%}
{%- if e is sameas elt -%}
{%- set posns.pos = loop.index0 -%}
{% break %}
{%- endif -%}
{%- endfor -%}
{%- if posns.pos > 0 -%}
{%- set pe = doc.elements[posns.pos - 1] -%}
{%- if pe.type in ["Section-header", "Caption", "Text"] -%}
The text preceding the image is: {{ pe.text_representation }}
{%- endif -%}
{%- endif %}
{% if posns.pos != -1 and posns.pos < doc.elements|count - 1 -%}
{%- set fe = doc.elements[posns.pos + 1] -%}
{%- if fe.type in ["Caption", "Text"] -%}
The text following the image is: {{ fe.text_representation }}
{%- endif -%}
{%- endif -%}
{%- endif -%}
"""
),
include_image=True,
)


Expand Down Expand Up @@ -291,6 +363,43 @@ class {entity} from the document. The class only has properties {properties}. Us
)


PropertiesZeroShotJinjaPrompt = JinjaPrompt(
system="You are a helpful property extractor. You only return JSON.",
user=J_SET_SCHEMA
+ J_SET_ENTITY
+ textwrap.dedent(
"""\
You are given some text of a document. Extract JSON representing one entity of
class {{ entity }} from the document. The class only has properties {{ schema }}. Using
this context, FIND, FORMAT, and RETURN the JSON representing one {{ entity }}.
Only return JSON as part of your answer. If no entity is in the text, return "None".
Document:
"""
)
+ J_DYNAMIC_DOC_TEXT,
)

PropertiesFromSchemaJinjaPrompt = JinjaPrompt(
system="You are given text contents from a document.",
user=(
J_FORMAT_SCHEMA_MACRO
+ """\
Extract values for the following fields:
{{ format_schema(schema) }}
Document text:"""
+ J_DYNAMIC_DOC_TEXT
+ """
Don't return extra information.
If you cannot find a value for a requested property, use the provided default or the value 'None'.
Return your answers as a valid json dictionary that will be parsed in python.
"""
),
)


class EntityExtractorMessagesPrompt(SimplePrompt):
def __init__(self, question: str, field: str, format: Optional[str], discrete: bool = False):
super().__init__()
Expand Down
Loading

0 comments on commit 8a22bb9

Please sign in to comment.