Skip to content

Commit

Permalink
mypy + tests
Browse files Browse the repository at this point in the history
Signed-off-by: Henry Lindeman <[email protected]>
  • Loading branch information
HenryL27 committed Feb 27, 2025
1 parent 5b267a7 commit d156b03
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
9 changes: 5 additions & 4 deletions lib/sycamore/sycamore/query/execution/operations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Any, List, Union, Optional
from typing import Any, List, Union, Optional, Type

import structlog

Expand All @@ -21,14 +21,15 @@

log = structlog.get_logger(__name__)
# multistep
DEFAULT_DOCSET_SUMMARIZER_CLS = MultiStepDocumentSummarizer
DEFAULT_DOCSET_SUMMARIZER_CLS = MultiStepDocumentSummarizer # type: ignore

DEFAULT_SUMMARIZER_KWARGS: dict[str, Any] = {
"fields": "*",
"tokenizer": OpenAITokenizer("gpt-4o"),
"max_tokens": 80_000,
}
# onestep
DEFAULT_DOCSET_SUMMARIZER_CLS = OneStepDocumentSummarizer
DEFAULT_DOCSET_SUMMARIZER_CLS = OneStepDocumentSummarizer # type: ignore
DEFAULT_SUMMARIZER_KWARGS = {"fields": [EtCetera], "tokenizer": OpenAITokenizer("gpt-4o"), "token_limit": 80_000}


Expand Down Expand Up @@ -92,7 +93,7 @@ def summarize_data(
Conversational response to question.
"""
if docset_summarizer is None:
docset_summarizer = DEFAULT_DOCSET_SUMMARIZER_CLS(llm=llm, question=question, **DEFAULT_SUMMARIZER_KWARGS)
docset_summarizer = DEFAULT_DOCSET_SUMMARIZER_CLS(llm=llm, question=question, **DEFAULT_SUMMARIZER_KWARGS) # type: ignore

if all(isinstance(d, DocSet) for d in result_data):
return summarize_data_docsets(
Expand Down
15 changes: 12 additions & 3 deletions lib/sycamore/sycamore/tests/unit/query/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
summarize_data,
math_operation,
)
from sycamore.transforms.summarize import NUM_DOCS_GENERATE
from sycamore.transforms.summarize import NUM_DOCS_GENERATE, MultiStepDocumentSummarizer


class MockLLM(LLM):
Expand Down Expand Up @@ -143,7 +143,13 @@ def test_summarize_data(self, words_and_ids_docset):
def test_get_text_for_summarize_data_docset(self, words_and_ids_docset):
llm = MockLLM()
summarize_data(
llm=llm, question=None, result_description="List of unique cities", result_data=[words_and_ids_docset]
llm=llm,
question=None,
result_description="List of unique cities",
result_data=[words_and_ids_docset],
docset_summarizer=MultiStepDocumentSummarizer(
llm=llm, question=None, data_description="List of unique cities"
),
)
captured = llm.capture[-1]
mcontent = captured.messages[-1].content
Expand All @@ -160,9 +166,12 @@ def test_get_text_for_summarize_data_docset_with_elements(self, big_words_and_id
result_description="List of unique cities",
result_data=[big_words_and_ids_docset],
summaries_as_text=True,
docset_summarizer=MultiStepDocumentSummarizer(
llm=llm, question=None, data_description="List of unique cities", max_tokens=1000
),
)
captured = llm.capture
assert len(captured) == 45 # 45 llm calls
assert len(captured) == 44 # 44 llm calls
assert response == "merged summary"

def test_get_text_for_summarize_data_non_docset(self, words_and_ids_docset):
Expand Down
5 changes: 4 additions & 1 deletion lib/sycamore/sycamore/transforms/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def __init__(
self,
llm: LLM,
question: Optional[str] = None,
data_description: Optional[str] = None,
prompt: SycamorePrompt = MaxTokensHeirarchyPrompt,
fields: Union[None, Literal["*"], list[str]] = None,
max_tokens: int = 10 * 1000,
Expand All @@ -194,6 +195,7 @@ def __init__(
self.prompt = prompt.set(**self.get_const_vars())
self.fields = fields
self.question = question
self.data_description = data_description
self.max_tokens = max_tokens
self.tokenizer = tokenizer
self.rounds = 4
Expand Down Expand Up @@ -246,6 +248,8 @@ def as_llm_map(self, child: Optional[Node], **kwargs) -> Node:
self.prompt = self.prompt.set(fields=self.fields)
if self.question is not None:
self.prompt = self.prompt.set(question=self.question)
if self.data_description is not None:
self.prompt = self.prompt.set(data_description=self.data_description)
nodes = []
last = child
for round in range(self.rounds):
Expand Down Expand Up @@ -361,7 +365,6 @@ def preprocess(self, doc: Document) -> Document:
this = self.prompt.render_document(doc)
while last != this:
ntk = this.token_count(self.tokenizer)
print(ntk)
if ntk > self.token_limit:
doc.properties[vars["numel_key"]] -= 1
return doc
Expand Down

0 comments on commit d156b03

Please sign in to comment.