Skip to content

Commit 2282c26

Browse files
feat!: SentenceWindowRetriever returns List[Document] with docs ordered by split_idx_start (#8590)
* initial import * adding a few pylint disable * adding tests * fixing integration tests * adding release notes * fixing types and docstrings
1 parent f0638b2 commit 2282c26

File tree

5 files changed

+48
-10
lines changed

5 files changed

+48
-10
lines changed

Diff for: haystack/components/retrievers/in_memory/bm25_retriever.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class InMemoryBM25Retriever:
3838
```
3939
"""
4040

41-
def __init__(
41+
def __init__( # pylint: disable=too-many-positional-arguments
4242
self,
4343
document_store: InMemoryDocumentStore,
4444
filters: Optional[Dict[str, Any]] = None,

Diff for: haystack/components/retrievers/in_memory/embedding_retriever.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class InMemoryEmbeddingRetriever:
5050
```
5151
"""
5252

53-
def __init__(
53+
def __init__( # pylint: disable=too-many-positional-arguments
5454
self,
5555
document_store: InMemoryDocumentStore,
5656
filters: Optional[Dict[str, Any]] = None,
@@ -143,7 +143,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "InMemoryEmbeddingRetriever":
143143
return default_from_dict(cls, data)
144144

145145
@component.output_types(documents=List[Document])
146-
def run(
146+
def run( # pylint: disable=too-many-positional-arguments
147147
self,
148148
query_embedding: List[float],
149149
filters: Optional[Dict[str, Any]] = None,

Diff for: haystack/components/retrievers/sentence_window_retriever.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "SentenceWindowRetriever":
151151
# deserialize the component
152152
return default_from_dict(cls, data)
153153

154-
@component.output_types(context_windows=List[str], context_documents=List[List[Document]])
154+
@component.output_types(context_windows=List[str], context_documents=List[Document])
155155
def run(self, retrieved_documents: List[Document], window_size: Optional[int] = None):
156156
"""
157157
Based on the `source_id` and on the `doc.meta['split_id']` get surrounding documents from the document store.
@@ -166,9 +166,9 @@ def run(self, retrieved_documents: List[Document], window_size: Optional[int] =
166166
A dictionary with the following keys:
167167
- `context_windows`: A list of strings, where each string represents the concatenated text from the
168168
context window of the corresponding document in `retrieved_documents`.
169-
- `context_documents`: A list of lists of `Document` objects, where each inner list contains the
170-
documents that come from the context window for the corresponding document in
171-
`retrieved_documents`.
169+
- `context_documents`: A list `Document` objects, containing the retrieved documents plus the context
170+
document surrounding them. The documents are sorted by the `split_idx_start`
171+
meta field.
172172
173173
"""
174174
window_size = window_size or self.window_size
@@ -200,6 +200,7 @@ def run(self, retrieved_documents: List[Document], window_size: Optional[int] =
200200
}
201201
)
202202
context_text.append(self.merge_documents_text(context_docs))
203-
context_documents.append(context_docs)
203+
context_docs_sorted = sorted(context_docs, key=lambda doc: doc.meta["split_idx_start"])
204+
context_documents.extend(context_docs_sorted)
204205

205206
return {"context_windows": context_text, "context_documents": context_documents}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
upgrade:
3+
- |
4+
The SentenceWindowRetriever output key `context_documents` now outputs a List[Document] containing the retrieved documents together with the context windows ordered by `split_idx_start`.

Diff for: test/components/retrievers/test_sentence_window_retriever.py

+35-2
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,39 @@ def test_constructor_parameter_does_not_change(self):
141141
retriever.run(retrieved_documents=[Document.from_dict(doc)], window_size=1)
142142
assert retriever.window_size == 5
143143

144+
def test_context_documents_returned_are_ordered_by_split_idx_start(self):
145+
docs = []
146+
accumulated_length = 0
147+
for sent in range(10):
148+
content = f"Sentence {sent}."
149+
docs.append(
150+
Document(
151+
content=content,
152+
meta={
153+
"id": f"doc_{sent}",
154+
"split_idx_start": accumulated_length,
155+
"source_id": "source1",
156+
"split_id": sent,
157+
},
158+
)
159+
)
160+
accumulated_length += len(content)
161+
162+
import random
163+
164+
random.shuffle(docs)
165+
166+
doc_store = InMemoryDocumentStore()
167+
doc_store.write_documents(docs)
168+
retriever = SentenceWindowRetriever(document_store=doc_store, window_size=3)
169+
170+
# run the retriever with a document whose content = "Sentence 4."
171+
result = retriever.run(retrieved_documents=[doc for doc in docs if doc.content == "Sentence 4."])
172+
173+
# assert that the context documents are in the correct order
174+
assert len(result["context_documents"]) == 7
175+
assert [doc.meta["split_idx_start"] for doc in result["context_documents"]] == [11, 22, 33, 44, 55, 66, 77]
176+
144177
@pytest.mark.integration
145178
def test_run_with_pipeline(self):
146179
splitter = DocumentSplitter(split_length=1, split_overlap=0, split_by="sentence")
@@ -165,13 +198,13 @@ def test_run_with_pipeline(self):
165198
"This is a text with some words. There is a second sentence. And there is also a third sentence. "
166199
"It also contains a fourth sentence. And a fifth sentence."
167200
]
168-
assert len(result["sentence_window_retriever"]["context_documents"][0]) == 5
201+
assert len(result["sentence_window_retriever"]["context_documents"]) == 5
169202

170203
result = pipe.run({"bm25_retriever": {"query": "third"}, "sentence_window_retriever": {"window_size": 1}})
171204
assert result["sentence_window_retriever"]["context_windows"] == [
172205
" There is a second sentence. And there is also a third sentence. It also contains a fourth sentence."
173206
]
174-
assert len(result["sentence_window_retriever"]["context_documents"][0]) == 3
207+
assert len(result["sentence_window_retriever"]["context_documents"]) == 3
175208

176209
@pytest.mark.integration
177210
def test_serialization_deserialization_in_pipeline(self):

0 commit comments

Comments
 (0)