Skip to content

Commit 0c11c7b

Browse files
authored
fix: Bring in fix from custom nodes (#8539)
* Bring in fix from custom nodes * Add to_dict function and test * reno * Fix pylint
1 parent f5683bc commit 0c11c7b

File tree

3 files changed

+104
-16
lines changed

3 files changed

+104
-16
lines changed

haystack/components/preprocessors/nltk_document_splitter.py

+42-5
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
import re
66
from copy import deepcopy
77
from pathlib import Path
8-
from typing import Any, Dict, List, Literal, Tuple
8+
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
99

1010
from haystack import Document, component, logging
1111
from haystack.components.preprocessors.document_splitter import DocumentSplitter
12+
from haystack.core.serialization import default_to_dict
1213
from haystack.lazy_imports import LazyImport
14+
from haystack.utils import serialize_callable
1315

1416
with LazyImport("Run 'pip install nltk'") as nltk_imports:
1517
import nltk
@@ -23,7 +25,7 @@
2325

2426
@component
2527
class NLTKDocumentSplitter(DocumentSplitter):
26-
def __init__(
28+
def __init__( # pylint: disable=too-many-positional-arguments
2729
self,
2830
split_by: Literal["word", "sentence", "page", "passage", "function"] = "word",
2931
split_length: int = 200,
@@ -33,6 +35,7 @@ def __init__(
3335
language: Language = "en",
3436
use_split_rules: bool = True,
3537
extend_abbreviations: bool = True,
38+
splitting_function: Optional[Callable[[str], List[str]]] = None,
3639
):
3740
"""
3841
Splits your documents using NLTK to respect sentence boundaries.
@@ -53,10 +56,17 @@ def __init__(
5356
:param extend_abbreviations: Choose whether to extend NLTK's PunktTokenizer abbreviations with a list
5457
of curated abbreviations, if available.
5558
This is currently supported for English ("en") and German ("de").
59+
:param splitting_function: Necessary when `split_by` is set to "function".
60+
This is a function which must accept a single `str` as input and return a `list` of `str` as output,
61+
representing the chunks after splitting.
5662
"""
5763

5864
super(NLTKDocumentSplitter, self).__init__(
59-
split_by=split_by, split_length=split_length, split_overlap=split_overlap, split_threshold=split_threshold
65+
split_by=split_by,
66+
split_length=split_length,
67+
split_overlap=split_overlap,
68+
split_threshold=split_threshold,
69+
splitting_function=splitting_function,
6070
)
6171
nltk_imports.check()
6272
if respect_sentence_boundary and split_by != "word":
@@ -66,6 +76,8 @@ def __init__(
6676
)
6777
respect_sentence_boundary = False
6878
self.respect_sentence_boundary = respect_sentence_boundary
79+
self.use_split_rules = use_split_rules
80+
self.extend_abbreviations = extend_abbreviations
6981
self.sentence_splitter = SentenceSplitter(
7082
language=language,
7183
use_split_rules=use_split_rules,
@@ -100,9 +112,11 @@ def _split_into_units(
100112
elif split_by == "word":
101113
self.split_at = " "
102114
units = text.split(self.split_at)
115+
elif split_by == "function" and self.splitting_function is not None:
116+
return self.splitting_function(text)
103117
else:
104118
raise NotImplementedError(
105-
"DocumentSplitter only supports 'word', 'sentence', 'page' or 'passage' split_by options."
119+
"DocumentSplitter only supports 'function', 'page', 'passage', 'sentence' or 'word' split_by options."
106120
)
107121

108122
# Add the delimiter back to all units except the last one
@@ -138,6 +152,9 @@ def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
138152
raise ValueError(
139153
f"DocumentSplitter only works with text documents but content for document ID {doc.id} is None."
140154
)
155+
if doc.content == "":
156+
logger.warning("Document ID {doc_id} has an empty content. Skipping this document.", doc_id=doc.id)
157+
continue
141158

142159
if self.respect_sentence_boundary:
143160
units = self._split_into_units(doc.content, "sentence")
@@ -159,6 +176,25 @@ def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
159176
)
160177
return {"documents": split_docs}
161178

179+
def to_dict(self) -> Dict[str, Any]:
180+
"""
181+
Serializes the component to a dictionary.
182+
"""
183+
serialized = default_to_dict(
184+
self,
185+
split_by=self.split_by,
186+
split_length=self.split_length,
187+
split_overlap=self.split_overlap,
188+
split_threshold=self.split_threshold,
189+
respect_sentence_boundary=self.respect_sentence_boundary,
190+
language=self.language,
191+
use_split_rules=self.use_split_rules,
192+
extend_abbreviations=self.extend_abbreviations,
193+
)
194+
if self.splitting_function:
195+
serialized["init_parameters"]["splitting_function"] = serialize_callable(self.splitting_function)
196+
return serialized
197+
162198
@staticmethod
163199
def _number_of_sentences_to_keep(sentences: List[str], split_length: int, split_overlap: int) -> int:
164200
"""
@@ -175,7 +211,8 @@ def _number_of_sentences_to_keep(sentences: List[str], split_length: int, split_
175211

176212
num_sentences_to_keep = 0
177213
num_words = 0
178-
for sent in reversed(sentences):
214+
# Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence
215+
for sent in reversed(sentences[1:]):
179216
num_words += len(sent.split())
180217
# If the number of words is larger than the split_length then don't add any more sentences
181218
if num_words > split_length:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
---
2+
fixes:
3+
- |
4+
For the NLTKDocumentSplitter we are updating how chunks are made when splitting by word and sentence boundary is respected.
5+
Namely, to avoid fully subsuming the previous chunk into the next one, we ignore the first sentence from that chunk when calculating sentence overlap.
6+
i.e. we want to avoid cases of Doc1 = [s1, s2], Doc2 = [s1, s2, s3].
7+
8+
Finished adding function support for this component by updating the _split_into_units function and added the splitting_function init parameter.
9+
10+
Add specific to_dict method to overwrite the underlying one from DocumentSplitter. This is needed to properly save the settings of the component to yaml.

test/components/preprocessors/test_nltk_document_splitter.py

+52-11
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,18 @@
55
from pytest import LogCaptureFixture
66

77
from haystack.components.preprocessors.nltk_document_splitter import NLTKDocumentSplitter, SentenceSplitter
8+
from haystack.utils import deserialize_callable
89

910

1011
def test_init_warning_message(caplog: LogCaptureFixture) -> None:
1112
_ = NLTKDocumentSplitter(split_by="page", respect_sentence_boundary=True)
1213
assert "The 'respect_sentence_boundary' option is only supported for" in caplog.text
1314

1415

16+
def custom_split(text):
17+
return text.split(".")
18+
19+
1520
class TestNLTKDocumentSplitterSplitIntoUnits:
1621
def test_document_splitter_split_into_units_word(self) -> None:
1722
document_splitter = NLTKDocumentSplitter(
@@ -87,9 +92,11 @@ class TestNLTKDocumentSplitterNumberOfSentencesToKeep:
8792
@pytest.mark.parametrize(
8893
"sentences, expected_num_sentences",
8994
[
90-
(["Moonlight shimmered softly, wolves howled nearby, night enveloped everything."], 0),
91-
([" It was a dark night ..."], 0),
92-
([" The moon was full."], 1),
95+
(["The sun set.", "Moonlight shimmered softly, wolves howled nearby, night enveloped everything."], 0),
96+
(["The sun set.", "It was a dark night ..."], 0),
97+
(["The sun set.", " The moon was full."], 1),
98+
(["The sun.", " The moon."], 1), # Ignores the first sentence
99+
(["Sun", "Moon"], 1), # Ignores the first sentence even if its inclusion would be < split_overlap
93100
],
94101
)
95102
def test_number_of_sentences_to_keep(self, sentences: List[str], expected_num_sentences: int) -> None:
@@ -304,7 +311,7 @@ def test_run_split_by_word_respect_sentence_boundary_no_repeats(self) -> None:
304311
def test_run_split_by_word_respect_sentence_boundary_with_split_overlap_and_page_breaks(self) -> None:
305312
document_splitter = NLTKDocumentSplitter(
306313
split_by="word",
307-
split_length=5,
314+
split_length=8,
308315
split_overlap=1,
309316
split_threshold=0,
310317
language="en",
@@ -313,26 +320,60 @@ def test_run_split_by_word_respect_sentence_boundary_with_split_overlap_and_page
313320
respect_sentence_boundary=True,
314321
)
315322

316-
text = "Sentence on page 1.\fSentence on page 2. \fSentence on page 3. \f\f Sentence on page 5."
323+
text = (
324+
"Sentence on page 1. Another on page 1.\fSentence on page 2. Another on page 2.\f"
325+
"Sentence on page 3. Another on page 3.\f\f Sentence on page 5."
326+
)
317327
documents = document_splitter.run(documents=[Document(content=text)])["documents"]
318328

319-
assert len(documents) == 4
320-
assert documents[0].content == "Sentence on page 1.\f"
329+
assert len(documents) == 6
330+
assert documents[0].content == "Sentence on page 1. Another on page 1.\f"
321331
assert documents[0].meta["page_number"] == 1
322332
assert documents[0].meta["split_id"] == 0
323333
assert documents[0].meta["split_idx_start"] == text.index(documents[0].content)
324-
assert documents[1].content == "Sentence on page 1.\fSentence on page 2. \f"
334+
assert documents[1].content == "Another on page 1.\fSentence on page 2. "
325335
assert documents[1].meta["page_number"] == 1
326336
assert documents[1].meta["split_id"] == 1
327337
assert documents[1].meta["split_idx_start"] == text.index(documents[1].content)
328-
assert documents[2].content == "Sentence on page 2. \fSentence on page 3. \f\f "
338+
assert documents[2].content == "Sentence on page 2. Another on page 2.\f"
329339
assert documents[2].meta["page_number"] == 2
330340
assert documents[2].meta["split_id"] == 2
331341
assert documents[2].meta["split_idx_start"] == text.index(documents[2].content)
332-
assert documents[3].content == "Sentence on page 3. \f\f Sentence on page 5."
333-
assert documents[3].meta["page_number"] == 3
342+
assert documents[3].content == "Another on page 2.\fSentence on page 3. "
343+
assert documents[3].meta["page_number"] == 2
334344
assert documents[3].meta["split_id"] == 3
335345
assert documents[3].meta["split_idx_start"] == text.index(documents[3].content)
346+
assert documents[4].content == "Sentence on page 3. Another on page 3.\f\f "
347+
assert documents[4].meta["page_number"] == 3
348+
assert documents[4].meta["split_id"] == 4
349+
assert documents[4].meta["split_idx_start"] == text.index(documents[4].content)
350+
assert documents[5].content == "Another on page 3.\f\f Sentence on page 5."
351+
assert documents[5].meta["page_number"] == 3
352+
assert documents[5].meta["split_id"] == 5
353+
assert documents[5].meta["split_idx_start"] == text.index(documents[5].content)
354+
355+
def test_to_dict(self):
356+
splitter = NLTKDocumentSplitter(split_by="word", split_length=10, split_overlap=2, split_threshold=5)
357+
serialized = splitter.to_dict()
358+
359+
assert serialized["type"] == "haystack.components.preprocessors.nltk_document_splitter.NLTKDocumentSplitter"
360+
assert serialized["init_parameters"]["split_by"] == "word"
361+
assert serialized["init_parameters"]["split_length"] == 10
362+
assert serialized["init_parameters"]["split_overlap"] == 2
363+
assert serialized["init_parameters"]["split_threshold"] == 5
364+
assert serialized["init_parameters"]["language"] == "en"
365+
assert serialized["init_parameters"]["use_split_rules"] is True
366+
assert serialized["init_parameters"]["extend_abbreviations"] is True
367+
assert "splitting_function" not in serialized["init_parameters"]
368+
369+
def test_to_dict_with_splitting_function(self):
370+
splitter = NLTKDocumentSplitter(split_by="function", splitting_function=custom_split)
371+
serialized = splitter.to_dict()
372+
373+
assert serialized["type"] == "haystack.components.preprocessors.nltk_document_splitter.NLTKDocumentSplitter"
374+
assert serialized["init_parameters"]["split_by"] == "function"
375+
assert "splitting_function" in serialized["init_parameters"]
376+
assert callable(deserialize_callable(serialized["init_parameters"]["splitting_function"]))
336377

337378

338379
class TestSentenceSplitter:

0 commit comments

Comments
 (0)