From deb175ffc8490886f3de986ff031b2cc8ec3f045 Mon Sep 17 00:00:00 2001 From: Dhruv Kaliraman <112497058+dhruvkaliraman7@users.noreply.github.com> Date: Thu, 6 Feb 2025 16:14:43 -0800 Subject: [PATCH 01/11] mock using patch (#1160) --- .../sycamore/tests/unit/test_materialize.py | 73 ++++++++++--------- 1 file changed, 38 insertions(+), 35 deletions(-) diff --git a/lib/sycamore/sycamore/tests/unit/test_materialize.py b/lib/sycamore/sycamore/tests/unit/test_materialize.py index 179ba51a7..1b37632f1 100644 --- a/lib/sycamore/sycamore/tests/unit/test_materialize.py +++ b/lib/sycamore/sycamore/tests/unit/test_materialize.py @@ -703,47 +703,50 @@ def failing_map(doc): assert retry_counter.x == 23 # 2 successful, 21 unsuccessful def test_mrr_path_handling(self): + from unittest.mock import patch from pyarrow.fs import S3FileSystem, LocalFileSystem from sycamore.docset import DocSet - """Test MaterializeReadReliability path handling for both local and S3 paths""" ctx = sycamore.init(exec_mode=ExecMode.LOCAL) mrr = MaterializeReadReliability(max_batch=3) mrr._refresh_seen_files = lambda: None mrr.seen = set() ctx.rewrite_rules.append(mrr) - # Test various path formats - test_cases = [ - # Local paths - {"path": "/tmp/local/path", "expected_fs": "LocalFileSystem"}, - {"path": Path("/tmp/local/path2"), "expected_fs": "LocalFileSystem"}, - {"path": {"root": "/tmp/local/path3"}, "expected_fs": "LocalFileSystem"}, - {"path": {"root": Path("/tmp/local/path4")}, "expected_fs": "LocalFileSystem"}, - # S3 paths - {"path": "s3://test-example/path", "should_execute": True, "expected_fs": "S3FileSystem"}, - {"path": {"root": "s3://test-example/a/path"}, "should_execute": True, "expected_fs": "S3FileSystem"}, - ] - - MaterializeReadReliability.execute_reliably = lambda context, plan, mrr, **kwargs: None - for case in test_cases: - # Create a dummy materialize plan - plan = Materialize(None, ctx, path=case["path"]) - - # Test should_execute_reliably - - MaterializeReadReliability.maybe_execute_reliably(DocSet(context=ctx, plan=plan)) - - # Verify the path was properly initialized in mrr_instance - assert hasattr(mrr, "path"), f"mrr_instance missing path attribute for {case['path']}" - assert hasattr(mrr, "fs"), f"mrr_instance missing fs attribute for {case['path']}" - - # Verify correct filesystem type - if case["expected_fs"] == "S3FileSystem": - assert isinstance( - mrr.fs, S3FileSystem - ), f"Expected S3FileSystem for path {case['path']}, got {type(mrr.fs)}" - else: - assert isinstance( - mrr.fs, LocalFileSystem - ), f"Expected LocalFileSystem for path {case['path']}, got {type(mrr.fs)}" + # Use patch instead of modifying class + with patch.object(MaterializeReadReliability, "execute_reliably", return_value=None): + + # Test various path formats + test_cases = [ + # Local paths + {"path": "/tmp/local/path", "expected_fs": "LocalFileSystem"}, + {"path": Path("/tmp/local/path2"), "expected_fs": "LocalFileSystem"}, + {"path": {"root": "/tmp/local/path3"}, "expected_fs": "LocalFileSystem"}, + {"path": {"root": Path("/tmp/local/path4")}, "expected_fs": "LocalFileSystem"}, + # S3 paths + {"path": "s3://test-example/path", "should_execute": True, "expected_fs": "S3FileSystem"}, + {"path": {"root": "s3://test-example/a/path"}, "should_execute": True, "expected_fs": "S3FileSystem"}, + ] + + MaterializeReadReliability.execute_reliably = lambda context, plan, mrr, **kwargs: None + for case in test_cases: + # Create a dummy materialize plan + plan = Materialize(None, ctx, path=case["path"]) + + # Test should_execute_reliably + + MaterializeReadReliability.maybe_execute_reliably(DocSet(context=ctx, plan=plan)) + + # Verify the path was properly initialized in mrr_instance + assert hasattr(mrr, "path"), f"mrr_instance missing path attribute for {case['path']}" + assert hasattr(mrr, "fs"), f"mrr_instance missing fs attribute for {case['path']}" + + # Verify correct filesystem type + if case["expected_fs"] == "S3FileSystem": + assert isinstance( + mrr.fs, S3FileSystem + ), f"Expected S3FileSystem for path {case['path']}, got {type(mrr.fs)}" + else: + assert isinstance( + mrr.fs, LocalFileSystem + ), f"Expected LocalFileSystem for path {case['path']}, got {type(mrr.fs)}" From f58e564569684907469ddd9f7fdcf119295b4b39 Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Thu, 6 Feb 2025 17:07:50 -0800 Subject: [PATCH 02/11] Ensure parent docs are collected during doc reconstruct (#1159) --- .../opensearch/opensearch_reader.py | 11 +- .../opensearch/test_opensearch_read.py | 146 ++++++++++++++++-- 2 files changed, 143 insertions(+), 14 deletions(-) diff --git a/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py b/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py index 0e5ebb431..b1ba3325c 100644 --- a/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py +++ b/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py @@ -333,8 +333,17 @@ def _to_parent_doc(self, slice_query: dict[str, Any]) -> List[dict[str, Any]]: and hit["_source"]["parent_id"] is not None and hit["_source"]["parent_id"] not in parent_ids ): - results.append(hit) + # Only add a child doc whose parent_id has not been found, yet. parent_ids.add(hit["_source"]["parent_id"]) + results.append(hit) + elif ("parent_id" not in hit["_source"] or hit["_source"]["parent_id"] is None) and hit[ + "_id" + ] not in parent_ids: + # Add a parent doc if it's a match. + parent_id = hit["_id"] + parent_ids.add(parent_id) + hit["_source"]["parent_id"] = parent_id + results.append(hit) page += 1 diff --git a/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py b/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py index 15845ba73..de44f0fb5 100644 --- a/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py +++ b/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py @@ -78,19 +78,6 @@ def get_doc_count(os_client, index_name: str, query: Optional[Dict[str, Any]] = return res["count"] -""" -class MockLLM(LLM): - def __init__(self): - super().__init__(model_name="mock_model") - - def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: - return str(uuid.uuid4()) - - def is_chat_mode(self): - return True -""" - - class TestOpenSearchRead: INDEX_SETTINGS = { "body": { @@ -575,6 +562,139 @@ def test_parallel_query_with_pit(self, setup_index_large, os_client): if doc.parent_id is not None: assert doc.parent_id == expected_docs[doc.doc_id]["parent_id"] + def test_parallel_query_on_property_with_pit(self, setup_index, os_client): + context = sycamore.init(exec_mode=ExecMode.RAY) + key = "property1" + hidden = str(uuid.uuid4()) + query = {"query": {"match": {f"properties.{key}": hidden}}} + # make sure we read from pickle files -- this part won't be written into opensearch. + dicts = [ + { + "doc_id": "1", + "properties": {key: hidden}, + "elements": [ + {"properties": {"_element_index": 1}, "text_representation": "here is an animal that meows"}, + ], + }, + { + "doc_id": "2", + "elements": [ + {"id": 7, "properties": {"_element_index": 7}, "text_representation": "this is a cat"}, + { + "id": 1, + "properties": {"_element_index": 1}, + "text_representation": "here is an animal that moos", + }, + ], + }, + { + "doc_id": "3", + "elements": [ + {"properties": {"_element_index": 1}, "text_representation": "here is an animal that moos"}, + ], + }, + { + "doc_id": "4", + "elements": [ + {"id": 1, "properties": {"_element_index": 1}}, + ], + }, + { + "doc_id": "5", + "elements": [ + { + "properties": {"_element_index": 1}, + "text_representation": "the number of pages in this document are 253", + } + ], + }, + { + "doc_id": "6", + "elements": [ + {"id": 1, "properties": {"_element_index": 1}}, + ], + }, + ] + docs = [Document(item) for item in dicts] + + original_docs = ( + context.read.document(docs) + # .materialize(path={"root": cache_dir, "name": doc_to_name}) + .explode() + .write.opensearch( + os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, + index_name=setup_index, + index_settings=TestOpenSearchRead.INDEX_SETTINGS, + execute=False, + ) + .take_all() + ) + + os_client.indices.refresh(setup_index) + + expected_count = len(original_docs) + actual_count = get_doc_count(os_client, setup_index) + # refresh should have made all ingested docs immediately available for search + assert actual_count == expected_count, f"Expected {expected_count} documents, found {actual_count}" + + t0 = time.time() + retrieved_docs = context.read.opensearch( + os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, + index_name=setup_index, + query=query, + reconstruct_document=True, + ).take_all() + t1 = time.time() + + print(f"Retrieved {len(retrieved_docs)} documents in {t1 - t0} seconds") + expected_docs = self.get_ids(os_client, setup_index, True, query) + assert len(retrieved_docs) == len(expected_docs) + assert "1" == retrieved_docs[0].doc_id + assert hidden == retrieved_docs[0].properties[key] + + def test_parallel_query_on_extracted_property_with_pit(self, setup_index, os_client): + + path = str(TEST_DIR / "resources/data/pdfs/Ray.pdf") + context = sycamore.init(exec_mode=ExecMode.RAY) + llm = OpenAI(OpenAIModels.GPT_4O_MINI) + extractor = OpenAIEntityExtractor("title", llm=llm) + original_docs = ( + context.read.binary(path, binary_format="pdf") + .partition(ArynPartitioner(aryn_api_key=ARYN_API_KEY)) + .extract_entity(extractor) + # .materialize(path={"root": materialized_dir, "name": doc_to_name}) + .explode() + .write.opensearch( + os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, + index_name=setup_index, + index_settings=TestOpenSearchRead.INDEX_SETTINGS, + execute=False, + ) + .take_all() + ) + + os_client.indices.refresh(setup_index) + + expected_count = len(original_docs) + actual_count = get_doc_count(os_client, setup_index) + # refresh should have made all ingested docs immediately available for search + assert actual_count == expected_count, f"Expected {expected_count} documents, found {actual_count}" + + query = {"query": {"match": {"properties.title": "ray"}}} + + t0 = time.time() + retrieved_docs = context.read.opensearch( + os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, + index_name=setup_index, + query=query, + reconstruct_document=True, + ).take_all() + t1 = time.time() + + print(f"Retrieved {len(retrieved_docs)} documents in {t1 - t0} seconds") + expected_docs = self.get_ids(os_client, setup_index, True, query) + assert len(retrieved_docs) == len(expected_docs) + @staticmethod def get_ids( os_client, index_name, parents_only: bool = False, query: Optional[Dict[str, Any]] = None From 7d6fa5d0f5ed5f5897a16c3a8bbb2c2e2f49b4dd Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Fri, 7 Feb 2025 09:56:54 -0800 Subject: [PATCH 03/11] apparently I just didn't finish this???? (#1162) Signed-off-by: Henry Lindeman --- lib/sycamore/sycamore/transforms/extract_schema.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/sycamore/sycamore/transforms/extract_schema.py b/lib/sycamore/sycamore/transforms/extract_schema.py index 1f8ec66a5..a890b7ef4 100644 --- a/lib/sycamore/sycamore/transforms/extract_schema.py +++ b/lib/sycamore/sycamore/transforms/extract_schema.py @@ -114,7 +114,9 @@ def render_document(self, doc: Document) -> RenderedPrompt: format_args["entity"] = doc.properties.get("_schema_class", "entity") flat_props = flatten_data(doc.properties, prefix="doc_property", separator="_") format_args.update(flat_props) - format_args["elements"] = self._render_element_list_to_string + format_args["elements"] = self._render_element_list_to_string(doc) + if doc.text_representation is None: + format_args["doc_text"] = format_args["elements"] messages = _build_format_str(self.system, self.user, format_args) result = RenderedPrompt(messages=messages, response_format=schema) From 34c3fa06fb09b87f9ffab76358f4a1c386e7a29c Mon Sep 17 00:00:00 2001 From: Mark Lindblad Date: Fri, 7 Feb 2025 12:09:08 -0800 Subject: [PATCH 04/11] Make async DocParse methods in `aryn-sdk` not operate on non-DocParse async tasks (#1156) --- lib/aryn-sdk/aryn_sdk/partition/partition.py | 24 ++++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/lib/aryn-sdk/aryn_sdk/partition/partition.py b/lib/aryn-sdk/aryn_sdk/partition/partition.py index 64a419f83..27ffcaa58 100644 --- a/lib/aryn-sdk/aryn_sdk/partition/partition.py +++ b/lib/aryn-sdk/aryn_sdk/partition/partition.py @@ -22,6 +22,7 @@ _logger.addHandler(logging.StreamHandler(sys.stderr)) g_version = "0.1.13" +g_parameters = {"path_filter": "^/v1/document/partition$"} class PartitionError(Exception): @@ -447,7 +448,9 @@ def partition_file_async_result( specific_task_url = f"{async_result_url.rstrip('/')}/{task_id}" headers = _generate_headers(aryn_config.api_key()) - response = requests.get(specific_task_url, headers=headers, stream=_should_stream(), verify=ssl_verify) + response = requests.get( + specific_task_url, params=g_parameters, headers=headers, stream=_should_stream(), verify=ssl_verify + ) if response.status_code == 200: return {"status": "done", "status_code": response.status_code, "result": response.json()} @@ -485,7 +488,9 @@ def partition_file_async_cancel( specific_task_url = f"{async_cancel_url.rstrip('/')}/{task_id}" headers = _generate_headers(aryn_config.api_key()) - response = requests.post(specific_task_url, headers=headers, stream=_should_stream(), verify=ssl_verify) + response = requests.post( + specific_task_url, params=g_parameters, headers=headers, stream=_should_stream(), verify=ssl_verify + ) if response.status_code == 200: return True elif response.status_code == 404: @@ -522,14 +527,13 @@ def partition_file_async_list( aryn_config = _process_config(aryn_api_key, aryn_config) headers = _generate_headers(aryn_config.api_key()) - response = requests.get(async_list_url, headers=headers, stream=_should_stream(), verify=ssl_verify) - - all_tasks = response.json()["tasks"] - result = {} - for task_id in all_tasks.keys(): - if all_tasks[task_id]["path"] == "/v1/document/partition": - del all_tasks[task_id]["path"] - result[task_id] = all_tasks[task_id] + response = requests.get( + async_list_url, params=g_parameters, headers=headers, stream=_should_stream(), verify=ssl_verify + ) + + result = response.json()["tasks"] + for v in result.values(): + del v["path"] return result From afd15c7cfe189cf4ea37d5b47be00969bcde73ed Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Fri, 7 Feb 2025 14:26:51 -0800 Subject: [PATCH 05/11] [llm unify 5a/n] Add JinjaPompt and re-convert extract entities (#1161) * add jinja prompts and convert extract entity to use it Signed-off-by: Henry Lindeman * delete commented out / dead code Signed-off-by: Henry Lindeman * JinjaPrompt docstring Signed-off-by: Henry Lindeman * add comments bc otherwise this is very dense Signed-off-by: Henry Lindeman * add norender() directive to jinja prompts when they shouldn't render Signed-off-by: Henry Lindeman * change FANCY_BATCHED_LIST to BATCHED_LIST_WITH_METADATA Signed-off-by: Henry Lindeman * pr comments Signed-off-by: Henry Lindeman --------- Signed-off-by: Henry Lindeman --- .../sycamore/llms/prompts/__init__.py | 8 +- .../sycamore/llms/prompts/default_prompts.py | 29 +- .../sycamore/llms/prompts/jinja_fragments.py | 31 ++ lib/sycamore/sycamore/llms/prompts/prompts.py | 119 ++++++++ .../sycamore/tests/unit/llms/test_llms.py | 14 - .../sycamore/tests/unit/test_docset.py | 5 +- .../unit/transforms/test_extract_entity.py | 55 ++-- .../tests/unit/transforms/test_llm_filter.py | 3 +- .../sycamore/transforms/extract_entity.py | 268 ++++++++++-------- 9 files changed, 366 insertions(+), 166 deletions(-) create mode 100644 lib/sycamore/sycamore/llms/prompts/jinja_fragments.py diff --git a/lib/sycamore/sycamore/llms/prompts/__init__.py b/lib/sycamore/sycamore/llms/prompts/__init__.py index 6261099a6..daa0d08c3 100644 --- a/lib/sycamore/sycamore/llms/prompts/__init__.py +++ b/lib/sycamore/sycamore/llms/prompts/__init__.py @@ -4,8 +4,8 @@ from sycamore.llms.prompts.default_prompts import ( SimplePrompt, - EntityExtractorZeroShotGuidancePrompt, - EntityExtractorFewShotGuidancePrompt, + EntityExtractorZeroShotJinjaPrompt, + EntityExtractorFewShotJinjaPrompt, TextSummarizerGuidancePrompt, SchemaZeroShotGuidancePrompt, PropertiesZeroShotGuidancePrompt, @@ -26,8 +26,8 @@ prompts = [ "SimplePrompt", - "EntityExtractorZeroShotGuidancePrompt", - "EntityExtractorFewShotGuidancePrompt", + "EntityExtractorZeroShotJinjaPrompt", + "EntityExtractorFewShotJinjaPrompt", "TextSummarizerGuidancePrompt", "SchemaZeroShotGuidancePrompt", "PropertiesZeroShotGuidancePrompt", diff --git a/lib/sycamore/sycamore/llms/prompts/default_prompts.py b/lib/sycamore/sycamore/llms/prompts/default_prompts.py index d54d6260f..7e72813dc 100644 --- a/lib/sycamore/sycamore/llms/prompts/default_prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/default_prompts.py @@ -3,7 +3,7 @@ from typing import Any, Optional, Type import textwrap -from sycamore.llms.prompts.prompts import ElementListPrompt, ElementPrompt, StaticPrompt +from sycamore.llms.prompts.prompts import ElementListPrompt, ElementPrompt, StaticPrompt, JinjaPrompt logger = logging.getLogger(__name__) @@ -49,11 +49,14 @@ class _EntityExtractorZeroShotGuidancePrompt(_SimplePrompt): """ -EntityExtractorZeroShotGuidancePrompt = ElementListPrompt( +EntityExtractorZeroShotJinjaPrompt = JinjaPrompt( system="You are a helpful entity extractor", - user="""You are given a few text elements of a document. The {entity} of the document is in these few text elements. - Using this context, FIND, COPY, and RETURN the {entity}. DO NOT REPHRASE OR MAKE UP AN ANSWER. - {elements}""", + user="""You are given a few text elements of a document. The {{ entity }} of the document is in these few text elements. + Using this context, FIND, COPY, and RETURN the {{ entity }}. DO NOT REPHRASE OR MAKE UP AN ANSWER. + {% for elt in doc.elements[:num_elements] %} ELEMENT {{ elt.element_index }}: {{ elt.field_to_value(field) }} + {% endfor %}""", + field="text_representation", + num_elements=35, ) @@ -69,15 +72,17 @@ class _EntityExtractorFewShotGuidancePrompt(SimplePrompt): """ -EntityExtractorFewShotGuidancePrompt = ElementListPrompt( +EntityExtractorFewShotJinjaPrompt = JinjaPrompt( system="You are a helpful entity extractor", - user="""You are given a few text elements of a document. The {entity} of the document is in these few text elements. Here are - some example groups of text elements where the {entity} has been identified. - {examples} - Using the context from the document and the provided examples, FIND, COPY, and RETURN the {entity}. Only return the {entity} as part + user="""You are given a few text elements of a document. The {{ entity }} of the document is in these few text elements. Here are + some example groups of text elements where the {{ entity }} has been identified. + {{ examples }} + Using the context from the document and the provided examples, FIND, COPY, and RETURN the {{ entity }}. Only return the {{ entity }} as part of your answer. DO NOT REPHRASE OR MAKE UP AN ANSWER. - {elements} - """, + {% for elt in doc.elements[:num_elements] %} ELEMENT {{ elt.element_index }}: {{ elt.field_to_value(field) }} + {% endfor %}""", + field="text_representation", + num_elements=35, ) diff --git a/lib/sycamore/sycamore/llms/prompts/jinja_fragments.py b/lib/sycamore/sycamore/llms/prompts/jinja_fragments.py new file mode 100644 index 000000000..29d87e11c --- /dev/null +++ b/lib/sycamore/sycamore/llms/prompts/jinja_fragments.py @@ -0,0 +1,31 @@ +J_ELEMENT_LIST = """\ +{% for elt in doc.elements[:num_elements] %}ELEMENT {{ elt.element_index }}: {{ elt.field_to_value(field) }} +{% endfor %}""" + + +# Directive to not render the template if the iteration var has +# surpassed the number of batches. +J_BATCH_OOB_CHECK = ( + "{%- if doc.properties[iteration_var] >= doc.properties[batch_key]|count -%}{{ norender() }}{%- endif -%}\n" +) + +J_ELEMENT_BATCHED_LIST = ( + J_BATCH_OOB_CHECK + + """\ +{% for i in doc.properties[batch_key][doc.properties[iteration_var]] -%} +{%- set elt = doc.elements[i] -%} +ELEMENT {{ loop.index }}: {{ elt.field_to_value(field) }} +{% endfor -%}""" +) + +J_ELEMENT_BATCHED_LIST_WITH_METADATA = ( + J_BATCH_OOB_CHECK + + """\ +{% for i in doc.properties[batch_key][doc.properties[iteration_var]] -%} +{%- set elt = doc.elements[i] -%} +{% if "type" in elt %}Element type: {{ elt.type }}{% endif %} +{% if "page_number" in elt.properties %}Page_number: {{ elt.properties["page_number"] }}{% endif %} +{% if "_element_index" in elt.properties %}Element_index: {{ elt.properties["_element_index"] }}{% endif %} +Text: {{ elt.field_to_value(field) }} +{% endfor -%}""" +) diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index e800fc7a6..39dbd1598 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -5,6 +5,7 @@ import pydantic from PIL import Image from sycamore.data.document import Document, Element +from sycamore.functions.tokenizer import Tokenizer from sycamore.connectors.common import flatten_data @@ -36,6 +37,10 @@ class RenderedPrompt: messages: list[RenderedMessage] response_format: Union[None, dict[str, Any], type[pydantic.BaseModel]] = None + def token_count(self, tokenizer: Tokenizer) -> int: + total_text = " ".join(m.content for m in self.messages) + return len(tokenizer.tokenize(total_text)) + class SycamorePrompt: """Base class/API for all Sycamore LLM Prompt objects. Sycamore Prompts @@ -451,3 +456,117 @@ def render_document(self, doc: Document) -> RenderedPrompt: def render_multiple_documents(self, docs: list[Document]) -> RenderedPrompt: return self.render_generic() + + +class NoRender(Exception): + def __init__(self): + super().__init__() + + +def raise_no_render(): + raise NoRender() + + +def _deserialize_jinja_prompt(kwargs): + return JinjaPrompt(**kwargs) + + +class JinjaPrompt(SycamorePrompt): + """A prompt that uses the Jinja templating system to render documents, with + a system and user prompt. + + Args: + system: The system prompt template, using Jinja syntax. + user: The user prompt template or prompt templates, using Jinja syntax. + kwargs: Additional key-value pairs that will be made available to the + rendering engine. + + Example: + .. code-block:: python + + prompt = JinjaPrompt( + system="You are a helpful entity extractor that extracts a json object or list to" + " populate a data processing system", + user='''Below, you will be given a series of segments of an NTSB report and a question. + Your job is to provide the answer to the question based on the value provided. + Your response should ONLY contain the answer to the question. If you are not able + to extract the new field given the information, respond with "None". The type + of your response should be a JSON list of strings. + Field value: + {% for elt in doc.elements[:10] %} + ELEMENT {{ elt.element_index }}: {{ elt.field_to_value(field) }} + {% endfor %} + Answer the question "{{ question }}":''', + question="What aircraft parts were damaged in this report?", + field="text_representation", + ) + ds.llm_map(prompt, output_field="damaged_parts", llm=OpenAI(OpenAIModels.GPT_4O)) + + """ + + def __init__(self, *, system: Optional[str] = None, user: Union[None, str, list[str]] = None, **kwargs): + from jinja2.sandbox import SandboxedEnvironment + from jinja2 import Template + + super().__init__() + self.system = system + self.user = user + self.kwargs = kwargs + self._env = SandboxedEnvironment() + self._sys_template: Optional[Template] = None + self._user_templates: Union[None, Template, list[Template]] = None + + def __reduce__(self): + # Cannot serialize compiled templates - so force recompilation + return _deserialize_jinja_prompt, ({"system": self.system, "user": self.user, **self.kwargs},) + + def render_document(self, doc: Document) -> RenderedPrompt: + """Render this document using Jinja's template rendering system. + The template gets references to: + + - doc: the document + - **self.kwargs: other keyword arguments held by this prompt are + available by name. + + Args: + doc: The document to render + + Returns: + A rendered prompt containing information from the document. + """ + if self._sys_template is None and self.system is not None: + self._sys_template = self._env.from_string(source=self.system, globals={"norender": raise_no_render}) + if self._user_templates is None and self.user is not None: + if isinstance(self.user, str): + self._user_templates = self._env.from_string(source=self.user, globals={"norender": raise_no_render}) + else: + self._user_templates = [ + self._env.from_string(source=u, globals={"norender": raise_no_render}) for u in self.user + ] + + render_args = copy.deepcopy(self.kwargs) + render_args["doc"] = doc + + messages = [] + if self._sys_template is not None: + try: + system = self._sys_template.render(render_args) + messages.append(RenderedMessage(role="system", content=system)) + except NoRender: + return RenderedPrompt(messages=[]) + if self._user_templates is not None: + if isinstance(self._user_templates, list): + for t in self._user_templates: + try: + content = t.render(render_args) + messages.append(RenderedMessage(role="user", content=content)) + except NoRender: + return RenderedPrompt(messages=[]) + else: + try: + content = self._user_templates.render(render_args) + messages.append(RenderedMessage(role="user", content=content)) + except NoRender: + return RenderedPrompt(messages=[]) + + return RenderedPrompt(messages=messages) diff --git a/lib/sycamore/sycamore/tests/unit/llms/test_llms.py b/lib/sycamore/sycamore/tests/unit/llms/test_llms.py index d03e9ea55..c88092ca7 100644 --- a/lib/sycamore/sycamore/tests/unit/llms/test_llms.py +++ b/lib/sycamore/sycamore/tests/unit/llms/test_llms.py @@ -1,11 +1,9 @@ from pathlib import Path from unittest.mock import patch -import pytest from sycamore.llms import OpenAI, OpenAIModels, Bedrock, BedrockModels, get_llm, MODELS from sycamore.llms.llms import FakeLLM from sycamore.llms.prompts import RenderedPrompt, RenderedMessage -from sycamore.llms.prompts import EntityExtractorFewShotGuidancePrompt, EntityExtractorZeroShotGuidancePrompt from sycamore.utils.cache import DiskCache import datetime from sycamore.utils.thread_local import ThreadLocalAccess @@ -46,18 +44,6 @@ def test_openai_davinci_fallback(): assert llm._model_name == OpenAIModels.GPT_3_5_TURBO_INSTRUCT.value.name -# Skip bc prompts are changing entirely -@pytest.mark.skip -def test_deprecated_prompt_fallback(): - from sycamore.llms.prompts.default_prompts import ENTITY_EXTRACTOR_ZERO_SHOT_GUIDANCE_PROMPT - - assert isinstance(ENTITY_EXTRACTOR_ZERO_SHOT_GUIDANCE_PROMPT, EntityExtractorZeroShotGuidancePrompt) - - from sycamore.llms.prompts import ENTITY_EXTRACTOR_FEW_SHOT_GUIDANCE_PROMPT - - assert isinstance(ENTITY_EXTRACTOR_FEW_SHOT_GUIDANCE_PROMPT, EntityExtractorFewShotGuidancePrompt) - - def test_model_list(): assert "openai." + OpenAIModels.TEXT_DAVINCI.value.name in MODELS assert "bedrock." + BedrockModels.CLAUDE_3_5_SONNET.value.name in MODELS diff --git a/lib/sycamore/sycamore/tests/unit/test_docset.py b/lib/sycamore/sycamore/tests/unit/test_docset.py index 3a56b8252..4c69cbc61 100644 --- a/lib/sycamore/sycamore/tests/unit/test_docset.py +++ b/lib/sycamore/sycamore/tests/unit/test_docset.py @@ -28,7 +28,7 @@ Query, ) from sycamore.transforms import Filter -from sycamore.transforms.base import get_name_from_callable +from sycamore.transforms.base import get_name_from_callable, CompositeTransform from sycamore.transforms.base_llm import LLMMap from sycamore.transforms.extract_entity import OpenAIEntityExtractor from sycamore.transforms.extract_schema import SchemaExtractor, LLMPropertyExtractor @@ -43,6 +43,7 @@ def __init__(self): super().__init__(model_name="mock_model") def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + print(prompt) if llm_kwargs is None: llm_kwargs = {} if prompt.messages[-1].content.endswith("Element_index: 1\nText: third element\n"): @@ -182,7 +183,7 @@ def test_llm_extract_entity(self, mocker): llm = mocker.Mock(spec=LLM) docset = DocSet(context, None) docset = docset.extract_entity(entity_extractor=OpenAIEntityExtractor("title", llm=llm, prompt_template="")) - assert isinstance(docset.lineage(), LLMMap) + assert isinstance(docset.lineage(), CompositeTransform) def test_query(self, mocker): context = mocker.Mock(spec=Context) diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py b/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py index db420d6af..a2d2588bd 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_extract_entity.py @@ -18,11 +18,15 @@ def __init__(self): super().__init__(model_name="mock_model") def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + print(prompt) if len(prompt.messages) == 1: usermessage = prompt.messages[0].content else: usermessage = prompt.messages[1].content + if "returnnone" in usermessage: + return "None" + if usermessage.startswith("Hi"): return usermessage @@ -56,7 +60,7 @@ class TestEntityExtraction: "content": {"binary": None, "text": "text"}, "parent_id": None, "properties": {"path": "s3://path"}, - "embedding": {"binary": None, "text": None}, + "embedding": None, "elements": [ { "type": "title", @@ -78,14 +82,14 @@ def test_extract_entity_zero_shot(self, mocker): llm = MockLLM() extractor = OpenAIEntityExtractor("title", llm=llm) llm_map = extractor.as_llm_map(None) - out_docs = llm_map.run([self.doc]) + out_docs = llm_map._local_process([self.doc]) assert out_docs[0].properties.get("title") == "title1" def test_extract_entity_zero_shot_custom_field(self, mocker): llm = MockLLM() extractor = OpenAIEntityExtractor("title", llm=llm, field="properties.entity.author") llm_map = extractor.as_llm_map(None) - out_docs = llm_map.run([self.doc]) + out_docs = llm_map._local_process([self.doc]) assert out_docs[0].properties.get("title") == "Jack Black" def test_extract_entity_with_context_llm(self, mocker): @@ -97,35 +101,35 @@ def test_extract_entity_with_context_llm(self, mocker): ) extractor = OpenAIEntityExtractor("title") llm_map = extractor.as_llm_map(None, context=context) - out_docs = llm_map.run([self.doc]) + out_docs = llm_map._local_process([self.doc]) assert out_docs[0].properties.get("title") == "title1" def test_extract_entity_few_shot(self, mocker): llm = MockLLM() extractor = OpenAIEntityExtractor("title", llm=llm, prompt_template="title") llm_map = extractor.as_llm_map(None) - out_docs = llm_map.run([self.doc]) + out_docs = llm_map._local_process([self.doc]) assert out_docs[0].properties.get("title") == "title2" def test_extract_entity_document_field_messages(self, mocker): llm = MockLLM() extractor = OpenAIEntityExtractor("title", llm=llm, use_elements=False, prompt=[], field="properties.path") llm_map = extractor.as_llm_map(None) - out_docs = llm_map.run([self.doc]) + out_docs = llm_map._local_process([self.doc]) assert out_docs[0].properties.get("title") == "alt_title" def test_extract_entity_document_field_string(self, mocker): llm = MockLLM() extractor = OpenAIEntityExtractor("title", llm=llm, use_elements=False, prompt="", field="properties.path") llm_map = extractor.as_llm_map(None) - out_docs = llm_map.run([self.doc]) + out_docs = llm_map._local_process([self.doc]) assert out_docs[0].properties.get("title") == "alt_title" def test_extract_entity_with_elements_and_string_prompt(self, mocker): llm = MockLLM() extractor = OpenAIEntityExtractor("title", llm=llm, use_elements=True, prompt="Hi ") llm_map = extractor.as_llm_map(None) - outdocs = llm_map.run([self.doc]) + outdocs = llm_map._local_process([self.doc]) assert outdocs[0].properties.get("title").startswith("Hi") assert "text1" in outdocs[0].properties.get("title") assert "text2" in outdocs[0].properties.get("title") @@ -135,11 +139,28 @@ def test_extract_entity_with_elements_and_messages_prompt(self, mocker): prompt_messages = [{"role": "system", "content": "Yo"}, {"role": "user", "content": "ho!"}] extractor = OpenAIEntityExtractor("title", llm=llm, use_elements=True, prompt=prompt_messages) llm_map = extractor.as_llm_map(None) - outdocs = llm_map.run([self.doc]) + outdocs = llm_map._local_process([self.doc]) assert outdocs[0].properties.get("title").startswith("ho there!") assert "text1" in outdocs[0].properties.get("title") assert "text2" in outdocs[0].properties.get("title") + def test_extract_entity_iteration_var_oob(self, mocker): + llm = MockLLM() + llm.generate = MagicMock(wraps=llm.generate) + extractor = OpenAIEntityExtractor( + "returnnone", + llm=llm, + field="properties.entity.author", + tokenizer=MockTokenizer(), + max_tokens=10, + prompt="{{ entity }}", + ) + llm_map = extractor.as_llm_map(None) + out_docs = llm_map._local_process([self.doc]) + + assert llm.generate.call_count == 2 + assert out_docs[0].properties["returnnone"] == "None" + def test_extract_entity_with_similarity_sorting(self, mocker): doc_list = [ Document( @@ -222,7 +243,7 @@ def test_extract_entity_with_tokenizer(self, mocker): prompt=[], field="text_representation", tokenizer=mock_tokenizer, - max_tokens=20, # Low token limit to test windowing + max_tokens=42, # Low token limit to test windowing ) entity_docset = docset.extract_entity( @@ -230,12 +251,12 @@ def test_extract_entity_with_tokenizer(self, mocker): ) taken = entity_docset.take() - assert taken[0].properties[f"{new_field}_source_element_index"] == {0, 1, 2} - assert taken[1].properties[f"{new_field}_source_element_index"] == {2} + assert taken[0].properties[f"{new_field}_source_indices"] == [0, 1, 2] + assert taken[1].properties[f"{new_field}_source_indices"] == [1] # set to array index, not element_index assert taken[0].properties[new_field] == "4" assert taken[1].properties[new_field] == "5" - assert taken[0].elements[0]["properties"]["_autogen_LLMExtractEntityOutput_source_element_index"] == {0, 1, 2} - assert taken[0].elements[1]["properties"]["_autogen_LLMExtractEntityOutput_source_element_index"] == {0, 1, 2} - assert taken[0].elements[2]["properties"]["_autogen_LLMExtractEntityOutput_source_element_index"] == {0, 1, 2} - assert taken[1].elements[0]["properties"]["_autogen_LLMExtractEntityOutput_source_element_index"] == {1} - assert taken[1].elements[1]["properties"]["_autogen_LLMExtractEntityOutput_source_element_index"] == {2} + assert taken[0].elements[0]["properties"]["_autogen_LLMExtractEntityOutput_source_indices"] == [0, 1, 2] + assert taken[0].elements[1]["properties"]["_autogen_LLMExtractEntityOutput_source_indices"] == [0, 1, 2] + assert taken[0].elements[2]["properties"]["_autogen_LLMExtractEntityOutput_source_indices"] == [0, 1, 2] + assert taken[1].elements[0]["properties"]["_autogen_LLMExtractEntityOutput_source_indices"] == [0] + assert taken[1].elements[1]["properties"]["_autogen_LLMExtractEntityOutput_source_indices"] == [1] diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_llm_filter.py b/lib/sycamore/sycamore/tests/unit/transforms/test_llm_filter.py index 5e195a2ca..129d1332f 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_llm_filter.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_llm_filter.py @@ -31,7 +31,8 @@ properties={"_element_index": 2}, text_representation="very long element with many words that might exceed token limit." " Specifically, it has so many words that even with the additional contextualization" - " like 'Element type' and 'page number' it still overflows", + " like 'Element type' and 'page number' it still overflows. So many words, in fact," + " that even rendering the entire prompt still overflows it.", ), # llm_filter result = 5 ], ), diff --git a/lib/sycamore/sycamore/transforms/extract_entity.py b/lib/sycamore/sycamore/transforms/extract_entity.py index 3cd1839df..9d7ce8397 100644 --- a/lib/sycamore/sycamore/transforms/extract_entity.py +++ b/lib/sycamore/sycamore/transforms/extract_entity.py @@ -1,23 +1,24 @@ from abc import ABC, abstractmethod -from typing import Callable, Any, Optional, Union, cast +from typing import Callable, Any, Optional, Union from sycamore.context import Context, context_params, OperationTypes from sycamore.data import Element, Document from sycamore.llms import LLM from sycamore.llms.prompts.default_prompts import ( - EntityExtractorZeroShotGuidancePrompt, - EntityExtractorFewShotGuidancePrompt, + EntityExtractorZeroShotJinjaPrompt, + EntityExtractorFewShotJinjaPrompt, _EntityExtractorZeroShotGuidancePrompt, _EntityExtractorFewShotGuidancePrompt, ) from sycamore.llms.prompts.prompts import ( - ElementListIterPrompt, - ElementListPrompt, RenderedMessage, SycamorePrompt, RenderedPrompt, + JinjaPrompt, ) +from sycamore.llms.prompts.jinja_fragments import J_ELEMENT_BATCHED_LIST, J_ELEMENT_BATCHED_LIST_WITH_METADATA from sycamore.plan_nodes import Node +from sycamore.transforms.base import CompositeTransform, BaseMapTransform from sycamore.transforms.base_llm import LLMMap from sycamore.transforms.map import Map from sycamore.utils.time_trace import timetrace @@ -129,133 +130,168 @@ def __init__( self._similarity_query = similarity_query self._similarity_scorer = similarity_scorer - @context_params(OperationTypes.INFORMATION_EXTRACTOR) - def as_llm_map( - self, child: Optional[Node], context: Optional[Context] = None, llm: Optional[LLM] = None, **kwargs - ) -> Node: - if llm is None: - llm = self._llm - assert llm is not None, "Could not find an LLM to use" - prompt: SycamorePrompt # grr mypy + def _get_const_variables(self) -> dict[str, str]: + # These kept popping up in various places across the transforms + return { + "similarity_field_name": f"{self._field}_similarity_score", + "source_idx_key": f"{self._entity_name}_source_indices", + "batch_key": f"{self._entity_name}_batches", + "iteration_var_name": f"{self._entity_name}_i", + } + + def _get_prompt(self) -> SycamorePrompt: + # there's like a million paths to cover but I think I have + # them all + vars = self._get_const_variables() + if self._prompt_formatter is not element_list_formatter: + j_elements = "{{ formatter(doc.elements) }}" + elif self._tokenizer is not None: + j_elements = J_ELEMENT_BATCHED_LIST_WITH_METADATA + else: + j_elements = J_ELEMENT_BATCHED_LIST + if not self._use_elements: + if self._prompt is None: + raise ValueError("prompt must be specified if use_elements is False") + j_elements = "{{ doc.field_to_value(field) }}" + + common_params = { + "field": self._field, + "num_elements": self._num_of_elements, + "batch_key": vars["batch_key"], + "iteration_var": vars["iteration_var_name"], + "entity": self._entity_name, + } + if self._prompt is not None: if isinstance(self._prompt, str): - prompt = ElementListPrompt(user=self._prompt + "\n{elements}") + return JinjaPrompt(system=None, user=self._prompt + "\n" + j_elements, **common_params) else: system = None if len(self._prompt) > 0 and self._prompt[0]["role"] == "system": system = self._prompt[0]["content"] - user = [p["content"] for p in self._prompt[1:]] + ["{elements}"] + user = [p["content"] for p in self._prompt[1:]] + [j_elements] else: - user = [p["content"] for p in self._prompt] + ["{elements}"] - prompt = ElementListPrompt(system=system, user=user) + user = [p["content"] for p in self._prompt] + [j_elements] + return JinjaPrompt(system=system, user=user, **common_params) elif self._prompt_template is not None: - prompt = EntityExtractorFewShotGuidancePrompt - prompt = cast(ElementListPrompt, prompt.set(examples=self._prompt_template)) + return EntityExtractorFewShotJinjaPrompt.set(examples=self._prompt_template, **common_params) else: - prompt = EntityExtractorZeroShotGuidancePrompt - - if self._tokenizer is not None: - - def validate(d: Document) -> bool: - return d.properties.get(self._entity_name, "None") != "None" - - def elt_list_ctor(elts: list[Element]) -> str: - if self._prompt_formatter is not element_list_formatter: - return self._prompt_formatter(elts, self._field) - combined_text = "" - for element in elts: - if "type" in element: - combined_text += f"Element type: {element['type']}\n" - if "page_number" in element["properties"]: - combined_text += f"Page_number: {element['properties']['page_number']}\n" - if "_element_index" in element["properties"]: - combined_text += f"Element_index: {element['properties']['_element_index']}\n" - combined_text += f"Text: {element.field_to_value(self._field)}\n" - return combined_text - - source_idx_key = f"{self._entity_name}_source_element_index" - - def eb(elts: list[Element]) -> list[list[Element]]: + return EntityExtractorZeroShotJinjaPrompt.set(**common_params) + + def _make_preprocess_fn(self, prompt: SycamorePrompt) -> Callable[[Document], Document]: + vars = self._get_const_variables() + + def sort_and_batch_elements(doc: Document) -> Document: + if self._similarity_query is not None and self._similarity_scorer is not None: + # If we did similarity scoring sort the elements (keep track of their original + # locations though) + elements = sorted( + [(e, i) for i, e in enumerate(doc.elements)], + key=(lambda e_i: e_i[0].properties.get(vars["similarity_field_name"], float("-inf"))), + reverse=True, + ) + else: + elements = [(e, i) for i, e in enumerate(doc.elements)] + + batches = [] + if self._tokenizer is not None: + curr_club = [] curr_tks = 0 - curr_batch: list[Element] = [] - batches = [] - source_indices = set() - assert ( - self._tokenizer is not None - ), "Cannot batch elements based on token counts because tokenizer is None" - for e in elts: - eltl = cast(ElementListPrompt, prompt).element_list_constructor([e]) - tks = len(self._tokenizer.tokenize(eltl)) + # We'll create a dummy document and consecutively + # add more elements to it, rendering out to a prompt + # at each step and counting tokens to find breakpoints. + dummy = doc.copy() + dummy.properties = doc.properties.copy() + dummy.properties[vars["iteration_var_name"]] = 0 + dummy.elements = [] + for e, i in elements: + dummy.elements.append(e) + curr_club.append(i) + dummy.properties[vars["batch_key"]] = [curr_club] + rendered = prompt.render_document(dummy) + tks = rendered.token_count(self._tokenizer) if tks + curr_tks > self._max_tokens: - batches.append(curr_batch) - curr_tks = tks - curr_batch = [e] - source_indices = {e.element_index} - e.properties[source_idx_key] = source_indices + curr_club.pop() + batches.append(curr_club) + curr_club = [i] + e.properties[vars["source_idx_key"]] = curr_club + dummy.elements = [e] + curr_tks = 0 else: - e.properties[source_idx_key] = source_indices - source_indices.add(e.element_index) - curr_batch.append(e) + e.properties[vars["source_idx_key"]] = curr_club curr_tks += tks - batches.append(curr_batch) - return batches - - iteration_var_name = f"{self._entity_name}_i" - - def postprocess(d: Document) -> Document: - last_eclub: set[int] = set() - club_idx = 0 - target_club_idx = d.properties[iteration_var_name] - for e in d.elements: - if len(last_eclub) > 0 and e.properties[source_idx_key] != last_eclub: - club_idx += 1 - last_eclub = e.properties[source_idx_key] - if club_idx == target_club_idx: - d.properties[source_idx_key] = last_eclub - break - return d + batches.append(curr_club) + else: + # If no tokenizer, we run a single batch with the first num_of_elements. + batches = [[i for e, i in elements[: self._num_of_elements]]] + for i in batches[0]: + doc.elements[i].properties[vars["source_idx_key"]] = batches[0] - prompt = ElementListIterPrompt( - system=prompt.system, - user=prompt.user, - element_list_constructor=elt_list_ctor, - element_batcher=eb, - entity=self._entity_name, - examples=self._prompt_template, - iteration_var_name=iteration_var_name, - ) + doc.properties[vars["batch_key"]] = batches + return doc - llm_map = LLMMap( - child, prompt, self._entity_name, llm, iteration_var=iteration_var_name, validate=validate, **kwargs - ) - ppmap = Map(llm_map, f=postprocess) - return ppmap + return sort_and_batch_elements - elif not self._use_elements: - if self._prompt is None: - raise ValueError("prompt must be specified if use_elements is False") - if isinstance(self._prompt, str): - prompt = FieldToValuePrompt( - messages=[RenderedMessage(role="user", content=self._prompt + "{value}")], field=self._field - ) - elif isinstance(self._prompt, list): - ms = [RenderedMessage(role=m["role"], content=m["content"]) for m in self._prompt] - ms.append(RenderedMessage(role="user", content="{value}")) - prompt = FieldToValuePrompt(messages=ms, field=self._field) - return LLMMap(child, prompt, self._entity_name, llm, **kwargs) - - def elt_sorter(elts: list[Element]) -> list[Element]: - sorter_inner = make_element_sorter_fn(self._field, self._similarity_query, self._similarity_scorer) - dummy_doc = Document(elements=elts) - sorter_inner(dummy_doc) - return dummy_doc.elements - - prompt = prompt.set(element_select=lambda e: elt_sorter(e)[: self._num_of_elements]) - prompt = prompt.set(element_list_constructor=lambda e: self._prompt_formatter(e, self._field)) - prompt = prompt.set(entity=self._entity_name) - - llm_map = LLMMap(child, prompt, self._entity_name, llm, **kwargs) - return llm_map + @context_params(OperationTypes.INFORMATION_EXTRACTOR) + def as_llm_map( + self, child: Optional[Node], context: Optional[Context] = None, llm: Optional[LLM] = None, **kwargs + ) -> Node: + # represent this EntityExtractor as a CompositeTransform consisting of some + # preprocessing (set up batches, sort elements, etc), the central LLMMap, + # and some postprocessing (derive the source_indices property) + if llm is None: + llm = self._llm + assert llm is not None, "Could not find an LLM to use" + + prompt = self._get_prompt() + preprocess = self._make_preprocess_fn(prompt) + vars = self._get_const_variables() + + def validate(d: Document) -> bool: + return self._tokenizer is None or d.properties.get(self._entity_name, "None") != "None" + + def postprocess(d: Document) -> Document: + target_club_idx = d.properties[vars["iteration_var_name"]] + if target_club_idx >= len(d.properties[vars["batch_key"]]): + return d + batch = d.properties[vars["batch_key"]][target_club_idx] + d.properties[vars["source_idx_key"]] = batch + return d + + nodes: list[BaseMapTransform] = [] + head_node: Node + if self._similarity_query is not None and self._similarity_scorer is not None: + # If similarity we add a ScoreSimilarity node to the sub-pipeline + from sycamore.transforms.similarity import ScoreSimilarity + + head_node = ScoreSimilarity( + child, # type: ignore + similarity_scorer=self._similarity_scorer, + query=self._similarity_query, + score_property_name=vars["similarity_field_name"], + ) + nodes.append(head_node) + else: + head_node = child # type: ignore + + head_node = Map(head_node, f=preprocess) + nodes.append(head_node) + head_node = LLMMap( + head_node, + prompt, + self._entity_name, + llm, + validate=validate, + iteration_var=vars["iteration_var_name"], + max_tries=100, + **kwargs, + ) + nodes.append(head_node) + head_node = Map(head_node, f=postprocess) + nodes.append(head_node) + comptransform = CompositeTransform(child, []) # type: ignore + comptransform.nodes = nodes + return comptransform @context_params(OperationTypes.INFORMATION_EXTRACTOR) @timetrace("OaExtract") From 574b150a5846298fbd1343f3f3d98f96189c64ed Mon Sep 17 00:00:00 2001 From: Dhruv Kaliraman <112497058+dhruvkaliraman7@users.noreply.github.com> Date: Fri, 7 Feb 2025 15:01:43 -0800 Subject: [PATCH 06/11] Add list to cast types (#1163) * Add list to cast types * Modify test --- lib/sycamore/sycamore/tests/unit/transforms/test_schema.py | 5 ++++- lib/sycamore/sycamore/transforms/extract_schema.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_schema.py b/lib/sycamore/sycamore/tests/unit/transforms/test_schema.py index 80f4a4f47..bb5f22f89 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_schema.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_schema.py @@ -258,7 +258,8 @@ def test_extract_properties_with_schema(self, mocker): '"someOtherDate": "2024-01--1 00:01:01", ' '"accidentNumber": "FTW95FA129", ' '"latitude": "10.00353", ' - '"injuryCount": "5"}' + '"injuryCount": "5", ' + '"location": ["Fort Worth, TX", "Dallas, TX"]}' ) doc = Document() @@ -275,6 +276,7 @@ def test_extract_properties_with_schema(self, mocker): SchemaField(name="accidentNumber", field_type="str"), SchemaField(name="injuryCount", field_type="int"), SchemaField(name="latitude", field_type="float"), + SchemaField(name="location", field_type="list"), ] ) property_extractor = LLMPropertyExtractor(llm, schema=schema) @@ -293,3 +295,4 @@ def test_extract_properties_with_schema(self, mocker): assert doc.properties["entity"]["someOtherDate"] == "2024-01--1 00:01:01" assert doc.properties["entity"]["injuryCount"] == 5 assert doc.properties["entity"]["latitude"] == 10.00353 + assert doc.properties["entity"]["location"] == ["Fort Worth, TX", "Dallas, TX"] diff --git a/lib/sycamore/sycamore/transforms/extract_schema.py b/lib/sycamore/sycamore/transforms/extract_schema.py index a890b7ef4..bca5dc1bc 100644 --- a/lib/sycamore/sycamore/transforms/extract_schema.py +++ b/lib/sycamore/sycamore/transforms/extract_schema.py @@ -267,6 +267,7 @@ def cast_types(self, fields: dict) -> dict: "bool": bool, "date": lambda x: dateparser.parse(x), "datetime": lambda x: dateparser.parse(x), + "list": list, } for field in self._schema.fields: From 8e24f97d58bb7b8f60c268c530363d244f7723c2 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Mon, 10 Feb 2025 10:44:31 -0800 Subject: [PATCH 07/11] [llm unify 5b/n] Jinja summarize images (#1166) * add jinja prompts and convert extract entity to use it Signed-off-by: Henry Lindeman * delete commented out / dead code Signed-off-by: Henry Lindeman * JinjaPrompt docstring Signed-off-by: Henry Lindeman * add comments bc otherwise this is very dense Signed-off-by: Henry Lindeman * add norender() directive to jinja prompts when they shouldn't render Signed-off-by: Henry Lindeman * change FANCY_BATCHED_LIST to BATCHED_LIST_WITH_METADATA Signed-off-by: Henry Lindeman * pr comments Signed-off-by: Henry Lindeman * branch switch Signed-off-by: Henry Lindeman * move summarizeImages to jinja Signed-off-by: Henry Lindeman * delete old bespoke prompt class Signed-off-by: Henry Lindeman * make it actually use the jinja prompt (and fix the jinja) Signed-off-by: Henry Lindeman * mypy: Signed-off-by: Henry Lindeman * add summarize_images unittest (mostly a prompt ut) Signed-off-by: Henry Lindeman * adjust prop extraction its to count the correct number of lineage metadatadocs Signed-off-by: Henry Lindeman --------- Signed-off-by: Henry Lindeman --- .../sycamore/llms/prompts/default_prompts.py | 63 ++++++++- lib/sycamore/sycamore/llms/prompts/prompts.py | 133 +++++++++++++----- .../transforms/test_data_extraction.py | 12 +- .../unit/transforms/test_summarize_images.py | 59 ++++++++ .../sycamore/transforms/summarize_images.py | 52 +------ lib/sycamore/sycamore/utils/extract_json.py | 7 + lib/sycamore/sycamore/utils/pdf_utils.py | 4 +- 7 files changed, 241 insertions(+), 89 deletions(-) create mode 100644 lib/sycamore/sycamore/tests/unit/transforms/test_summarize_images.py diff --git a/lib/sycamore/sycamore/llms/prompts/default_prompts.py b/lib/sycamore/sycamore/llms/prompts/default_prompts.py index 7e72813dc..6dd619278 100644 --- a/lib/sycamore/sycamore/llms/prompts/default_prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/default_prompts.py @@ -3,7 +3,13 @@ from typing import Any, Optional, Type import textwrap -from sycamore.llms.prompts.prompts import ElementListPrompt, ElementPrompt, StaticPrompt, JinjaPrompt +from sycamore.llms.prompts.prompts import ( + ElementListPrompt, + ElementPrompt, + StaticPrompt, + JinjaPrompt, + JinjaElementPrompt, +) logger = logging.getLogger(__name__) @@ -86,6 +92,61 @@ class _EntityExtractorFewShotGuidancePrompt(SimplePrompt): ) +SummarizeImagesJinjaPrompt = JinjaElementPrompt( + user=textwrap.dedent( + """ + You are given an image from a PDF document along with with some snippets of text preceding + and following the image on the page. Based on this context, please decide whether the image is a + graph or not. An image is a graph if it is a bar chart or a line graph. If the image is a graph, + please summarize the axes, including their units, and provide a summary of the results in no more + than 5 sentences. + + Return the results in the following JSON schema: + + { + "is_graph": true, + "x-axis": string, + "y-axis": string, + "summary": string + } + + If the image is not a graph, please summarize the contents of the image in no more than five sentences + in the following JSON format: + + { + "is_graph": false, + "summary": string + } + + In all cases return only JSON and check your work. + + {% if include_context -%} + {%- set posns = namespace(pos=-1) -%} + {%- for e in doc.elements -%} + {%- if e is sameas elt -%} + {%- set posns.pos = loop.index0 -%} + {% break %} + {%- endif -%} + {%- endfor -%} + {%- if posns.pos > 0 -%} + {%- set pe = doc.elements[posns.pos - 1] -%} + {%- if pe.type in ["Section-header", "Caption", "Text"] -%} + The text preceding the image is: {{ pe.text_representation }} + {%- endif -%} + {%- endif %} + {% if posns.pos != -1 and posns.pos < doc.elements|count - 1 -%} + {%- set fe = doc.elements[posns.pos + 1] -%} + {%- if fe.type in ["Caption", "Text"] -%} + The text following the image is: {{ fe.text_representation }} + {%- endif -%} + {%- endif -%} + {%- endif -%} + """ + ), + include_image=True, +) + + class _TextSummarizerGuidancePrompt(SimplePrompt): system = "You are a helpful text summarizer." user = """Write a summary of the following. Use only the information provided. diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index 39dbd1598..7995e6775 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Union, Optional, Callable +from typing import Any, Union, Optional, Callable, TYPE_CHECKING import copy import pydantic @@ -8,6 +8,10 @@ from sycamore.functions.tokenizer import Tokenizer from sycamore.connectors.common import flatten_data +if TYPE_CHECKING: + from jinja2.sandbox import SandboxedEnvironment + from jinja2 import Template + @dataclass class RenderedMessage: @@ -467,8 +471,35 @@ def raise_no_render(): raise NoRender() +def compile_templates(templates: list[Optional[str]], env: "SandboxedEnvironment") -> list[Optional["Template"]]: + return [ + env.from_string(source=t, globals={"norender": raise_no_render}) if t is not None else None for t in templates + ] + + +def render_templates(sys: Optional["Template"], user: list["Template"], render_args: dict[str, Any]) -> RenderedPrompt: + messages = [] + if sys is not None: + try: + system = sys.render(render_args) + messages.append(RenderedMessage(role="system", content=system)) + except NoRender: + return RenderedPrompt(messages=[]) + for ut in user: + try: + content = ut.render(render_args) + messages.append(RenderedMessage(role="user", content=content)) + except NoRender: + return RenderedPrompt(messages=[]) + return RenderedPrompt(messages=messages) + + def _deserialize_jinja_prompt(kwargs): - return JinjaPrompt(**kwargs) + cls = kwargs.pop("class") + if cls == "JinjaPrompt": + return JinjaPrompt(**kwargs) + if cls == "JinjaElementPrompt": + return JinjaElementPrompt(**kwargs) class JinjaPrompt(SycamorePrompt): @@ -512,13 +543,15 @@ def __init__(self, *, system: Optional[str] = None, user: Union[None, str, list[ self.system = system self.user = user self.kwargs = kwargs - self._env = SandboxedEnvironment() + self._env = SandboxedEnvironment(extensions=["jinja2.ext.loopcontrols"]) self._sys_template: Optional[Template] = None - self._user_templates: Union[None, Template, list[Template]] = None + self._user_templates: Union[None, list[Template]] = None def __reduce__(self): # Cannot serialize compiled templates - so force recompilation - return _deserialize_jinja_prompt, ({"system": self.system, "user": self.user, **self.kwargs},) + return _deserialize_jinja_prompt, ( + {"system": self.system, "user": self.user, "class": self.__class__.__name__, **self.kwargs}, + ) def render_document(self, doc: Document) -> RenderedPrompt: """Render this document using Jinja's template rendering system. @@ -534,39 +567,67 @@ def render_document(self, doc: Document) -> RenderedPrompt: Returns: A rendered prompt containing information from the document. """ - if self._sys_template is None and self.system is not None: - self._sys_template = self._env.from_string(source=self.system, globals={"norender": raise_no_render}) - if self._user_templates is None and self.user is not None: - if isinstance(self.user, str): - self._user_templates = self._env.from_string(source=self.user, globals={"norender": raise_no_render}) - else: - self._user_templates = [ - self._env.from_string(source=u, globals={"norender": raise_no_render}) for u in self.user - ] + if self._user_templates is None: + userlist = self.user if isinstance(self.user, list) else [self.user] # type: ignore + templates = compile_templates([self.system] + userlist, self._env) # type: ignore + self._sys_template = templates[0] + self._user_templates = [t for t in templates[1:] if t is not None] render_args = copy.deepcopy(self.kwargs) render_args["doc"] = doc - messages = [] - if self._sys_template is not None: - try: - system = self._sys_template.render(render_args) - messages.append(RenderedMessage(role="system", content=system)) - except NoRender: - return RenderedPrompt(messages=[]) - if self._user_templates is not None: - if isinstance(self._user_templates, list): - for t in self._user_templates: - try: - content = t.render(render_args) - messages.append(RenderedMessage(role="user", content=content)) - except NoRender: - return RenderedPrompt(messages=[]) - else: - try: - content = self._user_templates.render(render_args) - messages.append(RenderedMessage(role="user", content=content)) - except NoRender: - return RenderedPrompt(messages=[]) + rendered = render_templates(self._sys_template, self._user_templates, render_args) + return rendered + + +class JinjaElementPrompt(SycamorePrompt): + def __init__( + self, + *, + system: Optional[str] = None, + user: Union[None, str, list[str]] = None, + include_image: bool = False, + **kwargs, + ): + from jinja2.sandbox import SandboxedEnvironment + from jinja2 import Template - return RenderedPrompt(messages=messages) + super().__init__() + self.system = system + self.user = user + self.include_image = include_image + self.kwargs = kwargs + self._env = SandboxedEnvironment(extensions=["jinja2.ext.loopcontrols"]) + self._sys_template: Optional[Template] = None + self._user_templates: Union[None, list[Template]] = None + + def __reduce__(self): + # Cannot serialize compiled templates - so force recompilation + return _deserialize_jinja_prompt, ( + { + "system": self.system, + "user": self.user, + "include_image": self.include_image, + "class": self.__class__.__name__, + **self.kwargs, + }, + ) + + def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: + if self._user_templates is None: + userlist = self.user if isinstance(self.user, list) else [self.user] # type: ignore + templates = compile_templates([self.system] + userlist, self._env) # type: ignore + self._sys_template = templates[0] + self._user_templates = [t for t in templates[1:] if t is not None] + + render_args = copy.deepcopy(self.kwargs) + render_args["elt"] = elt + render_args["doc"] = doc + + result = render_templates(self._sys_template, self._user_templates, render_args) + if self.include_image and len(result.messages) > 0: + from sycamore.utils.pdf_utils import get_element_image + + result.messages[-1].images = [get_element_image(elt, doc)] + print(result) + return result diff --git a/lib/sycamore/sycamore/tests/integration/transforms/test_data_extraction.py b/lib/sycamore/sycamore/tests/integration/transforms/test_data_extraction.py index e3f415417..28178135c 100644 --- a/lib/sycamore/sycamore/tests/integration/transforms/test_data_extraction.py +++ b/lib/sycamore/sycamore/tests/integration/transforms/test_data_extraction.py @@ -49,9 +49,9 @@ def test_extract_properties_from_dict_schema(llm): assert taken[0].properties["entity"]["age"] == 74 assert "Honolulu" in taken[0].properties["entity"]["from_location"] - assert len(taken) == 3 - assert taken[2].metadata["usage"]["prompt_tokens"] > 0 - assert taken[2].metadata["usage"]["completion_tokens"] > 0 + assert len(taken) == 4 + assert taken[3].metadata["usage"]["prompt_tokens"] > 0 + assert taken[3].metadata["usage"]["completion_tokens"] > 0 @pytest.mark.parametrize("llm", llms) @@ -97,8 +97,8 @@ def test_extract_properties_from_schema(llm): assert taken[1].properties["entity"]["from_location"] == "New Delhi" assert taken[1].properties["entity"]["date"] == "2014-01-11" - assert len(taken) == 5 - assert taken[3].metadata["usage"]["prompt_tokens"] > 0 - assert taken[3].metadata["usage"]["completion_tokens"] > 0 + assert len(taken) == 6 assert taken[4].metadata["usage"]["prompt_tokens"] > 0 assert taken[4].metadata["usage"]["completion_tokens"] > 0 + assert taken[5].metadata["usage"]["prompt_tokens"] > 0 + assert taken[5].metadata["usage"]["completion_tokens"] > 0 diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_summarize_images.py b/lib/sycamore/sycamore/tests/unit/transforms/test_summarize_images.py new file mode 100644 index 000000000..186eedf1e --- /dev/null +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_summarize_images.py @@ -0,0 +1,59 @@ +from typing import Optional, Any +import json +from sycamore.data.document import Document +from sycamore.data.element import Element, ImageElement +from sycamore.llms.prompts.prompts import RenderedPrompt +from sycamore.tests.config import TEST_DIR +from sycamore.llms import LLM +from sycamore.transforms.summarize_images import LLMImageSummarizer, SummarizeImages + + +def image_element() -> ImageElement: + with open(TEST_DIR / "resources/data/imgs/sample-detr-image.png", "rb") as f: + return ImageElement(binary_representation=f.read(), image_format="png") + + +class MockLLM(LLM): + def __init__(self): + pass + + def is_chat_mode(self): + return True + + def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict[str, Any]] = None) -> str: + promptstr = "\n".join(m.content for m in prompt.messages) + return json.dumps({"summary": promptstr}) + + +class TestSummarizeImages: + @staticmethod + def doc(): + return Document( + elements=[ + image_element(), + Element(type="Text", text_representation="text"), + image_element(), + Element(type="Section-header", text_representation="section-header"), + image_element(), + Element(type="Caption", text_representation="caption"), + image_element(), + ] + ) + + def test_summarize_images(self, mocker): + d = self.doc() + sum_images = LLMImageSummarizer(llm=MockLLM()) + si_transform = SummarizeImages(None, summarizer=sum_images) + out = si_transform._local_process([d])[0] + + assert "The text preceding the image is: " not in out.elements[0].properties["summary"]["summary"] + assert "The text following the image is: text" in out.elements[0].properties["summary"]["summary"] + + assert "The text preceding the image is: text" in out.elements[2].properties["summary"]["summary"] + assert "The text following the image is: " not in out.elements[2].properties["summary"]["summary"] + + assert "The text preceding the image is: section-header" in out.elements[4].properties["summary"]["summary"] + assert "The text following the image is: caption" in out.elements[4].properties["summary"]["summary"] + + assert "The text preceding the image is: caption" in out.elements[6].properties["summary"]["summary"] + assert "The text following the image is: " not in out.elements[6].properties["summary"]["summary"] diff --git a/lib/sycamore/sycamore/transforms/summarize_images.py b/lib/sycamore/sycamore/transforms/summarize_images.py index 68a2e775a..7be07fb40 100644 --- a/lib/sycamore/sycamore/transforms/summarize_images.py +++ b/lib/sycamore/sycamore/transforms/summarize_images.py @@ -1,10 +1,10 @@ from typing import Optional -import textwrap -from sycamore.data import Document, ImageElement, Element +from sycamore.data import Document, Element from sycamore.llms.openai import LLM, OpenAI, OpenAIClientWrapper, OpenAIModels -from sycamore.llms.prompts.prompts import SycamorePrompt, RenderedPrompt, RenderedMessage +from sycamore.llms.prompts.default_prompts import SummarizeImagesJinjaPrompt +from sycamore.llms.prompts.prompts import SycamorePrompt from sycamore.plan_nodes import Node from sycamore.transforms.base import CompositeTransform from sycamore.transforms.base_llm import LLMMapElements @@ -12,47 +12,6 @@ from sycamore.utils.extract_json import extract_json -class SummarizeImagesPrompt(SycamorePrompt): - """A prompt for summarizing image elements. If given a non-image element - or an image element without image data, will render an empty prompt (which - is skipped by LLMMapElements). - - Args: - user: Base user prompt. Defaults to LLMImageSummarizer.DEFAULT_PROMPT - include_context: Whether to include the text of the elements before - and after the image in the prompt. Only takes Section-headers, - Captions, and Text before the image and only Captions and Text - after the image. - """ - - def __init__(self, user: Optional[str] = None, include_context: bool = True): - self.include_context = include_context - self.user = user or textwrap.dedent(" " * 12 + LLMImageSummarizer.DEFAULT_PROMPT) - self.preceding = "\nThe text preceding the image is {preceding_context}" - self.following = "\nThe text following the image is {following_context}" - - def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: - if not isinstance(elt, ImageElement): - return RenderedPrompt(messages=[]) - im = elt.as_image() - if im is None: - return RenderedPrompt(messages=[]) - text = self.user - if self.include_context: - for i, e in enumerate(doc.elements): - if e.element_index == elt.element_index: - if i > 0: - pe = doc.elements[i - 1] - if pe.type in {"Section-header", "Caption", "Text"}: - text += self.preceding.format(preceding_context=pe.text_representation) - if i < len(doc.elements) - 1: - fe = doc.elements[i + 1] - if fe.type in {"Caption", "Text"}: - text += self.following.format(following_context=fe.text_representation) - - return RenderedPrompt(messages=[RenderedMessage(role="user", content=text, images=[im])]) - - def parse_summary_json(e: Element) -> Element: if "summary" in e.properties and isinstance(e.properties["summary"], str): e.properties["summary"] = extract_json(e.properties["summary"]) @@ -171,7 +130,10 @@ class SummarizeImages(CompositeTransform): def __init__(self, child: Node, summarizer=OpenAIImageSummarizer(), **resource_args): super().__init__(child, [], **resource_args) - prompt = SummarizeImagesPrompt(user=summarizer.prompt, include_context=summarizer.include_context) + prompt: SycamorePrompt = SummarizeImagesJinjaPrompt + if summarizer.prompt != LLMImageSummarizer.DEFAULT_PROMPT: + prompt = prompt.set(user=summarizer.prompt) + prompt = prompt.set(include_context=summarizer.include_context) llm_map = LLMMapElements( child, prompt, output_field="summary", llm=summarizer.llm, filter=lambda e: e.type == "Image" ) diff --git a/lib/sycamore/sycamore/utils/extract_json.py b/lib/sycamore/sycamore/utils/extract_json.py index a2a45de07..a3c21c8df 100644 --- a/lib/sycamore/sycamore/utils/extract_json.py +++ b/lib/sycamore/sycamore/utils/extract_json.py @@ -10,6 +10,13 @@ def extract_json(payload: str) -> Any: try: return json.loads(payload) except (ValueError, TypeError, JSONDecodeError) as exc: + # Sometimes the LLM makes up an escape code. In that case, + # replace the escape char with its representation (e.g. \\x07) + # and recurse. + if isinstance(exc, JSONDecodeError) and "Invalid \\escape" in exc.msg: + c = payload[exc.pos] + payload = payload[: exc.pos] + repr(c)[1:-1] + payload[exc.pos + 1 :] + return extract_json(payload) # It is possible that the LLM response includes a code block with JSON data. # Pull the JSON content out from it. pattern = r"```json([\s\S]*?)```" diff --git a/lib/sycamore/sycamore/utils/pdf_utils.py b/lib/sycamore/sycamore/utils/pdf_utils.py index 92331b578..8d1bbb692 100644 --- a/lib/sycamore/sycamore/utils/pdf_utils.py +++ b/lib/sycamore/sycamore/utils/pdf_utils.py @@ -10,7 +10,7 @@ from sycamore import DocSet from sycamore.functions.document import DrawBoxes, split_and_convert_to_image from sycamore.utils.image_utils import show_images, crop_to_bbox -from sycamore.data import Document, Element +from sycamore.data import Document, Element, ImageElement import json logger = logging.getLogger(__name__) @@ -184,6 +184,8 @@ def promote_title(elements: list[Element], title_candidate_elements=["Section-he def get_element_image(element: Element, document: Document) -> Image.Image: + if isinstance(element, ImageElement) and (im := element.as_image()) is not None: + return im assert document.type == "pdf", "Cannot get picture of element from non-pdf" assert document.binary_representation is not None, "Cannot get image since there is not binary representation" assert element.bbox is not None, "Cannot get picture of element if it has no BBox" From ad5de1badb381ee7572c00f04b5bd689e2a5890f Mon Sep 17 00:00:00 2001 From: Mark Lindblad Date: Mon, 10 Feb 2025 15:18:14 -0800 Subject: [PATCH 08/11] Rename async list endpoints to "action" from "path" (#1170) --- lib/aryn-sdk/aryn_sdk/partition/partition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/aryn-sdk/aryn_sdk/partition/partition.py b/lib/aryn-sdk/aryn_sdk/partition/partition.py index 27ffcaa58..6fb7724c2 100644 --- a/lib/aryn-sdk/aryn_sdk/partition/partition.py +++ b/lib/aryn-sdk/aryn_sdk/partition/partition.py @@ -533,7 +533,7 @@ def partition_file_async_list( result = response.json()["tasks"] for v in result.values(): - del v["path"] + v.pop("action", None) return result From aa0d7a36c0b1972b0c54c05cdb499970cc01a149 Mon Sep 17 00:00:00 2001 From: Henry Lindeman Date: Mon, 10 Feb 2025 15:51:21 -0800 Subject: [PATCH 09/11] [llm unify 5c/n] jinjify extract properties (#1169) * add jinja prompts and convert extract entity to use it Signed-off-by: Henry Lindeman * delete commented out / dead code Signed-off-by: Henry Lindeman * JinjaPrompt docstring Signed-off-by: Henry Lindeman * add comments bc otherwise this is very dense Signed-off-by: Henry Lindeman * add norender() directive to jinja prompts when they shouldn't render Signed-off-by: Henry Lindeman * change FANCY_BATCHED_LIST to BATCHED_LIST_WITH_METADATA Signed-off-by: Henry Lindeman * pr comments Signed-off-by: Henry Lindeman * branch switch Signed-off-by: Henry Lindeman * move summarizeImages to jinja Signed-off-by: Henry Lindeman * delete old bespoke prompt class Signed-off-by: Henry Lindeman * make it actually use the jinja prompt (and fix the jinja) Signed-off-by: Henry Lindeman * mypy: Signed-off-by: Henry Lindeman * add summarize_images unittest (mostly a prompt ut) Signed-off-by: Henry Lindeman * adjust prop extraction its to count the correct number of lineage metadatadocs Signed-off-by: Henry Lindeman * extract properties -> jinja Signed-off-by: Henry Lindeman * set prompt_formatter when supplied a non-default Signed-off-by: Henry Lindeman * fix mypy. why is this fix a fix? I don't understand python. Signed-off-by: Henry Lindeman * drop a print statement Signed-off-by: Henry Lindeman --------- Signed-off-by: Henry Lindeman --- .../sycamore/llms/prompts/default_prompts.py | 43 ++++++++ .../sycamore/llms/prompts/jinja_fragments.py | 31 +++++- lib/sycamore/sycamore/llms/prompts/prompts.py | 22 +++- .../transforms/test_data_extraction.py | 4 +- .../sycamore/transforms/extract_entity.py | 6 +- .../sycamore/transforms/extract_schema.py | 104 ++---------------- 6 files changed, 111 insertions(+), 99 deletions(-) diff --git a/lib/sycamore/sycamore/llms/prompts/default_prompts.py b/lib/sycamore/sycamore/llms/prompts/default_prompts.py index 6dd619278..e31221e56 100644 --- a/lib/sycamore/sycamore/llms/prompts/default_prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/default_prompts.py @@ -10,6 +10,12 @@ JinjaPrompt, JinjaElementPrompt, ) +from sycamore.llms.prompts.jinja_fragments import ( + J_DYNAMIC_DOC_TEXT, + J_FORMAT_SCHEMA_MACRO, + J_SET_ENTITY, + J_SET_SCHEMA, +) logger = logging.getLogger(__name__) @@ -357,6 +363,43 @@ class {entity} from the document. The class only has properties {properties}. Us ) +PropertiesZeroShotJinjaPrompt = JinjaPrompt( + system="You are a helpful property extractor. You only return JSON.", + user=J_SET_SCHEMA + + J_SET_ENTITY + + textwrap.dedent( + """\ + You are given some text of a document. Extract JSON representing one entity of + class {{ entity }} from the document. The class only has properties {{ schema }}. Using + this context, FIND, FORMAT, and RETURN the JSON representing one {{ entity }}. + Only return JSON as part of your answer. If no entity is in the text, return "None". + + Document: + """ + ) + + J_DYNAMIC_DOC_TEXT, +) + +PropertiesFromSchemaJinjaPrompt = JinjaPrompt( + system="You are given text contents from a document.", + user=( + J_FORMAT_SCHEMA_MACRO + + """\ +Extract values for the following fields: +{{ format_schema(schema) }} + +Document text:""" + + J_DYNAMIC_DOC_TEXT + + """ + +Don't return extra information. +If you cannot find a value for a requested property, use the provided default or the value 'None'. +Return your answers as a valid json dictionary that will be parsed in python. +""" + ), +) + + class EntityExtractorMessagesPrompt(SimplePrompt): def __init__(self, question: str, field: str, format: Optional[str], discrete: bool = False): super().__init__() diff --git a/lib/sycamore/sycamore/llms/prompts/jinja_fragments.py b/lib/sycamore/sycamore/llms/prompts/jinja_fragments.py index 29d87e11c..857957fba 100644 --- a/lib/sycamore/sycamore/llms/prompts/jinja_fragments.py +++ b/lib/sycamore/sycamore/llms/prompts/jinja_fragments.py @@ -1,7 +1,11 @@ -J_ELEMENT_LIST = """\ +J_ELEMENT_LIST_CAPPED = """\ {% for elt in doc.elements[:num_elements] %}ELEMENT {{ elt.element_index }}: {{ elt.field_to_value(field) }} {% endfor %}""" +J_ELEMENT_LIST_UNCAPPED = """\ +{% for elt in doc.elements %}ELEMENT {{ elt.element_index }}: {{ elt.field_to_value(field) }} +{% endfor %}""" + # Directive to not render the template if the iteration var has # surpassed the number of batches. @@ -29,3 +33,28 @@ Text: {{ elt.field_to_value(field) }} {% endfor -%}""" ) + +J_SET_SCHEMA = """{%- if schema is not defined %}{% set schema = doc.properties["_schema"] %}{% endif -%}\n""" +J_SET_ENTITY = ( + """{%- if entity is not defined %}{% set entity = doc.properties.get("_schema_class", "entity") %}{% endif -%}\n""" +) + +J_DYNAMIC_DOC_TEXT = ( + """{%- set field = "text_representation" -%} +{% if doc.text_representation is not none %}{{ doc.text_representation }} +{% elif prompt_formatter is defined %}{{ prompt_formatter(doc.elements) }} +{% elif num_elements is defined %}""" + + J_ELEMENT_LIST_CAPPED + + "{% else %}" + + J_ELEMENT_LIST_UNCAPPED + + "{% endif %}" +) + +J_FORMAT_SCHEMA_MACRO = """{% macro format_schema(schema) -%} +{% for field in schema.fields %} +{{ loop.index }} {{ field.name }}: {{ field.field_type }}: default={{ field.default }} +{% if field.description %} Decription: {{ field.description }}{% endif %} +{% if field.examples %} Example values: {{ field.examples }}{% endif %} +{%- endfor -%} +{%- endmacro %} +""" diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index 7995e6775..d18b35128 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -13,6 +13,9 @@ from jinja2 import Template +ResponseFormat = Union[None, dict[str, Any], type[pydantic.BaseModel]] + + @dataclass class RenderedMessage: """Represents a message per the LLM messages interface - i.e. a role and a content string @@ -39,7 +42,7 @@ class RenderedPrompt: """ messages: list[RenderedMessage] - response_format: Union[None, dict[str, Any], type[pydantic.BaseModel]] = None + response_format: ResponseFormat = None def token_count(self, tokenizer: Tokenizer) -> int: total_text = " ".join(m.content for m in self.messages) @@ -535,13 +538,21 @@ class JinjaPrompt(SycamorePrompt): """ - def __init__(self, *, system: Optional[str] = None, user: Union[None, str, list[str]] = None, **kwargs): + def __init__( + self, + *, + system: Optional[str] = None, + user: Union[None, str, list[str]] = None, + response_format: ResponseFormat = None, + **kwargs, + ): from jinja2.sandbox import SandboxedEnvironment from jinja2 import Template super().__init__() self.system = system self.user = user + self.response_format = response_format self.kwargs = kwargs self._env = SandboxedEnvironment(extensions=["jinja2.ext.loopcontrols"]) self._sys_template: Optional[Template] = None @@ -577,6 +588,8 @@ def render_document(self, doc: Document) -> RenderedPrompt: render_args["doc"] = doc rendered = render_templates(self._sys_template, self._user_templates, render_args) + if self.response_format is not None: + rendered.response_format = self.response_format return rendered @@ -587,6 +600,7 @@ def __init__( system: Optional[str] = None, user: Union[None, str, list[str]] = None, include_image: bool = False, + response_format: ResponseFormat = None, **kwargs, ): from jinja2.sandbox import SandboxedEnvironment @@ -596,6 +610,7 @@ def __init__( self.system = system self.user = user self.include_image = include_image + self.response_format = response_format self.kwargs = kwargs self._env = SandboxedEnvironment(extensions=["jinja2.ext.loopcontrols"]) self._sys_template: Optional[Template] = None @@ -629,5 +644,6 @@ def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: from sycamore.utils.pdf_utils import get_element_image result.messages[-1].images = [get_element_image(elt, doc)] - print(result) + if self.response_format is not None: + result.response_format = self.response_format return result diff --git a/lib/sycamore/sycamore/tests/integration/transforms/test_data_extraction.py b/lib/sycamore/sycamore/tests/integration/transforms/test_data_extraction.py index 28178135c..0ae0111d9 100644 --- a/lib/sycamore/sycamore/tests/integration/transforms/test_data_extraction.py +++ b/lib/sycamore/sycamore/tests/integration/transforms/test_data_extraction.py @@ -68,7 +68,9 @@ def test_extract_properties_from_schema(llm): default="null", ), SchemaField(name="age", field_type="int", default=999), - SchemaField(name="date", field_type="str", description="Any date in the doc in YYYY-MM-DD format"), + SchemaField( + name="date", field_type="str", description="Any date in the doc, extracted in YYYY-MM-DD format" + ), SchemaField( name="from_location", field_type="str", diff --git a/lib/sycamore/sycamore/transforms/extract_entity.py b/lib/sycamore/sycamore/transforms/extract_entity.py index 9d7ce8397..d685bef9a 100644 --- a/lib/sycamore/sycamore/transforms/extract_entity.py +++ b/lib/sycamore/sycamore/transforms/extract_entity.py @@ -164,7 +164,9 @@ def _get_prompt(self) -> SycamorePrompt: if self._prompt is not None: if isinstance(self._prompt, str): - return JinjaPrompt(system=None, user=self._prompt + "\n" + j_elements, **common_params) + return JinjaPrompt( + system=None, user=self._prompt + "\n" + j_elements, response_format=None, **common_params + ) else: system = None if len(self._prompt) > 0 and self._prompt[0]["role"] == "system": @@ -172,7 +174,7 @@ def _get_prompt(self) -> SycamorePrompt: user = [p["content"] for p in self._prompt[1:]] + [j_elements] else: user = [p["content"] for p in self._prompt] + [j_elements] - return JinjaPrompt(system=system, user=user, **common_params) + return JinjaPrompt(system=system, user=user, response_format=None, **common_params) elif self._prompt_template is not None: return EntityExtractorFewShotJinjaPrompt.set(examples=self._prompt_template, **common_params) else: diff --git a/lib/sycamore/sycamore/transforms/extract_schema.py b/lib/sycamore/sycamore/transforms/extract_schema.py index bca5dc1bc..fec4edf47 100644 --- a/lib/sycamore/sycamore/transforms/extract_schema.py +++ b/lib/sycamore/sycamore/transforms/extract_schema.py @@ -1,19 +1,16 @@ from abc import ABC, abstractmethod from typing import Callable, Any, Optional, Union import json -import textwrap -import copy from sycamore.data import Element, Document -from sycamore.connectors.common import flatten_data -from sycamore.llms.prompts.prompts import ElementListPrompt from sycamore.schema import Schema from sycamore.llms import LLM from sycamore.llms.prompts.default_prompts import ( _SchemaZeroShotGuidancePrompt, + PropertiesZeroShotJinjaPrompt, + PropertiesFromSchemaJinjaPrompt, ) -from sycamore.llms.prompts import SycamorePrompt, RenderedPrompt -from sycamore.llms.prompts.prompts import _build_format_str +from sycamore.llms.prompts import SycamorePrompt from sycamore.plan_nodes import Node from sycamore.transforms.map import Map from sycamore.transforms.base_llm import LLMMap @@ -51,78 +48,6 @@ def as_llm_map(self, child: Optional[Node], **kwargs) -> Node: pass -class PropertyExtractionFromSchemaPrompt(ElementListPrompt): - default_system = "You are given text contents from a document." - default_user = textwrap.dedent( - """\ - Extract values for the following fields: - {schema} - - Document text: - {doc_text} - - Don't return extra information. - If you cannot find a value for a requested property, use the provided default or the value 'None'. - Return your answers as a valid json dictionary that will be parsed in python. - """ - ) - - def __init__(self, schema: Schema): - super().__init__(system=self.default_system, user=self.default_user) - self.schema = schema - self.kwargs["schema"] = self._format_schema(schema) - - @staticmethod - def _format_schema(schema: Schema) -> str: - text = "" - for i, field in enumerate(schema.fields): - text += f"{i} {field.name}: type={field.field_type}: default={field.default}\n" - if field.description is not None: - text += f" {field.description}\n" - if field.examples is not None: - text += f" Examples values: {field.examples}\n" - return text - - def set(self, **kwargs) -> SycamorePrompt: - if "schema" in kwargs: - new = copy.deepcopy(self) - new.schema = kwargs["schema"] - kwargs["schema"] = self._format_schema(new.schema) - return new.set(**kwargs) - return super().set(**kwargs) - - def render_document(self, doc: Document) -> RenderedPrompt: - rp = super().render_document(doc) - rp.response_format = self.schema.model_dump() - return rp - - -class PropertyExtractionFromDictPrompt(ElementListPrompt): - def __init__(self, schema: Optional[dict] = None, **kwargs): - super().__init__(**kwargs) - self.schema = schema - - def render_document(self, doc: Document) -> RenderedPrompt: - format_args = copy.deepcopy(self.kwargs) - format_args["doc_text"] = doc.text_representation - if self.schema is None: - schema = doc.properties.get("_schema") - else: - schema = self.schema - format_args["schema"] = schema - if "entity" not in format_args: - format_args["entity"] = doc.properties.get("_schema_class", "entity") - flat_props = flatten_data(doc.properties, prefix="doc_property", separator="_") - format_args.update(flat_props) - format_args["elements"] = self._render_element_list_to_string(doc) - if doc.text_representation is None: - format_args["doc_text"] = format_args["elements"] - - messages = _build_format_str(self.system, self.user, format_args) - result = RenderedPrompt(messages=messages, response_format=schema) - return result - - class LLMSchemaExtractor(SchemaExtractor): """ The LLMSchemaExtractor uses the specified LLM object to extract a schema. @@ -287,23 +212,18 @@ def cast_types(self, fields: dict) -> dict: def as_llm_map(self, child: Optional[Node], **kwargs) -> Node: prompt: SycamorePrompt # mypy grr if isinstance(self._schema, Schema): - prompt = PropertyExtractionFromSchemaPrompt(self._schema) + prompt = PropertiesFromSchemaJinjaPrompt + prompt = prompt.set(schema=self._schema, response_format=self._schema.model_dump()) else: - prompt = PropertyExtractionFromDictPrompt( - schema=self._schema, - system="You are a helpful property extractor. You only return JSON.", - user=textwrap.dedent( - """\ - You are given a few text elements of a document. Extract JSON representing one entity of - class {entity} from the document. The class only has properties {schema}. Using - this context, FIND, FORMAT, and RETURN the JSON representing one {entity}. - Only return JSON as part of your answer. If no entity is in the text, return "None". - {doc_text} - """ - ), - ) + prompt = PropertiesZeroShotJinjaPrompt + if self._schema is not None: + prompt = prompt.set(schema=self._schema) + if self._schema_name is not None: prompt = prompt.set(entity=self._schema_name) + prompt = prompt.set(num_elements=self._num_of_elements) + if self._prompt_formatter is not element_list_formatter: + prompt = prompt.set(prompt_formatter=self._prompt_formatter) def parse_json_and_cast(d: Document) -> Document: entity_name = self._schema_name or "_entity" From 43758ca85b1120e28df73d56242051babadb01cd Mon Sep 17 00:00:00 2001 From: Karan Sampath <176953591+karanataryn@users.noreply.github.com> Date: Mon, 10 Feb 2025 17:45:30 -0800 Subject: [PATCH 10/11] Bump Beautiful Soup (#1167) * bump * check lint * fix action * change action * update * assert Tag * fix table and partition * revert workflow changes --- lib/sycamore/poetry.lock | 13 +++++++------ lib/sycamore/pyproject.toml | 2 +- lib/sycamore/sycamore/data/table.py | 14 +++++++++----- lib/sycamore/sycamore/transforms/partition.py | 3 ++- poetry.lock | 13 +++++++------ 5 files changed, 26 insertions(+), 19 deletions(-) diff --git a/lib/sycamore/poetry.lock b/lib/sycamore/poetry.lock index 11b4f1b70..f8856ebca 100644 --- a/lib/sycamore/poetry.lock +++ b/lib/sycamore/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -525,17 +525,18 @@ files = [ [[package]] name = "beautifulsoup4" -version = "4.12.3" +version = "4.13.3" description = "Screen-scraping library" optional = false -python-versions = ">=3.6.0" +python-versions = ">=3.7.0" files = [ - {file = "beautifulsoup4-4.12.3-py3-none-any.whl", hash = "sha256:b80878c9f40111313e55da8ba20bdba06d8fa3969fc68304167741bbf9e082ed"}, - {file = "beautifulsoup4-4.12.3.tar.gz", hash = "sha256:74e3d1928edc070d21748185c46e3fb33490f22f52a3addee9aee0f4f7781051"}, + {file = "beautifulsoup4-4.13.3-py3-none-any.whl", hash = "sha256:99045d7d3f08f91f0d656bc9b7efbae189426cd913d830294a15eefa0ea4df16"}, + {file = "beautifulsoup4-4.13.3.tar.gz", hash = "sha256:1bd32405dacc920b42b83ba01644747ed77456a65760e285fbc47633ceddaf8b"}, ] [package.dependencies] soupsieve = ">1.2" +typing-extensions = ">=4.0.0" [package.extras] cchardet = ["cchardet"] @@ -9976,4 +9977,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "ec9791de75120499fa5627b69728ed2bda2d84adc76e4d92e216ae381b9356ec" +content-hash = "9b5e9ff0312b48fa4261aaeee39a557e20496228a07d519049ea46c5af827180" diff --git a/lib/sycamore/pyproject.toml b/lib/sycamore/pyproject.toml index 55601c5ee..9088c5407 100644 --- a/lib/sycamore/pyproject.toml +++ b/lib/sycamore/pyproject.toml @@ -22,7 +22,7 @@ ray = { extras = ["default"], version = "^2.36.0" } pyarrow = "^14.0.2" numpy = "<2.0.0" openai = "^1.60.2" -beautifulsoup4 = "^4.12.2" +beautifulsoup4 = "^4.13.1" amazon-textract-textractor = "^1.3.2" boto3 = "^1.28.70" boto3-stubs = {extras = ["essential"], version = "^1.35.12"} diff --git a/lib/sycamore/sycamore/data/table.py b/lib/sycamore/sycamore/data/table.py index 2ec195dd9..4ed0b3d7a 100644 --- a/lib/sycamore/sycamore/data/table.py +++ b/lib/sycamore/sycamore/data/table.py @@ -202,7 +202,7 @@ def from_html(cls, html_str: Optional[str] = None, html_tag: Optional[Tag] = Non if (html_str is not None and html_tag is not None) or (html_str is None and html_tag is None): raise ValueError("Exactly one of html_str and html_tag must be specified.") - + root: Union[Tag, BeautifulSoup] if html_str is not None: html_str = html_str.strip() if not html_str.startswith(""): @@ -222,9 +222,10 @@ def from_html(cls, html_str: Optional[str] = None, html_tag: Optional[Tag] = Non cells = [] caption = None - + assert isinstance(root, Tag), "Expected root to be a Tag" # Traverse the tree of elements in a pre-order traversal. for tag in root.find_all(recursive=True): + assert isinstance(tag, Tag), "Expected root to be a Tag" if tag.name == "tr": cur_row += 1 # TODO: Should this be based on rowspan? cur_col = 0 @@ -234,9 +235,12 @@ def from_html(cls, html_str: Optional[str] = None, html_tag: Optional[Tag] = Non # they have a thead. if cur_row < 0: cur_row += 1 - - rowspan = int(tag.attrs.get("rowspan", "1")) - colspan = int(tag.attrs.get("colspan", "1")) + if rowspan_str := tag.attrs.get("rowspan", "1"): + assert isinstance(rowspan_str, str) # For mypy + rowspan = int(rowspan_str) + if colspan_str := tag.attrs.get("colspan", "1"): + assert isinstance(colspan_str, str) # For mypy + colspan = int(colspan_str) content = tag.get_text() diff --git a/lib/sycamore/sycamore/transforms/partition.py b/lib/sycamore/sycamore/transforms/partition.py index 0487cb881..c39a5c1b2 100644 --- a/lib/sycamore/sycamore/transforms/partition.py +++ b/lib/sycamore/sycamore/transforms/partition.py @@ -2,7 +2,7 @@ import io from typing import Any, Literal, Optional, Union -from bs4 import BeautifulSoup +from bs4 import BeautifulSoup, Tag from sycamore.functions import TextOverlapChunker, Chunker from sycamore.functions import CharacterTokenizer, Tokenizer @@ -312,6 +312,7 @@ def partition(self, document: Document) -> Document: if self._extract_tables: for table in soup.find_all("table"): # ignore nested tables + assert isinstance(table, Tag) if len(table.find_all("table")) > 0: continue diff --git a/poetry.lock b/poetry.lock index b88aa8c29..4d96a2578 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiobotocore" @@ -597,17 +597,18 @@ files = [ [[package]] name = "beautifulsoup4" -version = "4.12.3" +version = "4.13.3" description = "Screen-scraping library" optional = false -python-versions = ">=3.6.0" +python-versions = ">=3.7.0" files = [ - {file = "beautifulsoup4-4.12.3-py3-none-any.whl", hash = "sha256:b80878c9f40111313e55da8ba20bdba06d8fa3969fc68304167741bbf9e082ed"}, - {file = "beautifulsoup4-4.12.3.tar.gz", hash = "sha256:74e3d1928edc070d21748185c46e3fb33490f22f52a3addee9aee0f4f7781051"}, + {file = "beautifulsoup4-4.13.3-py3-none-any.whl", hash = "sha256:99045d7d3f08f91f0d656bc9b7efbae189426cd913d830294a15eefa0ea4df16"}, + {file = "beautifulsoup4-4.13.3.tar.gz", hash = "sha256:1bd32405dacc920b42b83ba01644747ed77456a65760e285fbc47633ceddaf8b"}, ] [package.dependencies] soupsieve = ">1.2" +typing-extensions = ">=4.0.0" [package.extras] cchardet = ["cchardet"] @@ -8738,7 +8739,7 @@ amazon-textract-textractor = "^1.3.2" anthropic = {version = "^0.42.0", optional = true} apted = {version = "^1.0.3", optional = true} async-timeout = ">4.0.0" -beautifulsoup4 = "^4.12.2" +beautifulsoup4 = "^4.13.1" boto3 = "^1.28.70" boto3-stubs = {version = "^1.35.12", extras = ["essential"]} datasets = {version = "^2.16.1", optional = true} From 87f230ebd186277318ed3b1d511fa8699bc0fa84 Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Mon, 10 Feb 2025 18:27:53 -0800 Subject: [PATCH 11/11] Serialize query strings to avoid Ray Dataset column imputation (#1171) * Serialize query strings to avoid Ray Dataset column imputation * Fix non reconstruct case, ensure all integ tests pass --- .../connectors/opensearch/opensearch_reader.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py b/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py index b1ba3325c..463d02a18 100644 --- a/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py +++ b/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py @@ -1,3 +1,4 @@ +import json import logging from copy import deepcopy @@ -288,7 +289,7 @@ def __init__( logger.info(f"OpenSearchReader using PIT: {self.use_pit}") @timetrace("OpenSearchReader") - def _to_parent_doc(self, slice_query: dict[str, Any]) -> List[dict[str, Any]]: + def _to_parent_doc(self, doc: dict[str, Any]) -> List[dict[str, Any]]: """ Get all parent documents from a given slice. """ @@ -304,6 +305,7 @@ def _to_parent_doc(self, slice_query: dict[str, Any]) -> List[dict[str, Any]]: raise ValueError("Target is not present\n" f"Parameters: {self._query_params}\n") os_client = client._client + slice_query = json.loads(doc["doc"]) assert ( get_doc_count_for_slice(os_client, slice_query) < 10000 @@ -350,7 +352,7 @@ def _to_parent_doc(self, slice_query: dict[str, Any]) -> List[dict[str, Any]]: logger.info(f"Read {len(results)} documents from {self._query_params.index_name}") except Exception as e: - raise ValueError(f"Error reading from target: {e}") + raise ValueError(f"Error reading from target: {e}, query: {slice_query}") finally: if client is not None: client.close() @@ -359,7 +361,7 @@ def _to_parent_doc(self, slice_query: dict[str, Any]) -> List[dict[str, Any]]: return ret @timetrace("OpenSearchReader") - def _to_doc(self, slice_query: dict[str, Any]) -> List[dict[str, Any]]: + def _to_doc(self, doc: dict[str, Any]) -> List[dict[str, Any]]: """ Get all documents from a given slice. """ @@ -377,6 +379,7 @@ def _to_doc(self, slice_query: dict[str, Any]) -> List[dict[str, Any]]: raise ValueError("Target is not present\n" f"Parameters: {self._query_params}\n") os_client = client._client + slice_query = json.loads(doc["doc"]) slice_count = get_doc_count_for_slice(os_client, slice_query) assert slice_count <= 10000, f"Slice count ({slice_count}) should return <= 10,000 documents" @@ -556,7 +559,8 @@ def _execute_pit(self, **kwargs) -> "Dataset": } if "query" in query: _query["query"] = query["query"] - docs.append(_query) + + docs.append({"doc": json.dumps(_query)}) logger.debug(f"Added slice {i} to the query {_query}") except Exception as e: raise ValueError(f"Error reading from target: {e}")