Skip to content

Commit 3f77d3a

Browse files
!feat: unify NLTKDocumentSplitter and DocumentSplitter (#8617)
* wip: initial import * wip: refactoring * wip: refactoring tests * wip: refactoring tests * making all NLTKSplitter related tests work * refactoring * docstrings * refactoring and removing NLTKDocumentSplitter * fixing tests for custom sentence tokenizer * fixing tests for custom sentence tokenizer * cleaning up * adding release notes * reverting some changes * cleaning up tests * fixing serialisation and adding tests * cleaning up * wip * renaming and cleaning * adding NLTK files * updating docstring * adding import to init * Update haystack/components/preprocessors/document_splitter.py Co-authored-by: Stefano Fiorucci <[email protected]> * updating tests * wip * adding sentence/period change warning * fixing LICENSE header * Update haystack/components/preprocessors/document_splitter.py Co-authored-by: Stefano Fiorucci <[email protected]> --------- Co-authored-by: Stefano Fiorucci <[email protected]>
1 parent 6cceaac commit 3f77d3a

File tree

6 files changed

+635
-53
lines changed

6 files changed

+635
-53
lines changed

Diff for: haystack/components/preprocessors/document_splitter.py

+239-35
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,29 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
import warnings
56
from copy import deepcopy
67
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
78

89
from more_itertools import windowed
910

1011
from haystack import Document, component, logging
12+
from haystack.components.preprocessors.sentence_tokenizer import Language, SentenceSplitter, nltk_imports
1113
from haystack.core.serialization import default_from_dict, default_to_dict
1214
from haystack.utils import deserialize_callable, serialize_callable
1315

1416
logger = logging.getLogger(__name__)
1517

16-
# Maps the 'split_by' argument to the actual char used to split the Documents.
17-
# 'function' is not in the mapping cause it doesn't split on chars.
18-
_SPLIT_BY_MAPPING = {"page": "\f", "passage": "\n\n", "sentence": ".", "word": " ", "line": "\n"}
18+
# mapping of split by character, 'function' and 'sentence' don't split by character
19+
_CHARACTER_SPLIT_BY_MAPPING = {"page": "\f", "passage": "\n\n", "period": ".", "word": " ", "line": "\n"}
1920

2021

2122
@component
2223
class DocumentSplitter:
2324
"""
2425
Splits long documents into smaller chunks.
2526
26-
This is a common preprocessing step during indexing.
27-
It helps Embedders create meaningful semantic representations
27+
This is a common preprocessing step during indexing. It helps Embedders create meaningful semantic representations
2828
and prevents exceeding language model context limits.
2929
3030
The DocumentSplitter is compatible with the following DocumentStores:
@@ -54,40 +54,115 @@ class DocumentSplitter:
5454

5555
def __init__( # pylint: disable=too-many-positional-arguments
5656
self,
57-
split_by: Literal["function", "page", "passage", "sentence", "word", "line"] = "word",
57+
split_by: Literal["function", "page", "passage", "period", "word", "line", "sentence"] = "word",
5858
split_length: int = 200,
5959
split_overlap: int = 0,
6060
split_threshold: int = 0,
6161
splitting_function: Optional[Callable[[str], List[str]]] = None,
62+
respect_sentence_boundary: bool = False,
63+
language: Language = "en",
64+
use_split_rules: bool = True,
65+
extend_abbreviations: bool = True,
6266
):
6367
"""
6468
Initialize DocumentSplitter.
6569
66-
:param split_by: The unit for splitting your documents. Choose from `word` for splitting by spaces (" "),
67-
`sentence` for splitting by periods ("."), `page` for splitting by form feed ("\\f"),
68-
`passage` for splitting by double line breaks ("\\n\\n") or `line` for splitting each line ("\\n").
70+
:param split_by: The unit for splitting your documents. Choose from:
71+
- `word` for splitting by spaces (" ")
72+
- `period` for splitting by periods (".")
73+
- `page` for splitting by form feed ("\\f")
74+
- `passage` for splitting by double line breaks ("\\n\\n")
75+
- `line` for splitting each line ("\\n")
76+
- `sentence` for splitting by NLTK sentence tokenizer
77+
6978
:param split_length: The maximum number of units in each split.
7079
:param split_overlap: The number of overlapping units for each split.
7180
:param split_threshold: The minimum number of units per split. If a split has fewer units
7281
than the threshold, it's attached to the previous split.
7382
:param splitting_function: Necessary when `split_by` is set to "function".
7483
This is a function which must accept a single `str` as input and return a `list` of `str` as output,
7584
representing the chunks after splitting.
85+
:param respect_sentence_boundary: Choose whether to respect sentence boundaries when splitting by "word".
86+
If True, uses NLTK to detect sentence boundaries, ensuring splits occur only between sentences.
87+
:param language: Choose the language for the NLTK tokenizer. The default is English ("en").
88+
:param use_split_rules: Choose whether to use additional split rules when splitting by `sentence`.
89+
:param extend_abbreviations: Choose whether to extend NLTK's PunktTokenizer abbreviations with a list
90+
of curated abbreviations, if available. This is currently supported for English ("en") and German ("de").
7691
"""
7792

7893
self.split_by = split_by
79-
if split_by not in ["function", "page", "passage", "sentence", "word", "line"]:
80-
raise ValueError("split_by must be one of 'function', 'word', 'sentence', 'page', 'passage' or 'line'.")
94+
self.split_length = split_length
95+
self.split_overlap = split_overlap
96+
self.split_threshold = split_threshold
97+
self.splitting_function = splitting_function
98+
self.respect_sentence_boundary = respect_sentence_boundary
99+
self.language = language
100+
self.use_split_rules = use_split_rules
101+
self.extend_abbreviations = extend_abbreviations
102+
103+
self._init_checks(
104+
split_by=split_by,
105+
split_length=split_length,
106+
split_overlap=split_overlap,
107+
splitting_function=splitting_function,
108+
respect_sentence_boundary=respect_sentence_boundary,
109+
)
110+
111+
if split_by == "sentence" or (respect_sentence_boundary and split_by == "word"):
112+
nltk_imports.check()
113+
self.sentence_splitter = SentenceSplitter(
114+
language=language,
115+
use_split_rules=use_split_rules,
116+
extend_abbreviations=extend_abbreviations,
117+
keep_white_spaces=True,
118+
)
119+
120+
if split_by == "sentence":
121+
# ToDo: remove this warning in the next major release
122+
msg = (
123+
"The `split_by='sentence'` no longer splits by '.' and now relies on custom sentence tokenizer "
124+
"based on NLTK. To achieve the previous behaviour use `split_by='period'."
125+
)
126+
warnings.warn(msg)
127+
128+
def _init_checks(
129+
self,
130+
*,
131+
split_by: str,
132+
split_length: int,
133+
split_overlap: int,
134+
splitting_function: Optional[Callable],
135+
respect_sentence_boundary: bool,
136+
) -> None:
137+
"""
138+
Validates initialization parameters for DocumentSplitter.
139+
140+
:param split_by: The unit for splitting documents
141+
:param split_length: The maximum number of units in each split
142+
:param split_overlap: The number of overlapping units for each split
143+
:param splitting_function: Custom function for splitting when split_by="function"
144+
:param respect_sentence_boundary: Whether to respect sentence boundaries when splitting
145+
:raises ValueError: If any parameter is invalid
146+
"""
147+
valid_split_by = ["function", "page", "passage", "period", "word", "line", "sentence"]
148+
if split_by not in valid_split_by:
149+
raise ValueError(f"split_by must be one of {', '.join(valid_split_by)}.")
150+
81151
if split_by == "function" and splitting_function is None:
82152
raise ValueError("When 'split_by' is set to 'function', a valid 'splitting_function' must be provided.")
153+
83154
if split_length <= 0:
84155
raise ValueError("split_length must be greater than 0.")
85-
self.split_length = split_length
156+
86157
if split_overlap < 0:
87158
raise ValueError("split_overlap must be greater than or equal to 0.")
88-
self.split_overlap = split_overlap
89-
self.split_threshold = split_threshold
90-
self.splitting_function = splitting_function
159+
160+
if respect_sentence_boundary and split_by != "word":
161+
logger.warning(
162+
"The 'respect_sentence_boundary' option is only supported for `split_by='word'`. "
163+
"The option `respect_sentence_boundary` will be set to `False`."
164+
)
165+
self.respect_sentence_boundary = False
91166

92167
@component.output_types(documents=List[Document])
93168
def run(self, documents: List[Document]):
@@ -98,7 +173,6 @@ def run(self, documents: List[Document]):
98173
and an overlap of `split_overlap`.
99174
100175
:param documents: The documents to split.
101-
102176
:returns: A dictionary with the following key:
103177
- `documents`: List of documents with the split texts. Each document includes:
104178
- A metadata field `source_id` to track the original document.
@@ -121,39 +195,69 @@ def run(self, documents: List[Document]):
121195
if doc.content == "":
122196
logger.warning("Document ID {doc_id} has an empty content. Skipping this document.", doc_id=doc.id)
123197
continue
124-
split_docs += self._split(doc)
198+
199+
split_docs += self._split_document(doc)
125200
return {"documents": split_docs}
126201

127-
def _split(self, to_split: Document) -> List[Document]:
128-
# We already check this before calling _split but
129-
# we need to make linters happy
130-
if to_split.content is None:
131-
return []
202+
def _split_document(self, doc: Document) -> List[Document]:
203+
if self.split_by == "sentence" or self.respect_sentence_boundary:
204+
return self._split_by_nltk_sentence(doc)
132205

133206
if self.split_by == "function" and self.splitting_function is not None:
134-
splits = self.splitting_function(to_split.content)
135-
docs: List[Document] = []
136-
for s in splits:
137-
meta = deepcopy(to_split.meta)
138-
meta["source_id"] = to_split.id
139-
docs.append(Document(content=s, meta=meta))
140-
return docs
141-
142-
split_at = _SPLIT_BY_MAPPING[self.split_by]
143-
units = to_split.content.split(split_at)
207+
return self._split_by_function(doc)
208+
209+
return self._split_by_character(doc)
210+
211+
def _split_by_nltk_sentence(self, doc: Document) -> List[Document]:
212+
split_docs = []
213+
214+
result = self.sentence_splitter.split_sentences(doc.content) # type: ignore # None check is done in run()
215+
units = [sentence["sentence"] for sentence in result]
216+
217+
if self.respect_sentence_boundary:
218+
text_splits, splits_pages, splits_start_idxs = self._concatenate_sentences_based_on_word_amount(
219+
sentences=units, split_length=self.split_length, split_overlap=self.split_overlap
220+
)
221+
else:
222+
text_splits, splits_pages, splits_start_idxs = self._concatenate_units(
223+
elements=units,
224+
split_length=self.split_length,
225+
split_overlap=self.split_overlap,
226+
split_threshold=self.split_threshold,
227+
)
228+
metadata = deepcopy(doc.meta)
229+
metadata["source_id"] = doc.id
230+
split_docs += self._create_docs_from_splits(
231+
text_splits=text_splits, splits_pages=splits_pages, splits_start_idxs=splits_start_idxs, meta=metadata
232+
)
233+
234+
return split_docs
235+
236+
def _split_by_character(self, doc) -> List[Document]:
237+
split_at = _CHARACTER_SPLIT_BY_MAPPING[self.split_by]
238+
units = doc.content.split(split_at)
144239
# Add the delimiter back to all units except the last one
145240
for i in range(len(units) - 1):
146241
units[i] += split_at
147-
148242
text_splits, splits_pages, splits_start_idxs = self._concatenate_units(
149243
units, self.split_length, self.split_overlap, self.split_threshold
150244
)
151-
metadata = deepcopy(to_split.meta)
152-
metadata["source_id"] = to_split.id
245+
metadata = deepcopy(doc.meta)
246+
metadata["source_id"] = doc.id
153247
return self._create_docs_from_splits(
154248
text_splits=text_splits, splits_pages=splits_pages, splits_start_idxs=splits_start_idxs, meta=metadata
155249
)
156250

251+
def _split_by_function(self, doc) -> List[Document]:
252+
# the check for None is done already in the run method
253+
splits = self.splitting_function(doc.content) # type: ignore
254+
docs: List[Document] = []
255+
for s in splits:
256+
meta = deepcopy(doc.meta)
257+
meta["source_id"] = doc.id
258+
docs.append(Document(content=s, meta=meta))
259+
return docs
260+
157261
def _concatenate_units(
158262
self, elements: List[str], split_length: int, split_overlap: int, split_threshold: int
159263
) -> Tuple[List[str], List[int], List[int]]:
@@ -265,6 +369,10 @@ def to_dict(self) -> Dict[str, Any]:
265369
split_length=self.split_length,
266370
split_overlap=self.split_overlap,
267371
split_threshold=self.split_threshold,
372+
respect_sentence_boundary=self.respect_sentence_boundary,
373+
language=self.language,
374+
use_split_rules=self.use_split_rules,
375+
extend_abbreviations=self.extend_abbreviations,
268376
)
269377
if self.splitting_function:
270378
serialized["init_parameters"]["splitting_function"] = serialize_callable(self.splitting_function)
@@ -282,3 +390,99 @@ def from_dict(cls, data: Dict[str, Any]) -> "DocumentSplitter":
282390
init_params["splitting_function"] = deserialize_callable(splitting_function)
283391

284392
return default_from_dict(cls, data)
393+
394+
@staticmethod
395+
def _concatenate_sentences_based_on_word_amount(
396+
sentences: List[str], split_length: int, split_overlap: int
397+
) -> Tuple[List[str], List[int], List[int]]:
398+
"""
399+
Groups the sentences into chunks of `split_length` words while respecting sentence boundaries.
400+
401+
This function is only used when splitting by `word` and `respect_sentence_boundary` is set to `True`, i.e.:
402+
with NLTK sentence tokenizer.
403+
404+
:param sentences: The list of sentences to split.
405+
:param split_length: The maximum number of words in each split.
406+
:param split_overlap: The number of overlapping words in each split.
407+
:returns: A tuple containing the concatenated sentences, the start page numbers, and the start indices.
408+
"""
409+
# chunk information
410+
chunk_word_count = 0
411+
chunk_starting_page_number = 1
412+
chunk_start_idx = 0
413+
current_chunk: List[str] = []
414+
# output lists
415+
split_start_page_numbers = []
416+
list_of_splits: List[List[str]] = []
417+
split_start_indices = []
418+
419+
for sentence_idx, sentence in enumerate(sentences):
420+
current_chunk.append(sentence)
421+
chunk_word_count += len(sentence.split())
422+
next_sentence_word_count = (
423+
len(sentences[sentence_idx + 1].split()) if sentence_idx < len(sentences) - 1 else 0
424+
)
425+
426+
# Number of words in the current chunk plus the next sentence is larger than the split_length,
427+
# or we reached the last sentence
428+
if (chunk_word_count + next_sentence_word_count) > split_length or sentence_idx == len(sentences) - 1:
429+
# Save current chunk and start a new one
430+
list_of_splits.append(current_chunk)
431+
split_start_page_numbers.append(chunk_starting_page_number)
432+
split_start_indices.append(chunk_start_idx)
433+
434+
# Get the number of sentences that overlap with the next chunk
435+
num_sentences_to_keep = DocumentSplitter._number_of_sentences_to_keep(
436+
sentences=current_chunk, split_length=split_length, split_overlap=split_overlap
437+
)
438+
# Set up information for the new chunk
439+
if num_sentences_to_keep > 0:
440+
# Processed sentences are the ones that are not overlapping with the next chunk
441+
processed_sentences = current_chunk[:-num_sentences_to_keep]
442+
chunk_starting_page_number += sum(sent.count("\f") for sent in processed_sentences)
443+
chunk_start_idx += len("".join(processed_sentences))
444+
# Next chunk starts with the sentences that were overlapping with the previous chunk
445+
current_chunk = current_chunk[-num_sentences_to_keep:]
446+
chunk_word_count = sum(len(s.split()) for s in current_chunk)
447+
else:
448+
# Here processed_sentences is the same as current_chunk since there is no overlap
449+
chunk_starting_page_number += sum(sent.count("\f") for sent in current_chunk)
450+
chunk_start_idx += len("".join(current_chunk))
451+
current_chunk = []
452+
chunk_word_count = 0
453+
454+
# Concatenate the sentences together within each split
455+
text_splits = []
456+
for split in list_of_splits:
457+
text = "".join(split)
458+
if len(text) > 0:
459+
text_splits.append(text)
460+
461+
return text_splits, split_start_page_numbers, split_start_indices
462+
463+
@staticmethod
464+
def _number_of_sentences_to_keep(sentences: List[str], split_length: int, split_overlap: int) -> int:
465+
"""
466+
Returns the number of sentences to keep in the next chunk based on the `split_overlap` and `split_length`.
467+
468+
:param sentences: The list of sentences to split.
469+
:param split_length: The maximum number of words in each split.
470+
:param split_overlap: The number of overlapping words in each split.
471+
:returns: The number of sentences to keep in the next chunk.
472+
"""
473+
# If the split_overlap is 0, we don't need to keep any sentences
474+
if split_overlap == 0:
475+
return 0
476+
477+
num_sentences_to_keep = 0
478+
num_words = 0
479+
# Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence
480+
for sent in reversed(sentences[1:]):
481+
num_words += len(sent.split())
482+
# If the number of words is larger than the split_length then don't add any more sentences
483+
if num_words > split_length:
484+
break
485+
num_sentences_to_keep += 1
486+
if num_words > split_overlap:
487+
break
488+
return num_sentences_to_keep

0 commit comments

Comments
 (0)