Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llm unify 5c/n] jinjify extract properties #1169

Merged
merged 21 commits into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
d5cae60
add jinja prompts and convert extract entity to use it
HenryL27 Feb 7, 2025
c56626f
Merge branch 'main' of github.com:aryn-ai/sycamore into hml-jinjaprom…
HenryL27 Feb 7, 2025
274ddfe
delete commented out / dead code
HenryL27 Feb 7, 2025
2bb14dd
JinjaPrompt docstring
HenryL27 Feb 7, 2025
c6ed22f
add comments bc otherwise this is very dense
HenryL27 Feb 7, 2025
d9fa8ab
add norender() directive to jinja prompts when they shouldn't render
HenryL27 Feb 7, 2025
d9d5079
change FANCY_BATCHED_LIST to BATCHED_LIST_WITH_METADATA
HenryL27 Feb 7, 2025
9b23fbd
pr comments
HenryL27 Feb 7, 2025
01cd387
Merge branch 'main' of github.com:aryn-ai/sycamore into hml-jinjaprom…
HenryL27 Feb 7, 2025
a5b9b9a
branch switch
HenryL27 Feb 7, 2025
104a4e2
move summarizeImages to jinja
HenryL27 Feb 8, 2025
5d2df13
delete old bespoke prompt class
HenryL27 Feb 8, 2025
f7ea173
make it actually use the jinja prompt (and fix the jinja)
HenryL27 Feb 8, 2025
bf6751f
mypy:
HenryL27 Feb 8, 2025
737f407
add summarize_images unittest (mostly a prompt ut)
HenryL27 Feb 10, 2025
b6cdae8
adjust prop extraction its to count the correct number of lineage met…
HenryL27 Feb 10, 2025
000a71c
extract properties -> jinja
HenryL27 Feb 10, 2025
fce9743
set prompt_formatter when supplied a non-default
HenryL27 Feb 10, 2025
6224607
Merge branch 'main' of github.com:aryn-ai/sycamore into hml-jinjaprom…
HenryL27 Feb 10, 2025
fe0a40e
fix mypy. why is this fix a fix? I don't understand python.
HenryL27 Feb 10, 2025
363707e
drop a print statement
HenryL27 Feb 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions lib/sycamore/sycamore/llms/prompts/default_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
JinjaPrompt,
JinjaElementPrompt,
)
from sycamore.llms.prompts.jinja_fragments import (
J_DYNAMIC_DOC_TEXT,
J_FORMAT_SCHEMA_MACRO,
J_SET_ENTITY,
J_SET_SCHEMA,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -357,6 +363,43 @@ class {entity} from the document. The class only has properties {properties}. Us
)


PropertiesZeroShotJinjaPrompt = JinjaPrompt(
system="You are a helpful property extractor. You only return JSON.",
user=J_SET_SCHEMA
+ J_SET_ENTITY
+ textwrap.dedent(
"""\
You are given some text of a document. Extract JSON representing one entity of
class {{ entity }} from the document. The class only has properties {{ schema }}. Using
this context, FIND, FORMAT, and RETURN the JSON representing one {{ entity }}.
Only return JSON as part of your answer. If no entity is in the text, return "None".

Document:
"""
)
+ J_DYNAMIC_DOC_TEXT,
)

PropertiesFromSchemaJinjaPrompt = JinjaPrompt(
system="You are given text contents from a document.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the correct system prompt?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaict, yep.

class ExtractPropertiesFromSchemaPrompt(SimplePrompt):

user=(
J_FORMAT_SCHEMA_MACRO
+ """\
Extract values for the following fields:
{{ format_schema(schema) }}

Document text:"""
+ J_DYNAMIC_DOC_TEXT
+ """

Don't return extra information.
If you cannot find a value for a requested property, use the provided default or the value 'None'.
Return your answers as a valid json dictionary that will be parsed in python.
"""
),
)


class EntityExtractorMessagesPrompt(SimplePrompt):
def __init__(self, question: str, field: str, format: Optional[str], discrete: bool = False):
super().__init__()
Expand Down
31 changes: 30 additions & 1 deletion lib/sycamore/sycamore/llms/prompts/jinja_fragments.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
J_ELEMENT_LIST = """\
J_ELEMENT_LIST_CAPPED = """\
{% for elt in doc.elements[:num_elements] %}ELEMENT {{ elt.element_index }}: {{ elt.field_to_value(field) }}
{% endfor %}"""

J_ELEMENT_LIST_UNCAPPED = """\
{% for elt in doc.elements %}ELEMENT {{ elt.element_index }}: {{ elt.field_to_value(field) }}
{% endfor %}"""


# Directive to not render the template if the iteration var has
# surpassed the number of batches.
Expand Down Expand Up @@ -29,3 +33,28 @@
Text: {{ elt.field_to_value(field) }}
{% endfor -%}"""
)

J_SET_SCHEMA = """{%- if schema is not defined %}{% set schema = doc.properties["_schema"] %}{% endif -%}\n"""
J_SET_ENTITY = (
"""{%- if entity is not defined %}{% set entity = doc.properties.get("_schema_class", "entity") %}{% endif -%}\n"""
)

J_DYNAMIC_DOC_TEXT = (
"""{%- set field = "text_representation" -%}
{% if doc.text_representation is not none %}{{ doc.text_representation }}
{% elif prompt_formatter is defined %}{{ prompt_formatter(doc.elements) }}
{% elif num_elements is defined %}"""
+ J_ELEMENT_LIST_CAPPED
+ "{% else %}"
+ J_ELEMENT_LIST_UNCAPPED
+ "{% endif %}"
)

J_FORMAT_SCHEMA_MACRO = """{% macro format_schema(schema) -%}
{% for field in schema.fields %}
{{ loop.index }} {{ field.name }}: {{ field.field_type }}: default={{ field.default }}
{% if field.description %} Decription: {{ field.description }}{% endif %}
{% if field.examples %} Example values: {{ field.examples }}{% endif %}
{%- endfor -%}
{%- endmacro %}
"""
22 changes: 19 additions & 3 deletions lib/sycamore/sycamore/llms/prompts/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from jinja2 import Template


ResponseFormat = Union[None, dict[str, Any], type[pydantic.BaseModel]]


@dataclass
class RenderedMessage:
"""Represents a message per the LLM messages interface - i.e. a role and a content string
Expand All @@ -39,7 +42,7 @@ class RenderedPrompt:
"""

messages: list[RenderedMessage]
response_format: Union[None, dict[str, Any], type[pydantic.BaseModel]] = None
response_format: ResponseFormat = None

def token_count(self, tokenizer: Tokenizer) -> int:
total_text = " ".join(m.content for m in self.messages)
Expand Down Expand Up @@ -535,13 +538,21 @@ class JinjaPrompt(SycamorePrompt):

"""

def __init__(self, *, system: Optional[str] = None, user: Union[None, str, list[str]] = None, **kwargs):
def __init__(
self,
*,
system: Optional[str] = None,
user: Union[None, str, list[str]] = None,
response_format: ResponseFormat = None,
**kwargs,
):
from jinja2.sandbox import SandboxedEnvironment
from jinja2 import Template

super().__init__()
self.system = system
self.user = user
self.response_format = response_format
self.kwargs = kwargs
self._env = SandboxedEnvironment(extensions=["jinja2.ext.loopcontrols"])
self._sys_template: Optional[Template] = None
Expand Down Expand Up @@ -577,6 +588,8 @@ def render_document(self, doc: Document) -> RenderedPrompt:
render_args["doc"] = doc

rendered = render_templates(self._sys_template, self._user_templates, render_args)
if self.response_format is not None:
rendered.response_format = self.response_format
return rendered


Expand All @@ -587,6 +600,7 @@ def __init__(
system: Optional[str] = None,
user: Union[None, str, list[str]] = None,
include_image: bool = False,
response_format: ResponseFormat = None,
**kwargs,
):
from jinja2.sandbox import SandboxedEnvironment
Expand All @@ -596,6 +610,7 @@ def __init__(
self.system = system
self.user = user
self.include_image = include_image
self.response_format = response_format
self.kwargs = kwargs
self._env = SandboxedEnvironment(extensions=["jinja2.ext.loopcontrols"])
self._sys_template: Optional[Template] = None
Expand Down Expand Up @@ -629,5 +644,6 @@ def render_element(self, elt: Element, doc: Document) -> RenderedPrompt:
from sycamore.utils.pdf_utils import get_element_image

result.messages[-1].images = [get_element_image(elt, doc)]
print(result)
if self.response_format is not None:
result.response_format = self.response_format
return result
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def test_extract_properties_from_schema(llm):
default="null",
),
SchemaField(name="age", field_type="int", default=999),
SchemaField(name="date", field_type="str", description="Any date in the doc in YYYY-MM-DD format"),
SchemaField(
name="date", field_type="str", description="Any date in the doc, extracted in YYYY-MM-DD format"
),
SchemaField(
name="from_location",
field_type="str",
Expand Down
1 change: 1 addition & 0 deletions lib/sycamore/sycamore/transforms/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def _infer_prompts(
if llm_mode == LLMMode.SYNC:
res = []
for p in prompts:
print(p)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

s = llm.generate(prompt=p)
res.append(s)
return res
Expand Down
6 changes: 4 additions & 2 deletions lib/sycamore/sycamore/transforms/extract_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,17 @@ def _get_prompt(self) -> SycamorePrompt:

if self._prompt is not None:
if isinstance(self._prompt, str):
return JinjaPrompt(system=None, user=self._prompt + "\n" + j_elements, **common_params)
return JinjaPrompt(
system=None, user=self._prompt + "\n" + j_elements, response_format=None, **common_params
)
else:
system = None
if len(self._prompt) > 0 and self._prompt[0]["role"] == "system":
system = self._prompt[0]["content"]
user = [p["content"] for p in self._prompt[1:]] + [j_elements]
else:
user = [p["content"] for p in self._prompt] + [j_elements]
return JinjaPrompt(system=system, user=user, **common_params)
return JinjaPrompt(system=system, user=user, response_format=None, **common_params)
elif self._prompt_template is not None:
return EntityExtractorFewShotJinjaPrompt.set(examples=self._prompt_template, **common_params)
else:
Expand Down
104 changes: 12 additions & 92 deletions lib/sycamore/sycamore/transforms/extract_schema.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
from abc import ABC, abstractmethod
from typing import Callable, Any, Optional, Union
import json
import textwrap
import copy

from sycamore.data import Element, Document
from sycamore.connectors.common import flatten_data
from sycamore.llms.prompts.prompts import ElementListPrompt
from sycamore.schema import Schema
from sycamore.llms import LLM
from sycamore.llms.prompts.default_prompts import (
_SchemaZeroShotGuidancePrompt,
PropertiesZeroShotJinjaPrompt,
PropertiesFromSchemaJinjaPrompt,
)
from sycamore.llms.prompts import SycamorePrompt, RenderedPrompt
from sycamore.llms.prompts.prompts import _build_format_str
from sycamore.llms.prompts import SycamorePrompt
from sycamore.plan_nodes import Node
from sycamore.transforms.map import Map
from sycamore.transforms.base_llm import LLMMap
Expand Down Expand Up @@ -51,78 +48,6 @@ def as_llm_map(self, child: Optional[Node], **kwargs) -> Node:
pass


class PropertyExtractionFromSchemaPrompt(ElementListPrompt):
default_system = "You are given text contents from a document."
default_user = textwrap.dedent(
"""\
Extract values for the following fields:
{schema}

Document text:
{doc_text}

Don't return extra information.
If you cannot find a value for a requested property, use the provided default or the value 'None'.
Return your answers as a valid json dictionary that will be parsed in python.
"""
)

def __init__(self, schema: Schema):
super().__init__(system=self.default_system, user=self.default_user)
self.schema = schema
self.kwargs["schema"] = self._format_schema(schema)

@staticmethod
def _format_schema(schema: Schema) -> str:
text = ""
for i, field in enumerate(schema.fields):
text += f"{i} {field.name}: type={field.field_type}: default={field.default}\n"
if field.description is not None:
text += f" {field.description}\n"
if field.examples is not None:
text += f" Examples values: {field.examples}\n"
return text

def set(self, **kwargs) -> SycamorePrompt:
if "schema" in kwargs:
new = copy.deepcopy(self)
new.schema = kwargs["schema"]
kwargs["schema"] = self._format_schema(new.schema)
return new.set(**kwargs)
return super().set(**kwargs)

def render_document(self, doc: Document) -> RenderedPrompt:
rp = super().render_document(doc)
rp.response_format = self.schema.model_dump()
return rp


class PropertyExtractionFromDictPrompt(ElementListPrompt):
def __init__(self, schema: Optional[dict] = None, **kwargs):
super().__init__(**kwargs)
self.schema = schema

def render_document(self, doc: Document) -> RenderedPrompt:
format_args = copy.deepcopy(self.kwargs)
format_args["doc_text"] = doc.text_representation
if self.schema is None:
schema = doc.properties.get("_schema")
else:
schema = self.schema
format_args["schema"] = schema
if "entity" not in format_args:
format_args["entity"] = doc.properties.get("_schema_class", "entity")
flat_props = flatten_data(doc.properties, prefix="doc_property", separator="_")
format_args.update(flat_props)
format_args["elements"] = self._render_element_list_to_string(doc)
if doc.text_representation is None:
format_args["doc_text"] = format_args["elements"]

messages = _build_format_str(self.system, self.user, format_args)
result = RenderedPrompt(messages=messages, response_format=schema)
return result


class LLMSchemaExtractor(SchemaExtractor):
"""
The LLMSchemaExtractor uses the specified LLM object to extract a schema.
Expand Down Expand Up @@ -287,23 +212,18 @@ def cast_types(self, fields: dict) -> dict:
def as_llm_map(self, child: Optional[Node], **kwargs) -> Node:
prompt: SycamorePrompt # mypy grr
if isinstance(self._schema, Schema):
prompt = PropertyExtractionFromSchemaPrompt(self._schema)
prompt = PropertiesFromSchemaJinjaPrompt
prompt = prompt.set(schema=self._schema, response_format=self._schema.model_dump())
else:
prompt = PropertyExtractionFromDictPrompt(
schema=self._schema,
system="You are a helpful property extractor. You only return JSON.",
user=textwrap.dedent(
"""\
You are given a few text elements of a document. Extract JSON representing one entity of
class {entity} from the document. The class only has properties {schema}. Using
this context, FIND, FORMAT, and RETURN the JSON representing one {entity}.
Only return JSON as part of your answer. If no entity is in the text, return "None".
{doc_text}
"""
),
)
prompt = PropertiesZeroShotJinjaPrompt
if self._schema is not None:
prompt = prompt.set(schema=self._schema)

if self._schema_name is not None:
prompt = prompt.set(entity=self._schema_name)
prompt = prompt.set(num_elements=self._num_of_elements)
if self._prompt_formatter is not element_list_formatter:
prompt = prompt.set(prompt_formatter=self._prompt_formatter)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not setting for element_list_formatter because you're letting jinja fragments do the work? I'm assuming you would just remove element_list_formatter once you refactor LLM_schema_extractor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep. will aim to drop it when it's no longer used


def parse_json_and_cast(d: Document) -> Document:
entity_name = self._schema_name or "_entity"
Expand Down
Loading