Skip to content

Commit 4e7092d

Browse files
feat: Allow the scores returned by AI Search to be populated in the Document.meta (#1907)
* feat: Allow the search scores to be populated in the Document.meta - exposing these scores (as they are very critical information) to users of this integration. * feat: Allow the search scores to be populated in the Document.meta - exposing these scores (as they are very critical information) to users of this integration. * chore: running linter * Add a new param for search scores * Fix linting * feat: put the @search.score, if it exists, into the Document.score * fix: put back accidentally removed # noqa: B008 * fix: running linter * Update tests * Update document_store * Fix linting * PR comments and test updates * Fixes --------- Co-authored-by: Amna Mubashar <[email protected]>
1 parent e49c7d0 commit 4e7092d

File tree

6 files changed

+80
-24
lines changed

6 files changed

+80
-24
lines changed

integrations/azure_ai_search/src/haystack_integrations/document_stores/azure_ai_search/document_store.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
from typing import Any, Dict, List, Optional, Type, Union
77

88
from azure.core.credentials import AzureKeyCredential
9-
from azure.core.exceptions import ClientAuthenticationError, HttpResponseError, ResourceNotFoundError
9+
from azure.core.exceptions import (
10+
ClientAuthenticationError,
11+
HttpResponseError,
12+
ResourceNotFoundError,
13+
)
1014
from azure.core.pipeline.policies import UserAgentPolicy
1115
from azure.identity import DefaultAzureCredential
1216
from azure.search.documents import SearchClient
@@ -67,7 +71,10 @@
6771

6872
DEFAULT_VECTOR_SEARCH = VectorSearch(
6973
profiles=[
70-
VectorSearchProfile(name="default-vector-config", algorithm_configuration_name="cosine-algorithm-config")
74+
VectorSearchProfile(
75+
name="default-vector-config",
76+
algorithm_configuration_name="cosine-algorithm-config",
77+
)
7178
],
7279
algorithms=[
7380
HnswAlgorithmConfiguration(
@@ -94,6 +101,7 @@ def __init__(
94101
embedding_dimension: int = 768,
95102
metadata_fields: Optional[Dict[str, Union[SearchField, type]]] = None,
96103
vector_search_configuration: Optional[VectorSearch] = None,
104+
include_search_metadata: bool = False,
97105
**index_creation_kwargs: Any,
98106
):
99107
"""
@@ -123,6 +131,10 @@ def __init__(
123131
:param vector_search_configuration: Configuration option related to vector search.
124132
Default configuration uses the HNSW algorithm with cosine similarity to handle vector searches.
125133
134+
:param include_search_metadata: Whether to include Azure AI Search metadata fields
135+
in the returned documents. When set to True, the `meta` field of the returned
136+
documents will contain the @search.score, @search.reranker_score, @search.highlights,
137+
@search.captions, and other fields returned by Azure AI Search.
126138
:param index_creation_kwargs: Optional keyword parameters to be passed to `SearchIndex` class
127139
during index creation. Some of the supported parameters:
128140
- `semantic_search`: Defines semantic configuration of the search index. This parameter is needed
@@ -143,6 +155,7 @@ def __init__(
143155
self._dummy_vector = [-10.0] * self._embedding_dimension
144156
self._metadata_fields = self._normalize_metadata_index_fields(metadata_fields)
145157
self._vector_search_configuration = vector_search_configuration or DEFAULT_VECTOR_SEARCH
158+
self._include_search_metadata = include_search_metadata
146159
self._index_creation_kwargs = index_creation_kwargs
147160

148161
@property
@@ -256,7 +269,9 @@ def _create_index(self) -> None:
256269
self._index_client.create_index(index)
257270

258271
@staticmethod
259-
def _serialize_index_creation_kwargs(index_creation_kwargs: Dict[str, Any]) -> Dict[str, Any]:
272+
def _serialize_index_creation_kwargs(
273+
index_creation_kwargs: Dict[str, Any],
274+
) -> Dict[str, Any]:
260275
"""
261276
Serializes the index creation kwargs to a dictionary.
262277
This is needed to handle serialization of Azure AI Search classes
@@ -300,7 +315,7 @@ def to_dict(self) -> Dict[str, Any]:
300315
"""
301316
return default_to_dict(
302317
self,
303-
azure_endpoint=self._azure_endpoint.to_dict() if self._azure_endpoint else None,
318+
azure_endpoint=(self._azure_endpoint.to_dict() if self._azure_endpoint else None),
304319
api_key=self._api_key.to_dict() if self._api_key else None,
305320
index_name=self._index_name,
306321
embedding_dimension=self._embedding_dimension,
@@ -423,19 +438,28 @@ def _convert_search_result_to_documents(self, azure_docs: List[Dict[str, Any]])
423438

424439
for azure_doc in azure_docs:
425440
embedding = azure_doc.get("embedding")
441+
score = azure_doc.get("@search.score", None)
426442
if embedding == self._dummy_vector:
427443
embedding = None
444+
meta = {}
428445

429446
# Anything besides default fields (id, content, and embedding) is considered metadata
430-
meta = {
431-
key: value
432-
for key, value in azure_doc.items()
433-
if key not in ["id", "content", "embedding"] and key in self._index_fields and value is not None
434-
}
447+
if self._include_search_metadata:
448+
meta = {key: value for key, value in azure_doc.items() if key not in ["id", "content", "embedding"]}
449+
else:
450+
meta = {
451+
key: value
452+
for key, value in azure_doc.items()
453+
if key not in ["id", "content", "embedding"] and key in self._index_fields and value is not None
454+
}
435455

436456
# Create the document with meta only if it's non-empty
437457
doc = Document(
438-
id=azure_doc["id"], content=azure_doc["content"], embedding=embedding, meta=meta if meta else {}
458+
id=azure_doc["id"],
459+
content=azure_doc["content"],
460+
embedding=embedding,
461+
meta=meta,
462+
score=score,
439463
)
440464

441465
documents.append(doc)

integrations/azure_ai_search/tests/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def document_store(request):
2828
"""
2929
index_name = f"haystack_test_{uuid.uuid4().hex}"
3030
metadata_fields = getattr(request, "param", {}).get("metadata_fields", None)
31+
include_search_metadata = getattr(request, "param", {}).get("include_search_metadata", False)
3132

3233
azure_endpoint = os.environ["AZURE_AI_SEARCH_ENDPOINT"]
3334
api_key = os.environ["AZURE_AI_SEARCH_API_KEY"]
@@ -41,6 +42,7 @@ def document_store(request):
4142
create_index=True,
4243
embedding_dimension=768,
4344
metadata_fields=metadata_fields,
45+
include_search_metadata=include_search_metadata,
4446
)
4547

4648
# Override some methods to wait for the documents to be available

integrations/azure_ai_search/tests/test_bm25_retriever.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,23 @@ def test_run(self, document_store: AzureAISearchDocumentStore):
159159
document_store.write_documents(docs)
160160
retriever = AzureAISearchBM25Retriever(document_store=document_store)
161161
res = retriever.run(query="Test document")
162-
assert res["documents"] == docs
162+
assert res["documents"][0].content == docs[0].content
163+
assert res["documents"][0].score is not None
164+
assert res["documents"][0].id == docs[0].id
165+
166+
@pytest.mark.parametrize(
167+
"document_store",
168+
[
169+
{"include_search_metadata": True},
170+
],
171+
indirect=True,
172+
)
173+
def test_run_with_search_metadata(self, document_store: AzureAISearchDocumentStore):
174+
docs = [Document(id="1", content="Test document")]
175+
document_store.write_documents(docs)
176+
retriever = AzureAISearchBM25Retriever(document_store=document_store)
177+
res = retriever.run(query="Test document")
178+
assert all(key.startswith("@search") for key in res["documents"][0].meta.keys())
163179

164180
def test_document_retrieval(self, document_store: AzureAISearchDocumentStore):
165181
docs = [

integrations/azure_ai_search/tests/test_document_store.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -231,12 +231,36 @@ def test_init(_mock_azure_search_client):
231231
assert document_store._vector_search_configuration == DEFAULT_VECTOR_SEARCH
232232

233233

234+
def _assert_documents_are_equal(received: List[Document], expected: List[Document]):
235+
"""
236+
Assert that two lists of Documents are equal.
237+
238+
This is used in every test, if a Document Store implementation has a different behaviour
239+
it should override this method. This can happen for example when the Document Store sets
240+
a score to returned Documents. Since we can't know what the score will be, we can't compare
241+
the Documents reliably.
242+
"""
243+
sorted_received = sorted(received, key=lambda doc: doc.id)
244+
sorted_expected = sorted(expected, key=lambda doc: doc.id)
245+
assert len(sorted_received) == len(sorted_expected)
246+
247+
for received_doc, expected_doc in zip(sorted_received, sorted_expected):
248+
# Compare all attributes except score
249+
assert received_doc.id == expected_doc.id
250+
assert received_doc.content == expected_doc.content
251+
assert received_doc.embedding == expected_doc.embedding
252+
assert received_doc.meta == expected_doc.meta
253+
254+
234255
@pytest.mark.integration
235256
@pytest.mark.skipif(
236257
not os.environ.get("AZURE_AI_SEARCH_ENDPOINT", None) and not os.environ.get("AZURE_AI_SEARCH_API_KEY", None),
237258
reason="Missing AZURE_AI_SEARCH_ENDPOINT or AZURE_AI_SEARCH_API_KEY.",
238259
)
239260
class TestDocumentStore(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest):
261+
def assert_documents_are_equal(self, received: List[Document], expected: List[Document]):
262+
_assert_documents_are_equal(received, expected)
263+
240264
def test_write_documents(self, document_store: AzureAISearchDocumentStore):
241265
docs = [Document(id="1")]
242266
assert document_store.write_documents(docs) == 1
@@ -345,17 +369,7 @@ def filterable_docs(self) -> List[Document]:
345369

346370
# Overriding to compare the documents with the same order
347371
def assert_documents_are_equal(self, received: List[Document], expected: List[Document]):
348-
"""
349-
Assert that two lists of Documents are equal.
350-
351-
This is used in every test, if a Document Store implementation has a different behaviour
352-
it should override this method. This can happen for example when the Document Store sets
353-
a score to returned Documents. Since we can't know what the score will be, we can't compare
354-
the Documents reliably.
355-
"""
356-
sorted_recieved = sorted(received, key=lambda doc: doc.id)
357-
sorted_expected = sorted(expected, key=lambda doc: doc.id)
358-
assert sorted_recieved == sorted_expected
372+
_assert_documents_are_equal(received, expected)
359373

360374
# Azure search index supports UTC datetime in ISO 8601 format
361375
def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs):

integrations/azure_ai_search/tests/test_embedding_retriever.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def test_run(self, document_store: AzureAISearchDocumentStore):
174174
document_store.write_documents(docs)
175175
retriever = AzureAISearchEmbeddingRetriever(document_store=document_store)
176176
res = retriever.run(query_embedding=[0.1] * 768)
177-
assert res["documents"] == docs
177+
assert res["documents"][0].id == docs[0].id
178178

179179
def test_embedding_retrieval(self, document_store: AzureAISearchDocumentStore):
180180
query_embedding = [0.1] * 768

integrations/azure_ai_search/tests/test_hybrid_retriever.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def test_run(self, document_store: AzureAISearchDocumentStore):
180180
document_store.write_documents(docs)
181181
retriever = AzureAISearchHybridRetriever(document_store=document_store)
182182
res = retriever.run(query="Test document", query_embedding=[0.1] * 768)
183-
assert res["documents"] == docs
183+
assert res["documents"][0].id == docs[0].id
184184

185185
def test_hybrid_retrieval(self, document_store: AzureAISearchDocumentStore):
186186
query_embedding = [0.1] * 768

0 commit comments

Comments
 (0)