-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[llm unify 5c/n] jinjify extract properties #1169
Changes from 20 commits
d5cae60
c56626f
274ddfe
2bb14dd
c6ed22f
d9fa8ab
d9d5079
9b23fbd
01cd387
a5b9b9a
104a4e2
5d2df13
f7ea173
bf6751f
737f407
b6cdae8
000a71c
fce9743
6224607
fe0a40e
363707e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ def _infer_prompts( | |
if llm_mode == LLMMode.SYNC: | ||
res = [] | ||
for p in prompts: | ||
print(p) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove |
||
s = llm.generate(prompt=p) | ||
res.append(s) | ||
return res | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,16 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Callable, Any, Optional, Union | ||
import json | ||
import textwrap | ||
import copy | ||
|
||
from sycamore.data import Element, Document | ||
from sycamore.connectors.common import flatten_data | ||
from sycamore.llms.prompts.prompts import ElementListPrompt | ||
from sycamore.schema import Schema | ||
from sycamore.llms import LLM | ||
from sycamore.llms.prompts.default_prompts import ( | ||
_SchemaZeroShotGuidancePrompt, | ||
PropertiesZeroShotJinjaPrompt, | ||
PropertiesFromSchemaJinjaPrompt, | ||
) | ||
from sycamore.llms.prompts import SycamorePrompt, RenderedPrompt | ||
from sycamore.llms.prompts.prompts import _build_format_str | ||
from sycamore.llms.prompts import SycamorePrompt | ||
from sycamore.plan_nodes import Node | ||
from sycamore.transforms.map import Map | ||
from sycamore.transforms.base_llm import LLMMap | ||
|
@@ -51,78 +48,6 @@ def as_llm_map(self, child: Optional[Node], **kwargs) -> Node: | |
pass | ||
|
||
|
||
class PropertyExtractionFromSchemaPrompt(ElementListPrompt): | ||
default_system = "You are given text contents from a document." | ||
default_user = textwrap.dedent( | ||
"""\ | ||
Extract values for the following fields: | ||
{schema} | ||
|
||
Document text: | ||
{doc_text} | ||
|
||
Don't return extra information. | ||
If you cannot find a value for a requested property, use the provided default or the value 'None'. | ||
Return your answers as a valid json dictionary that will be parsed in python. | ||
""" | ||
) | ||
|
||
def __init__(self, schema: Schema): | ||
super().__init__(system=self.default_system, user=self.default_user) | ||
self.schema = schema | ||
self.kwargs["schema"] = self._format_schema(schema) | ||
|
||
@staticmethod | ||
def _format_schema(schema: Schema) -> str: | ||
text = "" | ||
for i, field in enumerate(schema.fields): | ||
text += f"{i} {field.name}: type={field.field_type}: default={field.default}\n" | ||
if field.description is not None: | ||
text += f" {field.description}\n" | ||
if field.examples is not None: | ||
text += f" Examples values: {field.examples}\n" | ||
return text | ||
|
||
def set(self, **kwargs) -> SycamorePrompt: | ||
if "schema" in kwargs: | ||
new = copy.deepcopy(self) | ||
new.schema = kwargs["schema"] | ||
kwargs["schema"] = self._format_schema(new.schema) | ||
return new.set(**kwargs) | ||
return super().set(**kwargs) | ||
|
||
def render_document(self, doc: Document) -> RenderedPrompt: | ||
rp = super().render_document(doc) | ||
rp.response_format = self.schema.model_dump() | ||
return rp | ||
|
||
|
||
class PropertyExtractionFromDictPrompt(ElementListPrompt): | ||
def __init__(self, schema: Optional[dict] = None, **kwargs): | ||
super().__init__(**kwargs) | ||
self.schema = schema | ||
|
||
def render_document(self, doc: Document) -> RenderedPrompt: | ||
format_args = copy.deepcopy(self.kwargs) | ||
format_args["doc_text"] = doc.text_representation | ||
if self.schema is None: | ||
schema = doc.properties.get("_schema") | ||
else: | ||
schema = self.schema | ||
format_args["schema"] = schema | ||
if "entity" not in format_args: | ||
format_args["entity"] = doc.properties.get("_schema_class", "entity") | ||
flat_props = flatten_data(doc.properties, prefix="doc_property", separator="_") | ||
format_args.update(flat_props) | ||
format_args["elements"] = self._render_element_list_to_string(doc) | ||
if doc.text_representation is None: | ||
format_args["doc_text"] = format_args["elements"] | ||
|
||
messages = _build_format_str(self.system, self.user, format_args) | ||
result = RenderedPrompt(messages=messages, response_format=schema) | ||
return result | ||
|
||
|
||
class LLMSchemaExtractor(SchemaExtractor): | ||
""" | ||
The LLMSchemaExtractor uses the specified LLM object to extract a schema. | ||
|
@@ -287,23 +212,18 @@ def cast_types(self, fields: dict) -> dict: | |
def as_llm_map(self, child: Optional[Node], **kwargs) -> Node: | ||
prompt: SycamorePrompt # mypy grr | ||
if isinstance(self._schema, Schema): | ||
prompt = PropertyExtractionFromSchemaPrompt(self._schema) | ||
prompt = PropertiesFromSchemaJinjaPrompt | ||
prompt = prompt.set(schema=self._schema, response_format=self._schema.model_dump()) | ||
else: | ||
prompt = PropertyExtractionFromDictPrompt( | ||
schema=self._schema, | ||
system="You are a helpful property extractor. You only return JSON.", | ||
user=textwrap.dedent( | ||
"""\ | ||
You are given a few text elements of a document. Extract JSON representing one entity of | ||
class {entity} from the document. The class only has properties {schema}. Using | ||
this context, FIND, FORMAT, and RETURN the JSON representing one {entity}. | ||
Only return JSON as part of your answer. If no entity is in the text, return "None". | ||
{doc_text} | ||
""" | ||
), | ||
) | ||
prompt = PropertiesZeroShotJinjaPrompt | ||
if self._schema is not None: | ||
prompt = prompt.set(schema=self._schema) | ||
|
||
if self._schema_name is not None: | ||
prompt = prompt.set(entity=self._schema_name) | ||
prompt = prompt.set(num_elements=self._num_of_elements) | ||
if self._prompt_formatter is not element_list_formatter: | ||
prompt = prompt.set(prompt_formatter=self._prompt_formatter) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not setting for element_list_formatter because you're letting jinja fragments do the work? I'm assuming you would just remove element_list_formatter once you refactor LLM_schema_extractor? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep. will aim to drop it when it's no longer used |
||
|
||
def parse_json_and_cast(d: Document) -> Document: | ||
entity_name = self._schema_name or "_entity" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the correct system prompt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
afaict, yep.
sycamore/lib/sycamore/sycamore/llms/prompts/default_prompts.py
Line 167 in ae1820f