5
5
import re
6
6
from copy import deepcopy
7
7
from pathlib import Path
8
- from typing import Any , Dict , List , Literal , Tuple
8
+ from typing import Any , Callable , Dict , List , Literal , Optional , Tuple
9
9
10
10
from haystack import Document , component , logging
11
11
from haystack .components .preprocessors .document_splitter import DocumentSplitter
12
+ from haystack .core .serialization import default_to_dict
12
13
from haystack .lazy_imports import LazyImport
14
+ from haystack .utils import serialize_callable
13
15
14
16
with LazyImport ("Run 'pip install nltk'" ) as nltk_imports :
15
17
import nltk
23
25
24
26
@component
25
27
class NLTKDocumentSplitter (DocumentSplitter ):
26
- def __init__ (
28
+ def __init__ ( # pylint: disable=too-many-positional-arguments
27
29
self ,
28
30
split_by : Literal ["word" , "sentence" , "page" , "passage" , "function" ] = "word" ,
29
31
split_length : int = 200 ,
@@ -33,6 +35,7 @@ def __init__(
33
35
language : Language = "en" ,
34
36
use_split_rules : bool = True ,
35
37
extend_abbreviations : bool = True ,
38
+ splitting_function : Optional [Callable [[str ], List [str ]]] = None ,
36
39
):
37
40
"""
38
41
Splits your documents using NLTK to respect sentence boundaries.
@@ -53,10 +56,17 @@ def __init__(
53
56
:param extend_abbreviations: Choose whether to extend NLTK's PunktTokenizer abbreviations with a list
54
57
of curated abbreviations, if available.
55
58
This is currently supported for English ("en") and German ("de").
59
+ :param splitting_function: Necessary when `split_by` is set to "function".
60
+ This is a function which must accept a single `str` as input and return a `list` of `str` as output,
61
+ representing the chunks after splitting.
56
62
"""
57
63
58
64
super (NLTKDocumentSplitter , self ).__init__ (
59
- split_by = split_by , split_length = split_length , split_overlap = split_overlap , split_threshold = split_threshold
65
+ split_by = split_by ,
66
+ split_length = split_length ,
67
+ split_overlap = split_overlap ,
68
+ split_threshold = split_threshold ,
69
+ splitting_function = splitting_function ,
60
70
)
61
71
nltk_imports .check ()
62
72
if respect_sentence_boundary and split_by != "word" :
@@ -66,6 +76,8 @@ def __init__(
66
76
)
67
77
respect_sentence_boundary = False
68
78
self .respect_sentence_boundary = respect_sentence_boundary
79
+ self .use_split_rules = use_split_rules
80
+ self .extend_abbreviations = extend_abbreviations
69
81
self .sentence_splitter = SentenceSplitter (
70
82
language = language ,
71
83
use_split_rules = use_split_rules ,
@@ -100,9 +112,11 @@ def _split_into_units(
100
112
elif split_by == "word" :
101
113
self .split_at = " "
102
114
units = text .split (self .split_at )
115
+ elif split_by == "function" and self .splitting_function is not None :
116
+ return self .splitting_function (text )
103
117
else :
104
118
raise NotImplementedError (
105
- "DocumentSplitter only supports 'word ', 'sentence ', 'page' or 'passage ' split_by options."
119
+ "DocumentSplitter only supports 'function ', 'page ', 'passage', 'sentence' or 'word ' split_by options."
106
120
)
107
121
108
122
# Add the delimiter back to all units except the last one
@@ -138,6 +152,9 @@ def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
138
152
raise ValueError (
139
153
f"DocumentSplitter only works with text documents but content for document ID { doc .id } is None."
140
154
)
155
+ if doc .content == "" :
156
+ logger .warning ("Document ID {doc_id} has an empty content. Skipping this document." , doc_id = doc .id )
157
+ continue
141
158
142
159
if self .respect_sentence_boundary :
143
160
units = self ._split_into_units (doc .content , "sentence" )
@@ -159,6 +176,25 @@ def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
159
176
)
160
177
return {"documents" : split_docs }
161
178
179
+ def to_dict (self ) -> Dict [str , Any ]:
180
+ """
181
+ Serializes the component to a dictionary.
182
+ """
183
+ serialized = default_to_dict (
184
+ self ,
185
+ split_by = self .split_by ,
186
+ split_length = self .split_length ,
187
+ split_overlap = self .split_overlap ,
188
+ split_threshold = self .split_threshold ,
189
+ respect_sentence_boundary = self .respect_sentence_boundary ,
190
+ language = self .language ,
191
+ use_split_rules = self .use_split_rules ,
192
+ extend_abbreviations = self .extend_abbreviations ,
193
+ )
194
+ if self .splitting_function :
195
+ serialized ["init_parameters" ]["splitting_function" ] = serialize_callable (self .splitting_function )
196
+ return serialized
197
+
162
198
@staticmethod
163
199
def _number_of_sentences_to_keep (sentences : List [str ], split_length : int , split_overlap : int ) -> int :
164
200
"""
@@ -175,7 +211,8 @@ def _number_of_sentences_to_keep(sentences: List[str], split_length: int, split_
175
211
176
212
num_sentences_to_keep = 0
177
213
num_words = 0
178
- for sent in reversed (sentences ):
214
+ # Next overlapping Document should not start exactly the same as the previous one, so we skip the first sentence
215
+ for sent in reversed (sentences [1 :]):
179
216
num_words += len (sent .split ())
180
217
# If the number of words is larger than the split_length then don't add any more sentences
181
218
if num_words > split_length :
0 commit comments