Skip to content

Commit 9446714

Browse files
authored
fix: fix serialization of DocumentRecallEvaluator (#7662)
* fix serialization of DocumentRecallEvaluator * add requested tests
1 parent f14bc53 commit 9446714

File tree

3 files changed

+46
-1
lines changed

3 files changed

+46
-1
lines changed

Diff for: haystack/components/evaluators/document_recall.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from enum import Enum
22
from typing import Any, Dict, List, Union
33

4-
from haystack.core.component import component
4+
from haystack import component, default_to_dict
55
from haystack.dataclasses import Document
66

77

@@ -74,6 +74,7 @@ def __init__(self, mode: Union[str, RecallMode] = RecallMode.SINGLE_HIT):
7474

7575
mode_functions = {RecallMode.SINGLE_HIT: self._recall_single_hit, RecallMode.MULTI_HIT: self._recall_multi_hit}
7676
self.mode_function = mode_functions[mode]
77+
self.mode = mode
7778

7879
def _recall_single_hit(self, ground_truth_documents: List[Document], retrieved_documents: List[Document]) -> float:
7980
unique_truths = {g.content for g in ground_truth_documents}
@@ -117,3 +118,12 @@ def run(
117118
scores.append(score)
118119

119120
return {"score": sum(scores) / len(retrieved_documents), "individual_scores": scores}
121+
122+
def to_dict(self) -> Dict[str, Any]:
123+
"""
124+
Serializes the component to a dictionary.
125+
126+
:returns:
127+
Dictionary with serialized data.
128+
"""
129+
return default_to_dict(self, mode=str(self.mode))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
fixes:
3+
- |
4+
Add `to_dict` method to `DocumentRecallEvaluator` to allow proper serialization of the component.

Diff for: test/components/evaluators/test_document_recall.py

+31
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from haystack.components.evaluators.document_recall import DocumentRecallEvaluator, RecallMode
44
from haystack.dataclasses import Document
5+
from haystack import default_from_dict
56

67

78
def test_init_with_unknown_mode_string():
@@ -78,6 +79,21 @@ def test_run_with_different_lengths(self, evaluator):
7879
retrieved_documents=[[Document(content="Berlin")]],
7980
)
8081

82+
def test_to_dict(self, evaluator):
83+
data = evaluator.to_dict()
84+
assert data == {
85+
"type": "haystack.components.evaluators.document_recall.DocumentRecallEvaluator",
86+
"init_parameters": {"mode": "single_hit"},
87+
}
88+
89+
def test_from_dict(self):
90+
data = {
91+
"type": "haystack.components.evaluators.document_recall.DocumentRecallEvaluator",
92+
"init_parameters": {"mode": "single_hit"},
93+
}
94+
new_evaluator = default_from_dict(DocumentRecallEvaluator, data)
95+
assert new_evaluator.mode == RecallMode.SINGLE_HIT
96+
8197

8298
class TestDocumentRecallEvaluatorMultiHit:
8399
@pytest.fixture
@@ -152,3 +168,18 @@ def test_run_with_different_lengths(self, evaluator):
152168
ground_truth_documents=[[Document(content="Berlin")], [Document(content="Paris")]],
153169
retrieved_documents=[[Document(content="Berlin")]],
154170
)
171+
172+
def test_to_dict(self, evaluator):
173+
data = evaluator.to_dict()
174+
assert data == {
175+
"type": "haystack.components.evaluators.document_recall.DocumentRecallEvaluator",
176+
"init_parameters": {"mode": "multi_hit"},
177+
}
178+
179+
def test_from_dict(self):
180+
data = {
181+
"type": "haystack.components.evaluators.document_recall.DocumentRecallEvaluator",
182+
"init_parameters": {"mode": "multi_hit"},
183+
}
184+
new_evaluator = default_from_dict(DocumentRecallEvaluator, data)
185+
assert new_evaluator.mode == RecallMode.MULTI_HIT

0 commit comments

Comments
 (0)