Skip to content

Commit bd77120

Browse files
authored
Fix DocumentSplitter not splitting by function (#8549)
* Fix DocumentSplitter not splitting by function * Make the split_by mapping a constant
1 parent cea1e3f commit bd77120

File tree

4 files changed

+67
-51
lines changed

4 files changed

+67
-51
lines changed

haystack/components/preprocessors/document_splitter.py

+37-37
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313

1414
logger = logging.getLogger(__name__)
1515

16+
# Maps the 'split_by' argument to the actual char used to split the Documents.
17+
# 'function' is not in the mapping cause it doesn't split on chars.
18+
_SPLIT_BY_MAPPING = {"page": "\f", "passage": "\n\n", "sentence": ".", "word": " ", "line": "\n"}
19+
1620

1721
@component
1822
class DocumentSplitter:
@@ -73,7 +77,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
7377

7478
self.split_by = split_by
7579
if split_by not in ["function", "page", "passage", "sentence", "word", "line"]:
76-
raise ValueError("split_by must be one of 'word', 'sentence', 'page', 'passage' or 'line'.")
80+
raise ValueError("split_by must be one of 'function', 'word', 'sentence', 'page', 'passage' or 'line'.")
7781
if split_by == "function" and splitting_function is None:
7882
raise ValueError("When 'split_by' is set to 'function', a valid 'splitting_function' must be provided.")
7983
if split_length <= 0:
@@ -108,7 +112,7 @@ def run(self, documents: List[Document]):
108112
if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
109113
raise TypeError("DocumentSplitter expects a List of Documents as input.")
110114

111-
split_docs = []
115+
split_docs: List[Document] = []
112116
for doc in documents:
113117
if doc.content is None:
114118
raise ValueError(
@@ -117,42 +121,38 @@ def run(self, documents: List[Document]):
117121
if doc.content == "":
118122
logger.warning("Document ID {doc_id} has an empty content. Skipping this document.", doc_id=doc.id)
119123
continue
120-
units = self._split_into_units(doc.content, self.split_by)
121-
text_splits, splits_pages, splits_start_idxs = self._concatenate_units(
122-
units, self.split_length, self.split_overlap, self.split_threshold
123-
)
124-
metadata = deepcopy(doc.meta)
125-
metadata["source_id"] = doc.id
126-
split_docs += self._create_docs_from_splits(
127-
text_splits=text_splits, splits_pages=splits_pages, splits_start_idxs=splits_start_idxs, meta=metadata
128-
)
124+
split_docs += self._split(doc)
129125
return {"documents": split_docs}
130126

131-
def _split_into_units(
132-
self, text: str, split_by: Literal["function", "page", "passage", "sentence", "word", "line"]
133-
) -> List[str]:
134-
if split_by == "page":
135-
self.split_at = "\f"
136-
elif split_by == "passage":
137-
self.split_at = "\n\n"
138-
elif split_by == "sentence":
139-
self.split_at = "."
140-
elif split_by == "word":
141-
self.split_at = " "
142-
elif split_by == "line":
143-
self.split_at = "\n"
144-
elif split_by == "function" and self.splitting_function is not None:
145-
return self.splitting_function(text)
146-
else:
147-
raise NotImplementedError(
148-
"""DocumentSplitter only supports 'function', 'line', 'page',
149-
'passage', 'sentence' or 'word' split_by options."""
150-
)
151-
units = text.split(self.split_at)
127+
def _split(self, to_split: Document) -> List[Document]:
128+
# We already check this before calling _split but
129+
# we need to make linters happy
130+
if to_split.content is None:
131+
return []
132+
133+
if self.split_by == "function" and self.splitting_function is not None:
134+
splits = self.splitting_function(to_split.content)
135+
docs: List[Document] = []
136+
for s in splits:
137+
meta = deepcopy(to_split.meta)
138+
meta["source_id"] = to_split.id
139+
docs.append(Document(content=s, meta=meta))
140+
return docs
141+
142+
split_at = _SPLIT_BY_MAPPING[self.split_by]
143+
units = to_split.content.split(split_at)
152144
# Add the delimiter back to all units except the last one
153145
for i in range(len(units) - 1):
154-
units[i] += self.split_at
155-
return units
146+
units[i] += split_at
147+
148+
text_splits, splits_pages, splits_start_idxs = self._concatenate_units(
149+
units, self.split_length, self.split_overlap, self.split_threshold
150+
)
151+
metadata = deepcopy(to_split.meta)
152+
metadata["source_id"] = to_split.id
153+
return self._create_docs_from_splits(
154+
text_splits=text_splits, splits_pages=splits_pages, splits_start_idxs=splits_start_idxs, meta=metadata
155+
)
156156

157157
def _concatenate_units(
158158
self, elements: List[str], split_length: int, split_overlap: int, split_threshold: int
@@ -166,8 +166,8 @@ def _concatenate_units(
166166
"""
167167

168168
text_splits: List[str] = []
169-
splits_pages = []
170-
splits_start_idxs = []
169+
splits_pages: List[int] = []
170+
splits_start_idxs: List[int] = []
171171
cur_start_idx = 0
172172
cur_page = 1
173173
segments = windowed(elements, n=split_length, step=split_length - split_overlap)
@@ -200,7 +200,7 @@ def _concatenate_units(
200200
return text_splits, splits_pages, splits_start_idxs
201201

202202
def _create_docs_from_splits(
203-
self, text_splits: List[str], splits_pages: List[int], splits_start_idxs: List[int], meta: Dict
203+
self, text_splits: List[str], splits_pages: List[int], splits_start_idxs: List[int], meta: Dict[str, Any]
204204
) -> List[Document]:
205205
"""
206206
Creates Document objects from splits enriching them with page number and the metadata of the original document.

haystack/components/preprocessors/nltk_document_splitter.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
8787
self.language = language
8888

8989
def _split_into_units(
90-
self, text: str, split_by: Literal["word", "sentence", "passage", "page", "function"]
90+
self, text: str, split_by: Literal["function", "page", "passage", "sentence", "word", "line"]
9191
) -> List[str]:
9292
"""
9393
Splits the text into units based on the specified split_by parameter.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
fixes:
3+
- |
4+
Fix `DocumentSplitter` to handle custom `splitting_function` without requiring `split_length`.
5+
Previously the `splitting_function` provided would not override other settings.

test/components/preprocessors/test_document_splitter.py

+24-13
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_empty_list(self):
5757

5858
def test_unsupported_split_by(self):
5959
with pytest.raises(
60-
ValueError, match="split_by must be one of 'word', 'sentence', 'page', 'passage' or 'line'."
60+
ValueError, match="split_by must be one of 'function', 'word', 'sentence', 'page', 'passage' or 'line'."
6161
):
6262
DocumentSplitter(split_by="unsupported")
6363

@@ -177,25 +177,36 @@ def test_split_by_page(self):
177177
assert docs[2].meta["page_number"] == 3
178178

179179
def test_split_by_function(self):
180-
splitting_function = lambda input_str: input_str.split(".")
181-
splitter = DocumentSplitter(split_by="function", splitting_function=splitting_function, split_length=1)
180+
splitting_function = lambda s: s.split(".")
181+
splitter = DocumentSplitter(split_by="function", splitting_function=splitting_function)
182182
text = "This.Is.A.Test"
183-
result = splitter.run(documents=[Document(content=text)])
183+
result = splitter.run(documents=[Document(id="1", content=text, meta={"key": "value"})])
184184
docs = result["documents"]
185185

186-
word_list = ["This", "Is", "A", "Test"]
187186
assert len(docs) == 4
188-
for w_target, w_split in zip(word_list, docs):
189-
assert w_split.content == w_target
190-
191-
splitting_function = lambda input_str: re.split("[\s]{2,}", input_str)
192-
splitter = DocumentSplitter(split_by="function", splitting_function=splitting_function, split_length=1)
187+
assert docs[0].content == "This"
188+
assert docs[0].meta == {"key": "value", "source_id": "1"}
189+
assert docs[1].content == "Is"
190+
assert docs[1].meta == {"key": "value", "source_id": "1"}
191+
assert docs[2].content == "A"
192+
assert docs[2].meta == {"key": "value", "source_id": "1"}
193+
assert docs[3].content == "Test"
194+
assert docs[3].meta == {"key": "value", "source_id": "1"}
195+
196+
splitting_function = lambda s: re.split(r"[\s]{2,}", s)
197+
splitter = DocumentSplitter(split_by="function", splitting_function=splitting_function)
193198
text = "This Is\n A Test"
194-
result = splitter.run(documents=[Document(content=text)])
199+
result = splitter.run(documents=[Document(id="1", content=text, meta={"key": "value"})])
195200
docs = result["documents"]
196201
assert len(docs) == 4
197-
for w_target, w_split in zip(word_list, docs):
198-
assert w_split.content == w_target
202+
assert docs[0].content == "This"
203+
assert docs[0].meta == {"key": "value", "source_id": "1"}
204+
assert docs[1].content == "Is"
205+
assert docs[1].meta == {"key": "value", "source_id": "1"}
206+
assert docs[2].content == "A"
207+
assert docs[2].meta == {"key": "value", "source_id": "1"}
208+
assert docs[3].content == "Test"
209+
assert docs[3].meta == {"key": "value", "source_id": "1"}
199210

200211
def test_split_by_word_with_overlap(self):
201212
splitter = DocumentSplitter(split_by="word", split_length=10, split_overlap=2)

0 commit comments

Comments
 (0)