diff --git a/lib/sycamore/sycamore/llms/prompts/default_prompts.py b/lib/sycamore/sycamore/llms/prompts/default_prompts.py index 6dd619278..e31221e56 100644 --- a/lib/sycamore/sycamore/llms/prompts/default_prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/default_prompts.py @@ -10,6 +10,12 @@ JinjaPrompt, JinjaElementPrompt, ) +from sycamore.llms.prompts.jinja_fragments import ( + J_DYNAMIC_DOC_TEXT, + J_FORMAT_SCHEMA_MACRO, + J_SET_ENTITY, + J_SET_SCHEMA, +) logger = logging.getLogger(__name__) @@ -357,6 +363,43 @@ class {entity} from the document. The class only has properties {properties}. Us ) +PropertiesZeroShotJinjaPrompt = JinjaPrompt( + system="You are a helpful property extractor. You only return JSON.", + user=J_SET_SCHEMA + + J_SET_ENTITY + + textwrap.dedent( + """\ + You are given some text of a document. Extract JSON representing one entity of + class {{ entity }} from the document. The class only has properties {{ schema }}. Using + this context, FIND, FORMAT, and RETURN the JSON representing one {{ entity }}. + Only return JSON as part of your answer. If no entity is in the text, return "None". + + Document: + """ + ) + + J_DYNAMIC_DOC_TEXT, +) + +PropertiesFromSchemaJinjaPrompt = JinjaPrompt( + system="You are given text contents from a document.", + user=( + J_FORMAT_SCHEMA_MACRO + + """\ +Extract values for the following fields: +{{ format_schema(schema) }} + +Document text:""" + + J_DYNAMIC_DOC_TEXT + + """ + +Don't return extra information. +If you cannot find a value for a requested property, use the provided default or the value 'None'. +Return your answers as a valid json dictionary that will be parsed in python. +""" + ), +) + + class EntityExtractorMessagesPrompt(SimplePrompt): def __init__(self, question: str, field: str, format: Optional[str], discrete: bool = False): super().__init__() diff --git a/lib/sycamore/sycamore/llms/prompts/jinja_fragments.py b/lib/sycamore/sycamore/llms/prompts/jinja_fragments.py index 29d87e11c..857957fba 100644 --- a/lib/sycamore/sycamore/llms/prompts/jinja_fragments.py +++ b/lib/sycamore/sycamore/llms/prompts/jinja_fragments.py @@ -1,7 +1,11 @@ -J_ELEMENT_LIST = """\ +J_ELEMENT_LIST_CAPPED = """\ {% for elt in doc.elements[:num_elements] %}ELEMENT {{ elt.element_index }}: {{ elt.field_to_value(field) }} {% endfor %}""" +J_ELEMENT_LIST_UNCAPPED = """\ +{% for elt in doc.elements %}ELEMENT {{ elt.element_index }}: {{ elt.field_to_value(field) }} +{% endfor %}""" + # Directive to not render the template if the iteration var has # surpassed the number of batches. @@ -29,3 +33,28 @@ Text: {{ elt.field_to_value(field) }} {% endfor -%}""" ) + +J_SET_SCHEMA = """{%- if schema is not defined %}{% set schema = doc.properties["_schema"] %}{% endif -%}\n""" +J_SET_ENTITY = ( + """{%- if entity is not defined %}{% set entity = doc.properties.get("_schema_class", "entity") %}{% endif -%}\n""" +) + +J_DYNAMIC_DOC_TEXT = ( + """{%- set field = "text_representation" -%} +{% if doc.text_representation is not none %}{{ doc.text_representation }} +{% elif prompt_formatter is defined %}{{ prompt_formatter(doc.elements) }} +{% elif num_elements is defined %}""" + + J_ELEMENT_LIST_CAPPED + + "{% else %}" + + J_ELEMENT_LIST_UNCAPPED + + "{% endif %}" +) + +J_FORMAT_SCHEMA_MACRO = """{% macro format_schema(schema) -%} +{% for field in schema.fields %} +{{ loop.index }} {{ field.name }}: {{ field.field_type }}: default={{ field.default }} +{% if field.description %} Decription: {{ field.description }}{% endif %} +{% if field.examples %} Example values: {{ field.examples }}{% endif %} +{%- endfor -%} +{%- endmacro %} +""" diff --git a/lib/sycamore/sycamore/llms/prompts/prompts.py b/lib/sycamore/sycamore/llms/prompts/prompts.py index 7995e6775..d18b35128 100644 --- a/lib/sycamore/sycamore/llms/prompts/prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/prompts.py @@ -13,6 +13,9 @@ from jinja2 import Template +ResponseFormat = Union[None, dict[str, Any], type[pydantic.BaseModel]] + + @dataclass class RenderedMessage: """Represents a message per the LLM messages interface - i.e. a role and a content string @@ -39,7 +42,7 @@ class RenderedPrompt: """ messages: list[RenderedMessage] - response_format: Union[None, dict[str, Any], type[pydantic.BaseModel]] = None + response_format: ResponseFormat = None def token_count(self, tokenizer: Tokenizer) -> int: total_text = " ".join(m.content for m in self.messages) @@ -535,13 +538,21 @@ class JinjaPrompt(SycamorePrompt): """ - def __init__(self, *, system: Optional[str] = None, user: Union[None, str, list[str]] = None, **kwargs): + def __init__( + self, + *, + system: Optional[str] = None, + user: Union[None, str, list[str]] = None, + response_format: ResponseFormat = None, + **kwargs, + ): from jinja2.sandbox import SandboxedEnvironment from jinja2 import Template super().__init__() self.system = system self.user = user + self.response_format = response_format self.kwargs = kwargs self._env = SandboxedEnvironment(extensions=["jinja2.ext.loopcontrols"]) self._sys_template: Optional[Template] = None @@ -577,6 +588,8 @@ def render_document(self, doc: Document) -> RenderedPrompt: render_args["doc"] = doc rendered = render_templates(self._sys_template, self._user_templates, render_args) + if self.response_format is not None: + rendered.response_format = self.response_format return rendered @@ -587,6 +600,7 @@ def __init__( system: Optional[str] = None, user: Union[None, str, list[str]] = None, include_image: bool = False, + response_format: ResponseFormat = None, **kwargs, ): from jinja2.sandbox import SandboxedEnvironment @@ -596,6 +610,7 @@ def __init__( self.system = system self.user = user self.include_image = include_image + self.response_format = response_format self.kwargs = kwargs self._env = SandboxedEnvironment(extensions=["jinja2.ext.loopcontrols"]) self._sys_template: Optional[Template] = None @@ -629,5 +644,6 @@ def render_element(self, elt: Element, doc: Document) -> RenderedPrompt: from sycamore.utils.pdf_utils import get_element_image result.messages[-1].images = [get_element_image(elt, doc)] - print(result) + if self.response_format is not None: + result.response_format = self.response_format return result diff --git a/lib/sycamore/sycamore/tests/integration/transforms/test_data_extraction.py b/lib/sycamore/sycamore/tests/integration/transforms/test_data_extraction.py index 28178135c..0ae0111d9 100644 --- a/lib/sycamore/sycamore/tests/integration/transforms/test_data_extraction.py +++ b/lib/sycamore/sycamore/tests/integration/transforms/test_data_extraction.py @@ -68,7 +68,9 @@ def test_extract_properties_from_schema(llm): default="null", ), SchemaField(name="age", field_type="int", default=999), - SchemaField(name="date", field_type="str", description="Any date in the doc in YYYY-MM-DD format"), + SchemaField( + name="date", field_type="str", description="Any date in the doc, extracted in YYYY-MM-DD format" + ), SchemaField( name="from_location", field_type="str", diff --git a/lib/sycamore/sycamore/transforms/extract_entity.py b/lib/sycamore/sycamore/transforms/extract_entity.py index 9d7ce8397..d685bef9a 100644 --- a/lib/sycamore/sycamore/transforms/extract_entity.py +++ b/lib/sycamore/sycamore/transforms/extract_entity.py @@ -164,7 +164,9 @@ def _get_prompt(self) -> SycamorePrompt: if self._prompt is not None: if isinstance(self._prompt, str): - return JinjaPrompt(system=None, user=self._prompt + "\n" + j_elements, **common_params) + return JinjaPrompt( + system=None, user=self._prompt + "\n" + j_elements, response_format=None, **common_params + ) else: system = None if len(self._prompt) > 0 and self._prompt[0]["role"] == "system": @@ -172,7 +174,7 @@ def _get_prompt(self) -> SycamorePrompt: user = [p["content"] for p in self._prompt[1:]] + [j_elements] else: user = [p["content"] for p in self._prompt] + [j_elements] - return JinjaPrompt(system=system, user=user, **common_params) + return JinjaPrompt(system=system, user=user, response_format=None, **common_params) elif self._prompt_template is not None: return EntityExtractorFewShotJinjaPrompt.set(examples=self._prompt_template, **common_params) else: diff --git a/lib/sycamore/sycamore/transforms/extract_schema.py b/lib/sycamore/sycamore/transforms/extract_schema.py index bca5dc1bc..fec4edf47 100644 --- a/lib/sycamore/sycamore/transforms/extract_schema.py +++ b/lib/sycamore/sycamore/transforms/extract_schema.py @@ -1,19 +1,16 @@ from abc import ABC, abstractmethod from typing import Callable, Any, Optional, Union import json -import textwrap -import copy from sycamore.data import Element, Document -from sycamore.connectors.common import flatten_data -from sycamore.llms.prompts.prompts import ElementListPrompt from sycamore.schema import Schema from sycamore.llms import LLM from sycamore.llms.prompts.default_prompts import ( _SchemaZeroShotGuidancePrompt, + PropertiesZeroShotJinjaPrompt, + PropertiesFromSchemaJinjaPrompt, ) -from sycamore.llms.prompts import SycamorePrompt, RenderedPrompt -from sycamore.llms.prompts.prompts import _build_format_str +from sycamore.llms.prompts import SycamorePrompt from sycamore.plan_nodes import Node from sycamore.transforms.map import Map from sycamore.transforms.base_llm import LLMMap @@ -51,78 +48,6 @@ def as_llm_map(self, child: Optional[Node], **kwargs) -> Node: pass -class PropertyExtractionFromSchemaPrompt(ElementListPrompt): - default_system = "You are given text contents from a document." - default_user = textwrap.dedent( - """\ - Extract values for the following fields: - {schema} - - Document text: - {doc_text} - - Don't return extra information. - If you cannot find a value for a requested property, use the provided default or the value 'None'. - Return your answers as a valid json dictionary that will be parsed in python. - """ - ) - - def __init__(self, schema: Schema): - super().__init__(system=self.default_system, user=self.default_user) - self.schema = schema - self.kwargs["schema"] = self._format_schema(schema) - - @staticmethod - def _format_schema(schema: Schema) -> str: - text = "" - for i, field in enumerate(schema.fields): - text += f"{i} {field.name}: type={field.field_type}: default={field.default}\n" - if field.description is not None: - text += f" {field.description}\n" - if field.examples is not None: - text += f" Examples values: {field.examples}\n" - return text - - def set(self, **kwargs) -> SycamorePrompt: - if "schema" in kwargs: - new = copy.deepcopy(self) - new.schema = kwargs["schema"] - kwargs["schema"] = self._format_schema(new.schema) - return new.set(**kwargs) - return super().set(**kwargs) - - def render_document(self, doc: Document) -> RenderedPrompt: - rp = super().render_document(doc) - rp.response_format = self.schema.model_dump() - return rp - - -class PropertyExtractionFromDictPrompt(ElementListPrompt): - def __init__(self, schema: Optional[dict] = None, **kwargs): - super().__init__(**kwargs) - self.schema = schema - - def render_document(self, doc: Document) -> RenderedPrompt: - format_args = copy.deepcopy(self.kwargs) - format_args["doc_text"] = doc.text_representation - if self.schema is None: - schema = doc.properties.get("_schema") - else: - schema = self.schema - format_args["schema"] = schema - if "entity" not in format_args: - format_args["entity"] = doc.properties.get("_schema_class", "entity") - flat_props = flatten_data(doc.properties, prefix="doc_property", separator="_") - format_args.update(flat_props) - format_args["elements"] = self._render_element_list_to_string(doc) - if doc.text_representation is None: - format_args["doc_text"] = format_args["elements"] - - messages = _build_format_str(self.system, self.user, format_args) - result = RenderedPrompt(messages=messages, response_format=schema) - return result - - class LLMSchemaExtractor(SchemaExtractor): """ The LLMSchemaExtractor uses the specified LLM object to extract a schema. @@ -287,23 +212,18 @@ def cast_types(self, fields: dict) -> dict: def as_llm_map(self, child: Optional[Node], **kwargs) -> Node: prompt: SycamorePrompt # mypy grr if isinstance(self._schema, Schema): - prompt = PropertyExtractionFromSchemaPrompt(self._schema) + prompt = PropertiesFromSchemaJinjaPrompt + prompt = prompt.set(schema=self._schema, response_format=self._schema.model_dump()) else: - prompt = PropertyExtractionFromDictPrompt( - schema=self._schema, - system="You are a helpful property extractor. You only return JSON.", - user=textwrap.dedent( - """\ - You are given a few text elements of a document. Extract JSON representing one entity of - class {entity} from the document. The class only has properties {schema}. Using - this context, FIND, FORMAT, and RETURN the JSON representing one {entity}. - Only return JSON as part of your answer. If no entity is in the text, return "None". - {doc_text} - """ - ), - ) + prompt = PropertiesZeroShotJinjaPrompt + if self._schema is not None: + prompt = prompt.set(schema=self._schema) + if self._schema_name is not None: prompt = prompt.set(entity=self._schema_name) + prompt = prompt.set(num_elements=self._num_of_elements) + if self._prompt_formatter is not element_list_formatter: + prompt = prompt.set(prompt_formatter=self._prompt_formatter) def parse_json_and_cast(d: Document) -> Document: entity_name = self._schema_name or "_entity"