diff --git a/lib/sycamore/sycamore/docset.py b/lib/sycamore/sycamore/docset.py index f29f39ea6..399895a3d 100644 --- a/lib/sycamore/sycamore/docset.py +++ b/lib/sycamore/sycamore/docset.py @@ -733,10 +733,8 @@ def summarize(self, summarizer: Summarizer, **kwargs) -> "DocSet": .partition(partitioner=ArynPartitioner()) .summarize(summarizer=summarizer) """ - from sycamore.transforms import Summarize - - summaries = Summarize(self.plan, summarizer=summarizer, **kwargs) - return DocSet(self.context, summaries) + map = summarizer.as_llm_map(self.plan, **kwargs) + return DocSet(self.context, map) def mark_bbox_preset(self, tokenizer: Tokenizer, token_limit: int = 512, **kwargs) -> "DocSet": """ diff --git a/lib/sycamore/sycamore/llms/prompts/default_prompts.py b/lib/sycamore/sycamore/llms/prompts/default_prompts.py index 4f50dddba..a0a641953 100644 --- a/lib/sycamore/sycamore/llms/prompts/default_prompts.py +++ b/lib/sycamore/sycamore/llms/prompts/default_prompts.py @@ -172,6 +172,15 @@ class _TextSummarizerGuidancePrompt(SimplePrompt): """, ) +TextSummarizerJinjaPrompt = JinjaElementPrompt( + system="You are a helpful text summarizer.", + user="""Write a summary of the following. Use only the information provided. + Include as many key details as possible. Do not make up your answer. Only return the summary as part of your answer. + + {{ elt.text_representation }} + """, +) + class _SchemaZeroShotGuidancePrompt(SimplePrompt): system = "You are a helpful entity extractor. You only return JSON Schema." diff --git a/lib/sycamore/sycamore/query/execution/operations.py b/lib/sycamore/sycamore/query/execution/operations.py index 033d0b7fb..ae147e3b0 100644 --- a/lib/sycamore/sycamore/query/execution/operations.py +++ b/lib/sycamore/sycamore/query/execution/operations.py @@ -5,21 +5,32 @@ from sycamore import DocSet from sycamore.context import context_params, Context -from sycamore.data import MetadataDocument -from sycamore.functions import CharacterTokenizer, Tokenizer +from sycamore.data import Document, Element +from sycamore.functions.tokenizer import OpenAITokenizer from sycamore.llms.llms import LLM +from sycamore.llms.prompts import RenderedPrompt, RenderedMessage from sycamore.llms.prompts.default_prompts import ( SummarizeDataMessagesPrompt, ) from sycamore.transforms.summarize import ( - NUM_TEXT_CHARS_GENERATE, - DocumentSummarizer, - collapse, - QuestionAnsweringSummarizer, - BASE_PROPS, + EtCetera, + MultiStepDocumentSummarizer, + OneStepDocumentSummarizer, + Summarizer, ) log = structlog.get_logger(__name__) +# multistep +DEFAULT_DOCSET_SUMMARIZER_CLS = MultiStepDocumentSummarizer # type: ignore + +DEFAULT_SUMMARIZER_KWARGS: dict[str, Any] = { + "fields": "*", + "tokenizer": OpenAITokenizer("gpt-4o"), + "max_tokens": 80_000, +} +# onestep +DEFAULT_DOCSET_SUMMARIZER_CLS = OneStepDocumentSummarizer # type: ignore +DEFAULT_SUMMARIZER_KWARGS = {"fields": [EtCetera], "tokenizer": OpenAITokenizer("gpt-4o"), "token_limit": 80_000} def math_operation(val1: int, val2: int, operator: str) -> Union[int, float]: @@ -52,14 +63,12 @@ def math_operation(val1: int, val2: int, operator: str) -> Union[int, float]: @context_params def summarize_data( llm: LLM, - question: str, + question: Optional[str], result_description: str, result_data: List[Any], - use_elements: bool = False, - num_elements: int = 5, - max_tokens: int = 120 * 1000, - tokenizer: Tokenizer = CharacterTokenizer(), + summaries_as_text: bool = False, context: Optional[Context] = None, + docset_summarizer: Optional[Summarizer] = None, **kwargs, ) -> str: """ @@ -71,123 +80,72 @@ def summarize_data( question: Question to answer. result_description: Description of each of the inputs in result_data. result_data: List of inputs. - use_elements: Use text contents from document.elements instead of document.text_representation. - num_elements: Number of elements whose text to use from each document. - max_tokens: Maximum number of tokens allowed in the summary to send to the LLM. - tokenizer: Tokenizer to use for counting against max_tokens. + summaries_as_text: If true, summarize all documents in the result_data docsets and treat + those summaries as the text representation for the final summarize step. context: Optional Context object to get default parameters from. + docset_summarizer: Summarizer class to use to summarize the docset. + Default is `DEFAULT_DOCSET_SUMMARIZER` + summarizer_kwargs: keyword arguments to pass to the docset summarizer constructor. e.g. + `tokenizer`, `token_limit`, and `element_batch_size` **kwargs: Additional keyword arguments. Returns: Conversational response to question. """ - text = _get_text_for_summarize_data( - result_description=result_description, - result_data=result_data, - use_elements=use_elements, - num_elements=num_elements, - max_tokens=max_tokens, - tokenizer=tokenizer, - **kwargs, - ) - messages = SummarizeDataMessagesPrompt(question=question, text=text).as_messages() - prompt_kwargs = {"messages": messages} - - # call to LLM - completion = llm.generate_old(prompt_kwargs=prompt_kwargs, llm_kwargs={"temperature": 0}) - - # LLM response + if docset_summarizer is None: + docset_summarizer = DEFAULT_DOCSET_SUMMARIZER_CLS( + llm=llm, question=question, **DEFAULT_SUMMARIZER_KWARGS # type: ignore + ) + + if all(isinstance(d, DocSet) for d in result_data): + return summarize_data_docsets( + llm, + question, + result_data, + docset_summarizer=docset_summarizer, + data_description=result_description, + summaries_as_text=summaries_as_text, + ) + + # If data is not DocSets, text is this list here + # TODO: Jinjify. + text = f"Data description: {result_description}\n" + for i, d in enumerate(result_data): + text += f"Input {i + 1}: {str(d)}\n" + + messages = SummarizeDataMessagesPrompt(question=question or "", text=text).as_messages() + prompt = RenderedPrompt(messages=[RenderedMessage(role=m["role"], content=m["content"]) for m in messages]) + completion = llm.generate(prompt=prompt) return completion -def _get_text_for_summarize_data( - result_description: str, - result_data: List[Any], - use_elements: bool, - num_elements: int, - max_tokens: Optional[int] = None, - tokenizer: Optional[Tokenizer] = None, - **kwargs, -) -> str: - text = f"Data description: {result_description}\n" - if (max_tokens is not None and tokenizer is None) or (max_tokens is None and tokenizer is not None): - raise ValueError("Both max_tokens and tokenizer must be provided together.") - - for i, result in enumerate(result_data): - text += f"Input {i + 1}:\n" - - # consolidates relevant properties to give to LLM - if isinstance(result, DocSet): - done = False - # For query result caching in the executor, we need to consume the documents - # so that the materialized data is complete, even if they are not all included - # in the input prompt to the LLM. - for di, doc in enumerate(result.take_all()): - if isinstance(doc, MetadataDocument): - continue - if done: - continue - props_dict = doc.properties.get("entity", {}) - props_dict.update({p: doc.properties[p] for p in set(doc.properties) - set(BASE_PROPS)}) - doc_text = f"Document {di}:\n" - for k, v in props_dict.items(): - doc_text += f"{k}: {v}\n" - - doc_text_representation = "" - if not use_elements: - if doc.text_representation is not None: - doc_text_representation += doc.text_representation[:NUM_TEXT_CHARS_GENERATE] - else: - for element in doc.elements[:num_elements]: - # Greedy fill doc level text length - if len(doc_text_representation) >= NUM_TEXT_CHARS_GENERATE: - break - doc_text_representation += (element.text_representation or "") + "\n" - doc_text += f"Text contents:\n{doc_text_representation}\n" - - if tokenizer is not None and max_tokens is not None: # for mypy - total_token_count = len(tokenizer.tokenize(text + doc_text)) - if total_token_count > max_tokens: - log.warn( - "Unable to add all text from to the LLM summary request due to token limit." - f" Sending text from {di + 1} docs." - ) - done = True - continue - text += doc_text + "\n" - else: - text += str(result) + "\n" - - return text +def sum_to_text(d: Document) -> Document: + if "summary" in d.properties: + d.text_representation = d.properties.pop("summary") + return d -@context_params -def summarize_map_reduce( +def summarize_data_docsets( llm: LLM, - question: str, - result_description: str, - result_data: List[Any], - use_elements: bool = False, - num_elements: int = 5, - max_tokens: int = 10 * 1000, - tokenizer: Tokenizer = CharacterTokenizer(), + question: Optional[str], + result_data: List[DocSet], + docset_summarizer: Summarizer, + data_description: Optional[str] = None, + summaries_as_text: bool = False, ) -> str: - """ """ - text = f"Data description: {result_description}\n" - for i, result in enumerate(result_data): - if isinstance(result, DocSet): - docs = ( - result.filter(lambda d: isinstance(d, MetadataDocument) is False) - .summarize( - summarizer=DocumentSummarizer(llm, question) - ) # document-level summarization can be parallelized (per DocSet) - .take_all() - ) - for doc in docs: - text += doc.properties["summary"] + "\n" - - else: - text += str(result) + "\n" - - final_summary = collapse(text, max_tokens, tokenizer, QuestionAnsweringSummarizer(llm, question)) - return final_summary + if summaries_as_text: + result_data = [ds.summarize(docset_summarizer).map(sum_to_text) for ds in result_data] + + single_docs = [_docset_to_singledoc(ds) for ds in result_data] + agged_ds = result_data[0].context.read.document(single_docs).summarize(docset_summarizer) + texts = [d.properties["summary"] for d in agged_ds.take_all()] + return "\n".join(texts) + + +def _docset_to_singledoc(ds: DocSet) -> Document: + """ + Converts a docset into a single document by turning every Document + into an Element of a global parent document. Essentially a reverse + explode. + """ + return Document(elements=[Element(**d.data) for d in ds.take_all()]) diff --git a/lib/sycamore/sycamore/query/execution/sycamore_operator.py b/lib/sycamore/sycamore/query/execution/sycamore_operator.py index 187cd5ad0..2fdd44431 100644 --- a/lib/sycamore/sycamore/query/execution/sycamore_operator.py +++ b/lib/sycamore/sycamore/query/execution/sycamore_operator.py @@ -202,7 +202,6 @@ def execute(self) -> Any: result_description=description, result_data=self.inputs, context=self.context, - use_elements=True, **self.get_execute_args(), ) return result @@ -283,7 +282,7 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No input_str = input_var or get_var_name(self.logical_node.input_nodes()[0]) output_str = output_var or get_var_name(self.logical_node) result = f""" -prompt = LlmFilterMessagesPrompt(filter_question='{self.logical_node.question}').as_messages() +prompt = LlmFilterMessagesJinjaPrompt.set(filter_question='{self.logical_node.question}') {output_str} = {input_str}.llm_filter( new_field='_autogen_LLMFilterOutput', prompt=prompt, @@ -293,7 +292,7 @@ def script(self, input_var: Optional[str] = None, output_var: Optional[str] = No ) """ return result, [ - "from sycamore.llms.prompts.default_prompts import LlmFilterMessagesPrompt", + "from sycamore.llms.prompts.default_prompts import LlmFilterMessagesJinjaPrompt", ] diff --git a/lib/sycamore/sycamore/tests/integration/query/execution/test_operations.py b/lib/sycamore/sycamore/tests/integration/query/execution/test_operations.py index 157bfa1ab..080d3ae7e 100644 --- a/lib/sycamore/sycamore/tests/integration/query/execution/test_operations.py +++ b/lib/sycamore/sycamore/tests/integration/query/execution/test_operations.py @@ -3,13 +3,10 @@ import sycamore from sycamore import EXEC_RAY from sycamore.data import Document -from sycamore.functions import CharacterTokenizer from sycamore.llms import OpenAI, OpenAIModels from sycamore.query.execution.operations import ( - QuestionAnsweringSummarizer, - collapse, - DocumentSummarizer, - summarize_map_reduce, + MultiStepDocumentSummarizer, + summarize_data, ) from sycamore.tests.config import TEST_DIR from sycamore.transforms.partition import UnstructuredPdfPartitioner @@ -24,39 +21,6 @@ def llm(): class TestOperations: - def test_collapse(self, llm): - question = "What is" - summarizer_fn = QuestionAnsweringSummarizer(llm, question) - - """ - Use this code to generate the text file. - - path = str(TEST_DIR / "resources/data/pdfs/Ray.pdf") - context = sycamore.init(exec_mode=EXEC_RAY) - result = ( - context.read.binary(path, binary_format="pdf") - .partition(partitioner=UnstructuredPdfPartitioner()) - .explode() - #.summarize(summarizer=LLMElementTextSummarizer(llm)) - .take_all() - ) - text = "" - for doc in result: - #for element in doc.elements: - if doc.text_representation: - text += doc.text_representation + "\n" - # text += "\n" - """ - - text_path = str(TEST_DIR / "resources/data/texts/Ray.txt") - text = open(text_path, "r").read() - - max_tokens = 10000 - tokenizer = CharacterTokenizer() - summary = collapse(text, max_tokens, tokenizer, summarizer_fn) - assert summary is not None - print(f"{len(summary)}\n\n{summary}") - def test_document_summarizer(self, llm): text_path = str(TEST_DIR / "resources/data/texts/Ray.txt") text = open(text_path, "r").read() @@ -109,7 +73,7 @@ def test_document_summarizer(self, llm): docs = [Document(item) for item in dicts] question = "What is" - doc_summarizer = DocumentSummarizer(llm, question) + doc_summarizer = MultiStepDocumentSummarizer(llm, question) docs[0].text_representation = text[:10000] doc = doc_summarizer.summarize(docs[0]) @@ -117,7 +81,7 @@ def test_document_summarizer(self, llm): def test_document_summarizer_in_sycamore(self, llm): question = "What is" - doc_summarizer = DocumentSummarizer(llm, question) + doc_summarizer = MultiStepDocumentSummarizer(llm, question) path = str(TEST_DIR / "resources/data/pdfs/Ray.pdf") context = sycamore.init(exec_mode=EXEC_RAY) result = ( @@ -138,7 +102,7 @@ def test_summarize_map_reduce(self, llm): docset = ( context.read.binary(path, binary_format="pdf").partition(partitioner=UnstructuredPdfPartitioner()).explode() ) + final_summary = summarize_data(llm, question, result_description="Ray paper", result_data=[docset]) - final_summary = summarize_map_reduce(llm, question, "summary", [docset]) print(final_summary) assert final_summary diff --git a/lib/sycamore/sycamore/tests/unit/query/execution/test_sycamore_executor.py b/lib/sycamore/sycamore/tests/unit/query/execution/test_sycamore_executor.py index 416f132ee..2fe08759e 100644 --- a/lib/sycamore/sycamore/tests/unit/query/execution/test_sycamore_executor.py +++ b/lib/sycamore/sycamore/tests/unit/query/execution/test_sycamore_executor.py @@ -1,11 +1,13 @@ import os import tempfile -from unittest.mock import patch, Mock +from unittest.mock import patch +from typing import Optional import pytest import sycamore from sycamore.llms import LLM +from sycamore.llms.prompts.prompts import RenderedPrompt from sycamore.query.execution.sycamore_executor import SycamoreExecutor from sycamore.query.logical_plan import LogicalPlan from sycamore.query.operators.count import Count @@ -14,6 +16,17 @@ from sycamore.query.result import SycamoreQueryResult +class MockLLM(LLM): + def __init__(self): + pass + + def is_chat_mode(self): + return True + + def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + return "" + + @pytest.fixture def test_count_docs_query_plan() -> LogicalPlan: """A simple query plan which only counts the number of documents.""" @@ -124,7 +137,7 @@ def test_run_summarize_data_plan(mock_sycamore_docsetreader): ): context = sycamore.init( params={ - "default": {"llm": Mock(spec=LLM)}, + "default": {"llm": MockLLM()}, "opensearch": { "os_client_args": { "hosts": [{"host": "localhost", "port": 9200}], diff --git a/lib/sycamore/sycamore/tests/unit/query/execution/test_sycamore_operator.py b/lib/sycamore/sycamore/tests/unit/query/execution/test_sycamore_operator.py index 0f02afe6c..6b0790e71 100644 --- a/lib/sycamore/sycamore/tests/unit/query/execution/test_sycamore_operator.py +++ b/lib/sycamore/sycamore/tests/unit/query/execution/test_sycamore_operator.py @@ -219,7 +219,6 @@ def test_summarize_data(): question=logical_node.question, result_description=logical_node.description, result_data=[load_node], - use_elements=True, **sycamore_operator.get_execute_args(), ) diff --git a/lib/sycamore/sycamore/tests/unit/query/test_operations.py b/lib/sycamore/sycamore/tests/unit/query/test_operations.py index 52807315a..3157d2993 100644 --- a/lib/sycamore/sycamore/tests/unit/query/test_operations.py +++ b/lib/sycamore/sycamore/tests/unit/query/test_operations.py @@ -5,7 +5,6 @@ import sycamore from sycamore.data import Document, Element from sycamore.docset import DocSet -from sycamore.functions import CharacterTokenizer from sycamore.functions.basic_filters import MatchFilter, RangeFilter from sycamore.llms import LLM from sycamore.llms.prompts import RenderedPrompt @@ -16,16 +15,17 @@ from sycamore.query.execution.operations import ( summarize_data, math_operation, - _get_text_for_summarize_data, ) -from sycamore.transforms.summarize import NUM_DOCS_GENERATE +from sycamore.transforms.summarize import NUM_DOCS_GENERATE, MultiStepDocumentSummarizer class MockLLM(LLM): def __init__(self): super().__init__(model_name="mock_model") + self.capture = [] def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> str: + self.capture.append(prompt) if prompt.messages[0].content.endswith('"1, 2, one, two, 1, 3".'): return '{"groups": ["group1", "group2", "group3"]}' if ( @@ -49,6 +49,15 @@ def generate(self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) return "group2" elif value == "3" or value == "three": return "group3" + elif "unique cities" in prompt.messages[-1].content: + if "Summary: " in prompt.messages[-1].content: + return "merged summary" + else: + return "accumulated summary" + elif "elements of a document" in prompt.messages[-1].content: + return "element summary" + elif "element summary" in prompt.messages[-1].content: + return "document summary" else: return "" return "" @@ -100,6 +109,23 @@ def words_and_ids_docset(self, generate_docset) -> DocSet: } return generate_docset(texts) + @pytest.fixture + def big_words_and_ids_docset(self, words_and_ids_docset) -> DocSet: + import random + from copy import deepcopy + + words = ["some", "words", "are", "too", "long"] + docs = words_and_ids_docset.take_all() + big_docs = [] + for i in range(5): + big_docs.extend([deepcopy(d) for d in docs]) + for bd in big_docs: + for i in range(3): + word_choices = random.choices(words, k=10) + bd.elements.append(Element(text_representation=" ".join(word_choices))) + ctx = words_and_ids_docset.context + return ctx.read.document(big_docs) + @pytest.fixture def number_docset(self, generate_docset) -> DocSet: return generate_docset( @@ -115,80 +141,46 @@ def test_summarize_data(self, words_and_ids_docset): assert response == "" def test_get_text_for_summarize_data_docset(self, words_and_ids_docset): - response = _get_text_for_summarize_data( - result_description="List of unique cities", - result_data=[words_and_ids_docset], - use_elements=False, - num_elements=10, - ) - expected = "Data description: List of unique cities\nInput 1:\n" - for i, doc in enumerate(words_and_ids_docset.take(NUM_DOCS_GENERATE)): - expected += f"Document {i}:\n" - expected += f"Text contents:\n{doc.text_representation or ''}\n\n" - - assert response == expected - - def test_get_text_for_summarize_data_docset_with_token_limit(self, words_and_ids_docset): - response = _get_text_for_summarize_data( - result_description="List of unique cities", - result_data=[words_and_ids_docset], - use_elements=False, - num_elements=10, - tokenizer=CharacterTokenizer(), - max_tokens=150, # description + text of 2 documents (manually calculated) - ) - - expected = "Data description: List of unique cities\nInput 1:\n" - docs = words_and_ids_docset.take() - for i, doc in enumerate(docs[:2]): - expected += f"Document {i}:\n" - expected += f"Text contents:\n{doc.text_representation or ''}\n\n" - - assert response == expected - - # 1 more char allows another doc - response_3_docs = _get_text_for_summarize_data( + llm = MockLLM() + summarize_data( + llm=llm, + question=None, result_description="List of unique cities", result_data=[words_and_ids_docset], - use_elements=False, - num_elements=10, - tokenizer=CharacterTokenizer(), - max_tokens=151, # description + text of 3 documents (manually calculated) + docset_summarizer=MultiStepDocumentSummarizer( + llm=llm, question=None, data_description="List of unique cities" + ), ) + captured = llm.capture[-1] + mcontent = captured.messages[-1].content - expected_3_docs = expected + "Document 2:\n" - expected_3_docs += f"Text contents:\n{docs[2].text_representation or ''}\n\n" - - assert response_3_docs == expected_3_docs + assert "List of unique cities" in mcontent + for i, doc in enumerate(words_and_ids_docset.take(NUM_DOCS_GENERATE)): + assert f"Text: {doc.text_representation}" in mcontent - @pytest.mark.parametrize("num_elements", [1, None]) - def test_get_text_for_summarize_data_docset_with_elements(self, words_and_ids_docset, num_elements): - response = _get_text_for_summarize_data( + def test_get_text_for_summarize_data_docset_with_elements(self, big_words_and_ids_docset): + llm = MockLLM() + response = summarize_data( + llm=llm, + question=None, result_description="List of unique cities", - result_data=[words_and_ids_docset], - use_elements=True, - num_elements=num_elements, + result_data=[big_words_and_ids_docset], + summaries_as_text=True, + docset_summarizer=MultiStepDocumentSummarizer( + llm=llm, question=None, data_description="List of unique cities", max_tokens=1000 + ), ) - - expected = "Data description: List of unique cities\nInput 1:\n" - for i, doc in enumerate(words_and_ids_docset.take(NUM_DOCS_GENERATE)): - expected += f"Document {i}:\n" - expected += "Text contents:\n" - for e in doc.elements[: (num_elements or NUM_DOCS_GENERATE)]: - expected += f"{e.text_representation or ''}\n" - expected += "\n\n" - - assert response == expected + captured = llm.capture + assert len(captured) == 44 # 44 llm calls + assert response == "merged summary" def test_get_text_for_summarize_data_non_docset(self, words_and_ids_docset): - response = _get_text_for_summarize_data( - result_description="Count of unique cities", result_data=[20], use_elements=False, num_elements=5 - ) - assert response == "Data description: Count of unique cities\nInput 1:\n20\n" - response = _get_text_for_summarize_data( - result_description="Count of unique cities", result_data=[20], use_elements=True, num_elements=5 - ) - assert response == "Data description: Count of unique cities\nInput 1:\n20\n" + llm = MockLLM() + _ = summarize_data(llm=llm, question=None, result_description="Count of unique cities", result_data=[20]) + print(llm.capture) + captured = llm.capture[-1].messages[-1].content + assert "Count of unique cities" in captured + assert "Input 1: 20" in captured # Math def test_math(self): diff --git a/lib/sycamore/sycamore/tests/unit/test_docset.py b/lib/sycamore/sycamore/tests/unit/test_docset.py index 0536bc5d1..0aee35506 100644 --- a/lib/sycamore/sycamore/tests/unit/test_docset.py +++ b/lib/sycamore/sycamore/tests/unit/test_docset.py @@ -18,7 +18,6 @@ Embedder, Embed, Partitioner, - Summarize, FlatMap, Map, MapBatch, @@ -29,7 +28,7 @@ ) from sycamore.transforms import Filter from sycamore.transforms.base import get_name_from_callable, CompositeTransform -from sycamore.transforms.base_llm import LLMMap +from sycamore.transforms.base_llm import LLMMap, LLMMapElements from sycamore.transforms.extract_entity import OpenAIEntityExtractor from sycamore.transforms.extract_schema import SchemaExtractor, LLMPropertyExtractor from sycamore.transforms.query import QueryExecutor @@ -246,7 +245,7 @@ def test_summarize(self, mocker): llm = mocker.Mock(spec=LLM) docset = DocSet(context, None) docset = docset.summarize(llm=llm, summarizer=LLMElementTextSummarizer(llm)) - assert isinstance(docset.lineage(), Summarize) + assert isinstance(docset.lineage(), LLMMapElements) def test_filter(self, mocker): context = mocker.Mock(spec=Context) diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_summarize.py b/lib/sycamore/sycamore/tests/unit/transforms/test_summarize.py index f749280eb..8ece8102a 100644 --- a/lib/sycamore/sycamore/tests/unit/transforms/test_summarize.py +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_summarize.py @@ -3,7 +3,13 @@ from sycamore.data import Document, Element from sycamore.llms import LLM -from sycamore.transforms.summarize import LLMElementTextSummarizer +from sycamore.transforms.summarize import ( + LLMElementTextSummarizer, + OneStepDocumentSummarizer, + MultiStepDocumentSummarizer, + EtCetera, +) +from sycamore.transforms.standardizer import USStateStandardizer class TestSummarize: @@ -21,7 +27,7 @@ def test_summarize_text_does_not_call_llm(self, mocker): def test_summarize_text_calls_llm(self, mocker): llm = mocker.Mock(spec=LLM) - generate = mocker.patch.object(llm, "generate_old") + generate = mocker.patch.object(llm, "generate") generate.return_value = "this is the summary" doc = Document() element1 = Element() @@ -37,5 +43,260 @@ def test_summarize_text_calls_llm(self, mocker): assert doc.elements[1].properties == {"summary": "this is the summary"} +class TestMultiStepSummarize: + doc = Document( + elements=[ + Element(text_representation="aaaaaaaa", properties={"key": "m"}), + Element(text_representation="bbbbbbbb", properties={"key": "n"}), + Element(text_representation="cccccccc", properties={"key": "o"}), + Element(text_representation="dddddddd", properties={"key": "p"}), + Element(text_representation="eeeeeeee", properties={"key": "q"}), + ] + ) + + def test_base(self, mocker): + llm = mocker.Mock(spec=LLM) + generate = mocker.patch.object(llm, "generate") + generate.return_value = "sum" + summarizer = MultiStepDocumentSummarizer(llm=llm) + d = summarizer.summarize(self.doc) + + assert d.properties["summary"] == "sum" + assert generate.call_count == 1 + prompt = generate.call_args.kwargs["prompt"] + usermessage = prompt.messages[-1].content + for e in self.doc.elements: + assert f"Text: {e.text_representation}" in usermessage + assert f"properties.key: {e.properties['key']}" not in usermessage + + def test_multistep_set_fields(self, mocker): + llm = mocker.Mock(spec=LLM) + generate = mocker.patch.object(llm, "generate") + generate.return_value = "sum" + summarizer = MultiStepDocumentSummarizer(llm=llm, fields=["properties.key"]) + d = summarizer.summarize(self.doc) + + assert d.properties["summary"] == "sum" + assert generate.call_count == 1 + prompt = generate.call_args.kwargs["prompt"] + usermessage = prompt.messages[-1].content + for e in self.doc.elements: + assert f"Text: {e.text_representation}" in usermessage + assert f"properties.key: {e.properties['key']}" in usermessage + + def test_multistep_all_fields(self, mocker): + llm = mocker.Mock(spec=LLM) + generate = mocker.patch.object(llm, "generate") + generate.return_value = "sum" + summarizer = MultiStepDocumentSummarizer(llm=llm, fields="*") + d = summarizer.summarize(self.doc) + + assert d.properties["summary"] == "sum" + assert generate.call_count == 1 + prompt = generate.call_args.kwargs["prompt"] + usermessage = prompt.messages[-1].content + for e in self.doc.elements: + assert f"Text: {e.text_representation}" in usermessage + assert f"key: {e.properties['key']}" in usermessage + + def test_multistep_set_question(self, mocker): + llm = mocker.Mock(spec=LLM) + generate = mocker.patch.object(llm, "generate") + generate.return_value = "sum" + summarizer = MultiStepDocumentSummarizer(llm=llm, question="loser says what?") + d = summarizer.summarize(self.doc) + + assert d.properties["summary"] == "sum" + assert generate.call_count == 1 + prompt = generate.call_args.kwargs["prompt"] + usermessage = prompt.messages[-1].content + assert "loser says what?" in usermessage + for e in self.doc.elements: + assert f"Text: {e.text_representation}" in usermessage + assert f"properties.key: {e.properties['key']}" not in usermessage + + def test_small_token_limit(self, mocker): + llm = mocker.Mock(spec=LLM) + generate = mocker.patch.object(llm, "generate") + generate.return_value = "sum" + summarizer = MultiStepDocumentSummarizer(llm=llm, max_tokens=470) # 310 chars = first 3 elements + d = summarizer.summarize(self.doc) + + assert d.properties["summary"] == "sum" + assert generate.call_count == 3 + first_call = generate.call_args_list[0] + second_call = generate.call_args_list[1] + third_call = generate.call_args_list[2] + + prompt = first_call.kwargs["prompt"] + usermessage = prompt.messages[-1].content + for e in self.doc.elements[:3]: + assert f"Text: {e.text_representation}" in usermessage + for e in self.doc.elements[3:]: + assert f"Text: {e.text_representation}" not in usermessage + + prompt = second_call.kwargs["prompt"] + usermessage = prompt.messages[-1].content + for e in self.doc.elements[:3]: + assert f"Text: {e.text_representation}" not in usermessage + for e in self.doc.elements[3:]: + assert f"Text: {e.text_representation}" in usermessage + + prompt = third_call.kwargs["prompt"] + usermessage = prompt.messages[-1].content + assert occurrences(usermessage, "Summary: sum") == 2 + + +class TestOneStepSummarize: + doc = Document( + elements=[ + Element( + text_representation="Something very long", + properties={ + f"state_{i}": s + for i, s in enumerate(list(USStateStandardizer.state_abbreviations.values()) + ["Canada"]) + } + | {"title": "Title A"}, + elements=[Element(text_representation="subelement 1"), Element(text_representation="subelement 2")], + ), + Element( + text_representation="Something very long part 2", + properties={ + f"state_{i}": s + for i, s in enumerate( + sorted( + list(USStateStandardizer.state_abbreviations.values()) + ["Canada"], + key=lambda state: state[1:], + ) + ) + } + | {"title": "Title B"}, + elements=[ + Element(text_representation="subelement 1"), + Element(text_representation="subelement 2"), + Element(text_representation="subelement 3"), + Element(text_representation="subelement 4"), + Element(text_representation="subelement 5"), + ] + * 10, + ), + Element( + text_representation="Something very long part 3", + properties={ + f"state_{i}": s + for i, s in enumerate( + sorted( + list(USStateStandardizer.state_abbreviations.values()) + ["Canada"], + key=lambda state: state[3:], + ) + ) + } + | {"title": "Title C"}, + ), + Element( + text_representation="Something very long part 4", + properties={ + f"state_{i}": s + for i, s in enumerate(reversed(list(USStateStandardizer.state_abbreviations.values()) + ["Canada"])) + } + | {"title": "Title D"}, + ), + ] + ) + + def test_basic(self, mocker): + llm = mocker.Mock(spec=LLM) + generate = mocker.patch.object(llm, "generate") + generate.return_value = "sum" + summarizer = OneStepDocumentSummarizer(llm, question="say what?") + d = summarizer.summarize(self.doc) + + assert d.properties["summary"] == "sum" + assert generate.call_count == 1 + prompt = generate.call_args.kwargs["prompt"] + usermessage = prompt.messages[-1].content + assert occurrences(usermessage, "subelement") == 52 + assert "say what?" in usermessage + for e in self.doc.elements: + for p in e.properties: + assert f"properties.{p}: {e.properties[p]}" in usermessage + + def test_title_first(self, mocker): + llm = mocker.Mock(spec=LLM) + generate = mocker.patch.object(llm, "generate") + generate.return_value = "sum" + summarizer = OneStepDocumentSummarizer(llm, question="say what?", fields=["properties.title", EtCetera]) + d = summarizer.summarize(self.doc) + + assert d.properties["summary"] == "sum" + assert generate.call_count == 1 + prompt = generate.call_args.kwargs["prompt"] + usermessage = prompt.messages[-1].content + assert occurrences(usermessage, "subelement") == 52 + # Intro to each element is "Entry i:" + # other properties are "property name: property value" + # if last nonwhitespace before properties.title is ':', + # then properties.title was the first property + before_title = usermessage.split(sep="properties.title")[:-1] + assert all(b.strip().endswith(":") for b in before_title) + assert "say what?" in usermessage + for e in self.doc.elements: + for p in e.properties: + assert f"properties.{p}: {e.properties[p]}" in usermessage + + def test_only_title(self, mocker): + llm = mocker.Mock(spec=LLM) + generate = mocker.patch.object(llm, "generate") + generate.return_value = "sum" + summarizer = OneStepDocumentSummarizer(llm, question="say what?", fields=["properties.title"]) + d = summarizer.summarize(self.doc) + + assert d.properties["summary"] == "sum" + assert generate.call_count == 1 + prompt = generate.call_args.kwargs["prompt"] + usermessage = prompt.messages[-1].content + assert occurrences(usermessage, "subelement") == 52 + assert occurrences(usermessage, "properties.title") == 4 + assert "say what?" in usermessage + assert "properties.state" not in usermessage + + def test_too_many_tokens(self, mocker): + llm = mocker.Mock(spec=LLM) + generate = mocker.patch.object(llm, "generate") + generate.return_value = "sum" + summarizer = OneStepDocumentSummarizer(llm, question="say what?", fields=["properties.title"], token_limit=1000) + d = summarizer.summarize(self.doc) + + assert d.properties["summary"] == "sum" + assert generate.call_count == 1 + prompt = generate.call_args.kwargs["prompt"] + usermessage = prompt.messages[-1].content + assert occurrences(usermessage, "subelement") == 32 + assert occurrences(usermessage, "properties.title") == 4 + assert "say what?" in usermessage + assert "properties.state" not in usermessage + + def test_too_many_tokens_takes_evenly(self, mocker): + llm = mocker.Mock(spec=LLM) + generate = mocker.patch.object(llm, "generate") + generate.return_value = "sum" + summarizer = OneStepDocumentSummarizer(llm, question="say what?", fields=["properties.title"], token_limit=500) + d = summarizer.summarize(self.doc) + + assert d.properties["summary"] == "sum" + assert generate.call_count == 1 + prompt = generate.call_args.kwargs["prompt"] + usermessage = prompt.messages[-1].content + assert occurrences(usermessage, "subelement") == 2 + assert "subelement 2" not in usermessage + assert occurrences(usermessage, "properties.title") == 4 + assert "say what?" in usermessage + assert "properties.state" not in usermessage + + def filter_elements_on_length(element: Element) -> bool: return False if element.text_representation is None else len(element.text_representation) > 10 + + +def occurrences(superstring: str, substring: str) -> int: + return len(superstring.split(sep=substring)) - 1 diff --git a/lib/sycamore/sycamore/transforms/summarize.py b/lib/sycamore/sycamore/transforms/summarize.py index bd54452a5..ee2dfa444 100644 --- a/lib/sycamore/sycamore/transforms/summarize.py +++ b/lib/sycamore/sycamore/transforms/summarize.py @@ -1,17 +1,24 @@ -import logging -import time from abc import ABC, abstractmethod -from typing import Callable, Optional +from typing import Callable, Optional, Literal, Union, Type +import copy +import textwrap from sycamore.data import Element, Document -from sycamore.functions import Tokenizer, CharacterTokenizer -from sycamore.llms.prompts.default_prompts import SummarizeDataMessagesPrompt +from sycamore.functions.tokenizer import Tokenizer, CharacterTokenizer +from sycamore.llms.prompts.default_prompts import ( + TextSummarizerJinjaPrompt, +) +from sycamore.llms.prompts.prompts import ( + JinjaElementPrompt, + SycamorePrompt, + JinjaPrompt, +) from sycamore.plan_nodes import NonCPUUser, NonGPUUser, Node from sycamore.llms import LLM -from sycamore.llms.prompts.default_prompts import _TextSummarizerGuidancePrompt from sycamore.transforms.map import Map -from sycamore.utils.time_trace import timetrace +from sycamore.transforms.base import CompositeTransform +from sycamore.transforms.base_llm import LLMMapElements, LLMMap NUM_DOCS_GENERATE = 60 NUM_TEXT_CHARS_GENERATE = 2500 @@ -30,8 +37,13 @@ class Summarizer(ABC): - @abstractmethod def summarize(self, document: Document) -> Document: + map = self.as_llm_map(None) + assert hasattr(map, "_local_process") + return map._local_process([document])[0] + + @abstractmethod + def as_llm_map(self, child: Optional[Node], **kwargs) -> Node: pass @@ -61,133 +73,324 @@ def __init__(self, llm: LLM, element_operator: Optional[Callable[[Element], bool self._llm = llm self._element_operator = element_operator - def summarize(self, document: Document) -> Document: - elements = [] + def as_llm_map(self, child: Optional[Node], **kwargs) -> Node: if self._element_operator is not None: - for element in document.elements: - if self._element_operator(element): - elements.append(self._summarize_text_element(element)) - else: - elements.append(element) + return LLMMapElements( + child, TextSummarizerJinjaPrompt, output_field="summary", llm=self._llm, filter=self._element_operator + ) else: - elements = [self._summarize_text_element(element) for element in document.elements] - - document.elements = elements - return document - - @timetrace("SummText") - def _summarize_text_element(self, element: Element) -> Element: - prompt = _TextSummarizerGuidancePrompt() - - if element.text_representation: - response = self._llm.generate_old(prompt_kwargs={"prompt": prompt, "query": element.text_representation}) - element.properties["summary"] = response - return element + return LLMMapElements(child, TextSummarizerJinjaPrompt, output_field="summary", llm=self._llm) + + +MaxTokensHeirarchyPrompt = JinjaElementPrompt( + system=textwrap.dedent( + """ + {% if question is defined %}You are a helpful research assistant. You answer questions based on + text you are presented with. + {% else %}You are a helpful data summarizer. You concisely summarize text you are presented with, + including as much detail as possible. + {% endif %} + """ + ), + user=textwrap.dedent( + """ + {#- + get_text macro: returns text for an element. If this is the first + round of summarization: + If `fields` is provided to the template, add a list of key-value + pairs to the text (if fields is the string "*", use all properties). + Always include the text representation + If this is after the first round of summarization: + use only the element's summary field + -#} + {%- macro get_text(element, itvarname) %} + {%- if elt.properties[itvarname] == 0 -%} + {%- if fields is defined -%} + {%- if fields == "*" %}{% for p in element.properties %} + {%- if p.startswith('_') %}{% continue %}{% endif %} + {{ p }}: {{ element.properties[p] }} + {%- endfor -%} + {%- else %}{% for f in fields %} + {{ f }}: {{ element.field_to_value(f) }} + {%- endfor %}{% endif -%} + {%- endif -%} + Text: {{ element.text_representation }} + {%- else -%} + Summary: {{ element.properties[intermediate_summary_key] }} + {% endif -%} + {% endmacro -%} + + {%- macro get_data_description() -%} + {%- if data_description is defined -%} + {{ data_description }} + {%- else -%} + some documents + {%- endif -%} + {%- endmacro -%} + + {%- if elt.properties[skip_me_key] -%}{{ norender() }}{%- endif -%} + {%- if (batch_key in elt.properties and elt.properties[batch_key]|count == 1 + and intermediate_summary_key in elt.properties) -%}{{ norender() }}{%- endif -%} + {%- if batch_key not in elt.properties -%}{{ norender() }}{%- endif -%} + + {% if elt.properties[round_key] == 0 -%} + You are given {{ get_data_description() }}. Please use only the information found in these elements + to determine an answer to the question "{{ question }}". If you cannot answer the question based on + the data provided, instead respond with any data that might be relevant to the question. + Elements: + {% else %} + You are given a list of partial answers to the question "{{ question }}" based on {{ get_data_description() }}. + Please combine these partial answers into a coherent single answer to the question "{{ question }}". + Some answers may not be particularly relevent, so don't pay them too much mind. + Answers: + {%- endif -%} + {%- for idx in elt.properties[batch_key] %} + {{ loop.index }}: {{ get_text(doc.elements[idx], round_key) }} + {% endfor %} + """ + ), + question="What is the summary of this data?", +) + + +class MultiStepDocumentSummarizer(Summarizer): + """ + Summarizes a document by constructing a tree of summaries. Each leaf contains as many consecutive + elements as possible within the token limit, and each vertex of the tree contains as many sub- + summaries as possible within the token limit. e.g with max_tokens=10 + Elements: (3 tokens) - (3 tokens) - (5 tokens) - (8 tokens) + | | | | + (4 token summary) - (3 token summary) - (2 token summary) + \\ | / + (5 token summary) + Args: + llm: LLM to use for summarization + question: Optional question to use as context for the summarization. If set, the llm will + attempt to answer the question with the data provided + prompt: Prompt to use for each summarization. Caution: The default (MaxTokensHeirarchicalSummarizerPrompt) + has some fairly complicated logic encoded in it to make the tree construction work correctly. + fields: List of fields to include in each element's representation in the prompt. Specify + with dotted notation (e.g. properties.title), or use "*" to capture everything. If None, + will include no fields. + max_tokens: token limit for each summarization. Default is 10k (default tokenizer is by character). + tokenizer: tokenizer to use when computing how many tokens a prompt will take. Default is + CharacterTokenizer + rounds: number of rounds of heirarchical summarization to perform. The number of elements that can be + included in the summary is O(e^rounds), so rounds can be small. Default is 4. + """ -class QuestionAnsweringSummarizer: - def __init__(self, llm: LLM, question: str): + def __init__( + self, + llm: LLM, + question: Optional[str] = None, + data_description: Optional[str] = None, + prompt: SycamorePrompt = MaxTokensHeirarchyPrompt, + fields: Union[None, Literal["*"], list[str]] = None, + max_tokens: int = 10 * 1000, + tokenizer: Tokenizer = CharacterTokenizer(), + rounds: int = 4, + ): self.llm = llm + self.prompt = prompt.set(**self.get_const_vars()) + self.fields = fields self.question = question - - def __call__(self, text: str) -> str: - messages = SummarizeDataMessagesPrompt(question=self.question, text=text).as_messages() - prompt_kwargs = {"messages": messages} - - t0 = time.time() - # call to LLM - summary = self.llm.generate_old(prompt_kwargs=prompt_kwargs, llm_kwargs={"temperature": 0}) - t1 = time.time() - logging.info(f"Summarizer took {t1 - t0} seconds to generate summary.") - - return summary - - -def collapse(text: str, tokens_per_chunk: int, tokenizer: Tokenizer, summarizer_fn: Callable[[str], str]) -> str: + self.data_description = data_description + self.max_tokens = max_tokens + self.tokenizer = tokenizer + self.rounds = 4 + + @staticmethod + def get_const_vars() -> dict[str, str]: + return { + "skip_me_key": "_skip_me", + "batch_key": "_batch", + "round_key": "_round", + "intermediate_summary_key": "_summary", + } + + def prep_batches(self, doc: Document, round: int = 0) -> Document: + vars = self.get_const_vars() + for i, elt in enumerate(doc.elements): + elt.properties[vars["round_key"]] = round + if vars["skip_me_key"] not in elt.properties: + elt.properties[vars["skip_me_key"]] = False + if elt.properties[vars["skip_me_key"]]: + continue + this_batch = [i] + elt.properties[vars["batch_key"]] = this_batch + for j in range(i + 1, len(doc.elements)): + e2 = doc.elements[j] + if e2.properties.get(vars["skip_me_key"], False): + continue + this_batch.append(j) + tks = self.prompt.render_element(elt, doc).token_count(self.tokenizer) + if tks > self.max_tokens: + this_batch.pop() + break + e2.properties[vars["skip_me_key"]] = True + return doc + + def cleanup(self, doc: Document) -> Document: + if len(doc.elements) == 0: + return doc + vars = self.get_const_vars() + doc.properties["summary"] = doc.elements[0].properties[vars["intermediate_summary_key"]] + for e in doc.elements: + for v in vars: + if v in e.properties: + del e.properties[v] + return doc + + def as_llm_map(self, child: Optional[Node], **kwargs) -> Node: + vars = self.get_const_vars() + if self.fields is not None: + self.prompt = self.prompt.set(fields=self.fields) + if self.question is not None: + self.prompt = self.prompt.set(question=self.question) + if self.data_description is not None: + self.prompt = self.prompt.set(data_description=self.data_description) + nodes = [] + last = child + for round in range(self.rounds): + prep_round = Map(child=last, f=self.prep_batches, kwargs={"round": round}) + llm_round = LLMMapElements( + child=prep_round, + prompt=self.prompt, + output_field=vars["intermediate_summary_key"], + llm=self.llm, + ) + nodes.extend([prep_round, llm_round]) + last = llm_round + cleanup = Map(child=last, f=self.cleanup) + nodes.append(cleanup) + ct = CompositeTransform(child, []) # type: ignore + ct.nodes = nodes + return ct + + +OneStepSummarizerPrompt = JinjaPrompt( + system="You are a helpful text summarizer", + user=textwrap.dedent( + """ + You are given a series of database entries that answer the question "{{ question }}". + Generate a concise, conversational summary of the data to answer the question. + {%- for elt in doc.elements %} + Entry {{ loop.index }}: + {% for f in doc.properties[fields_key] %}{#{% if f.startswith("_") %}{% continue %}{% endif %}#} + {{ f }}: {{ elt.field_to_value(f) }} + {% endfor -%} + {%- if doc.properties[numel_key] is not none and doc.properties[numel_key] > 0 %} Text: + {% endif -%} + {%- for subel in elt.data.get("elements", [])[:doc.properties[numel_key]] -%} + {{ subel.text_representation }} + {% endfor %} + {% endfor %} + """ + ), +) + + +class EtCetera: + """Sentinel value to sit at the end of a list of fields, signifying 'add as + many additional properties as you can within the token limit'""" + + +class OneStepDocumentSummarizer(Summarizer): """ - Collapses text iteratively, summarizing the first chunk and incorporating it in the summary for the next chunk. + Summarizes a document in a single LLM call by taking as much data as possible + from every element, spread across them evenly. Intended for use with summarize_data, + where a summarizer is used to summarize an entire docset. Args: - text: Text to collapse. - chunk_size: Size of each chunk. - tokenizer: Tokenizer to use for counting against max_tokens. + llm: LLM to use for summarization + question: Question to use as context for the summary. The llm will attempt to + use the data provided to answer the question. + token_limit: Token limit for the prompt. Default is 10k (default tokenizer is + by character) + tokenizer: Tokenizer to use to count tokens (to not exceed the token limit). + Default is CharacterTokenizer + fields: List of fields to include from every element. To include any additional + fields (after the ones specified), end the list with `EtCetera`. Default is + empty list, which stands for 'as many fields as fit within the token limit' + and is equivalent to `[EtCetera]` - Returns: - List of chunks. """ - tokens = tokenizer.tokenize(text) - total = len(tokens) - if total <= tokens_per_chunk: - return text - done = False - i = 0 - additional = i + tokens_per_chunk - cur_summary = "" - while not done: - input = "" - if cur_summary: - input = f"{cur_summary}\n" - input += "".join([str(tk) for tk in tokens[i : i + additional]]) # make mypy happy - print(f"input size: {len(input)}") - cur_summary = summarizer_fn(input) - assert ( - len(cur_summary) <= tokens_per_chunk - ), f"Summarizer output is longer than input chunk {len(cur_summary)} > {tokens_per_chunk} !!!" - print(f"summary to chunk ratio: {len(cur_summary) / tokens_per_chunk}") - i += additional - remaining = tokens_per_chunk - len(cur_summary) - additional = min(remaining, total - i) - if additional == 0: - break - - return cur_summary - - -class DocumentSummarizer(Summarizer): + def __init__( self, llm: LLM, question: str, - chunk_size: int = 10 * 1000, + token_limit: int = 10 * 1000, tokenizer: Tokenizer = CharacterTokenizer(), - chunk_overlap: int = 0, - use_elements: bool = False, - num_elements: int = 5, + fields: list[Union[str, Type[EtCetera]]] = [], ): self.llm = llm self.question = question - self.chunk_size = chunk_size + self.token_limit = token_limit self.tokenizer = tokenizer - self.chunk_overlap = chunk_overlap - self.use_elements = use_elements - self.num_elements = num_elements - - def summarize(self, document: Document) -> Document: - text = self.get_text(document) - summary = collapse(text, self.chunk_size, self.tokenizer, QuestionAnsweringSummarizer(self.llm, self.question)) - document.properties["summary"] = summary - return document - - def get_text(self, doc: Document) -> str: - doc_text = "" - props_dict = doc.properties.get("entity", {}) - props_dict.update({p: doc.properties[p] for p in set(doc.properties) - set(BASE_PROPS)}) - for k, v in props_dict.items(): - doc_text += f"{k}: {v}\n" - - doc_text_representation = "" - if not self.use_elements: - if doc.text_representation is not None: - doc_text_representation += doc.text_representation[:NUM_TEXT_CHARS_GENERATE] - else: - for element in doc.elements[: self.num_elements]: - # Greedy fill doc level text length - if len(doc_text_representation) >= NUM_TEXT_CHARS_GENERATE: - break - doc_text_representation += (element.text_representation or "") + "\n" - doc_text += f"Text contents:\n{doc_text_representation}\n" - - return doc_text + assert EtCetera not in fields[:-1], "EtCetera must be at the end of the list of fields if provided" + self.fields = fields + self.prompt = OneStepSummarizerPrompt.set(**self.get_const_vars()) + + @staticmethod + def get_const_vars() -> dict[str, str]: + return { + "fields_key": "_fields", + "numel_key": "_num_elements", + } + + def preprocess(self, doc: Document) -> Document: + vars = self.get_const_vars() + fields = copy.deepcopy(self.fields) + etc = False + if len(fields) > 0 and fields[-1] is EtCetera: + etc = True + fields = fields[:-1] + all_element_property_names = {f"properties.{k}" for e in doc.elements for k in e.properties} + doc.properties[vars["fields_key"]] = fields + doc.properties[vars["numel_key"]] = 0 + last = self.prompt.render_document(doc) + if len(fields) == 0 or etc: + for p in all_element_property_names: + if p in fields: + continue + fields.append(p) + last = self.prompt.render_document(doc) + ntk = last.token_count(self.tokenizer) + if ntk > self.token_limit: + fields.pop() + return doc + doc.properties[vars["numel_key"]] += 1 + this = self.prompt.render_document(doc) + while last != this: + ntk = this.token_count(self.tokenizer) + if ntk > self.token_limit: + doc.properties[vars["numel_key"]] -= 1 + return doc + last = this + doc.properties[vars["numel_key"]] += 1 + this = self.prompt.render_document(doc) + return doc + + def cleanup(self, doc: Document) -> Document: + vars = self.get_const_vars() + if vars["fields_key"] in doc.properties: + del doc.properties[vars["fields_key"]] + if vars["numel_key"] in doc.properties: + del doc.properties[vars["numel_key"]] + return doc + + def as_llm_map(self, child: Optional[Node], **kwargs): + prompt = self.prompt + if self.question is not None: + prompt = prompt.set(question=self.question) + preprocess = Map(child, f=self.preprocess) + llm_map = LLMMap(preprocess, prompt=prompt, output_field="summary", llm=self.llm, **kwargs) + postprocess = Map(llm_map, f=self.cleanup) + comptransform = CompositeTransform(child, []) # type: ignore + comptransform.nodes = [preprocess, llm_map, postprocess] + return comptransform class Summarize(NonCPUUser, NonGPUUser, Map):