Skip to content

Commit b687fd4

Browse files
refactor: use abtract VecStoreFromDocIndex for in memory and hnswlib implementation
Signed-off-by: anna-charlotte <[email protected]>
1 parent 41433e6 commit b687fd4

File tree

5 files changed

+448
-351
lines changed

5 files changed

+448
-351
lines changed

langchain/vectorstores/hnsw_lib.py

Lines changed: 30 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,44 @@
11
"""Wrapper around in-memory DocArray store."""
22
from __future__ import annotations
33

4-
from operator import itemgetter
54
from typing import List, Optional, Any, Tuple, Iterable, Type, Callable, Sequence, TYPE_CHECKING
5+
from docarray.typing import NdArray
66

77
from langchain.embeddings.base import Embeddings
8-
from langchain.schema import Document
9-
from langchain.vectorstores import VectorStore
108
from langchain.vectorstores.base import VST
11-
from langchain.vectorstores.utils import maximal_marginal_relevance
12-
13-
from docarray import BaseDoc
14-
from docarray.typing import NdArray
9+
from langchain.vectorstores.vector_store_from_doc_index import VecStoreFromDocIndex, _check_docarray_import
1510

1611

17-
class HnswLib(VectorStore):
12+
class HnswLib(VecStoreFromDocIndex):
1813
"""Wrapper around HnswLib storage.
1914
20-
To use it, you should have the ``docarray`` package with version >=0.30.0 installed.
15+
To use it, you should have the ``docarray`` package with version >=0.31.0 installed.
2116
"""
2217
def __init__(
2318
self,
24-
work_dir: str,
25-
n_dim: int,
2619
texts: List[str],
2720
embedding: Embeddings,
21+
work_dir: str,
22+
n_dim: int,
2823
metadatas: Optional[List[dict]],
29-
sim_metric: str = 'cosine',
30-
kwargs: dict = None
24+
dist_metric: str = 'cosine',
25+
**kwargs,
3126
) -> None:
32-
"""Initialize HnswLib store."""
33-
try:
34-
import docarray
35-
da_version = docarray.__version__.split('.')
36-
if int(da_version[0]) == 0 and int(da_version[1]) <= 21:
37-
raise ValueError(
38-
f'To use the HnswLib VectorStore the docarray version >=0.30.0 is expected, '
39-
f'received: {docarray.__version__}.'
40-
f'To upgrade, please run: `pip install -U docarray`.'
41-
)
42-
else:
43-
from docarray import DocList
44-
from docarray.index import HnswDocumentIndex
45-
except ImportError:
46-
raise ImportError(
47-
"Could not import docarray python package. "
48-
"Please install it with `pip install -U docarray`."
49-
)
27+
"""Initialize HnswLib store.
28+
29+
Args:
30+
texts (List[str]): Text data.
31+
embedding (Embeddings): Embedding function.
32+
metadatas (Optional[List[dict]]): Metadata for each text if it exists.
33+
Defaults to None.
34+
work_dir (str): path to the location where all the data will be stored.
35+
n_dim (int): dimension of an embedding.
36+
dist_metric (str): Distance metric for HnswLib can be one of: 'cosine',
37+
'ip', and 'l2'. Defaults to 'cosine'.
38+
"""
39+
_check_docarray_import()
40+
from docarray.index import HnswDocumentIndex
41+
5042
try:
5143
import google.protobuf
5244
except ImportError:
@@ -55,27 +47,13 @@ def __init__(
5547
"Please install it with `pip install -U protobuf`."
5648
)
5749

58-
if metadatas is None:
59-
metadatas = [{} for _ in range(len(texts))]
60-
61-
self.embedding = embedding
62-
63-
self.doc_cls = self._get_doc_cls(n_dim, sim_metric)
64-
self.doc_index = HnswDocumentIndex[self.doc_cls](work_dir=work_dir)
65-
embeddings = self.embedding.embed_documents(texts)
66-
docs = DocList[self.doc_cls](
67-
[
68-
self.doc_cls(
69-
text=t,
70-
embedding=e,
71-
metadata=m,
72-
) for t, m, e in zip(texts, metadatas, embeddings)
73-
]
74-
)
75-
self.doc_index.index(docs)
50+
doc_cls = self._get_doc_cls(n_dim, dist_metric)
51+
doc_index = HnswDocumentIndex[doc_cls](work_dir=work_dir)
52+
super().__init__(doc_index, texts, embedding, metadatas)
7653

7754
@staticmethod
7855
def _get_doc_cls(n_dim: int, sim_metric: str):
56+
from docarray import BaseDoc
7957
from pydantic import Field
8058

8159
class DocArrayDoc(BaseDoc):
@@ -93,6 +71,7 @@ def from_texts(
9371
metadatas: Optional[List[dict]] = None,
9472
work_dir: str = None,
9573
n_dim: int = None,
74+
dist_metric: str = 'cosine',
9675
**kwargs: Any
9776
) -> HnswLib:
9877

@@ -107,129 +86,6 @@ def from_texts(
10786
texts=texts,
10887
embedding=embedding,
10988
metadatas=metadatas,
110-
kwargs=kwargs
89+
dist_metric=dist_metric,
90+
kwargs=kwargs,
11191
)
112-
113-
def add_texts(
114-
self,
115-
texts: Iterable[str],
116-
metadatas: Optional[List[dict]] = None,
117-
**kwargs: Any
118-
) -> List[str]:
119-
"""Run more texts through the embeddings and add to the vectorstore.
120-
121-
Args:
122-
texts: Iterable of strings to add to the vectorstore.
123-
metadatas: Optional list of metadatas associated with the texts.
124-
125-
Returns:
126-
List of ids from adding the texts into the vectorstore.
127-
"""
128-
if metadatas is None:
129-
metadatas = [{} for _ in range(len(list(texts)))]
130-
131-
ids = []
132-
embeddings = self.embedding.embed_documents(texts)
133-
for t, m, e in zip(texts, metadatas, embeddings):
134-
doc = self.doc_cls(
135-
text=t,
136-
embedding=e,
137-
metadata=m
138-
)
139-
self.doc_index.index(doc)
140-
ids.append(doc.id) # TODO return index of self.docs ?
141-
142-
return ids
143-
144-
def similarity_search_with_score(
145-
self, query: str, k: int = 4, **kwargs: Any
146-
) -> List[Tuple[Document, float]]:
147-
"""Return docs most similar to query.
148-
149-
Args:
150-
query: Text to look up documents similar to.
151-
k: Number of Documents to return. Defaults to 4.
152-
153-
Returns:
154-
List of Documents most similar to the query and score for each.
155-
"""
156-
query_embedding = self.embedding.embed_query(query)
157-
query_embedding = [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.]
158-
print(f"query_embedding = {query_embedding}")
159-
query_doc = self.doc_cls(embedding=query_embedding)
160-
docs, scores = self.doc_index.find(query_doc, search_field='embedding', limit=k)
161-
162-
result = [(Document(page_content=doc.text), score) for doc, score in zip(docs, scores)]
163-
return result
164-
165-
def similarity_search(
166-
self, query: str, k: int = 4, **kwargs: Any
167-
) -> List[Document]:
168-
"""Return docs most similar to query.
169-
170-
Args:
171-
query: Text to look up documents similar to.
172-
k: Number of Documents to return. Defaults to 4.
173-
174-
Returns:
175-
List of Documents most similar to the query.
176-
"""
177-
results = self.similarity_search_with_score(query, k)
178-
return list(map(itemgetter(0), results))
179-
180-
def _similarity_search_with_relevance_scores(
181-
self,
182-
query: str,
183-
k: int = 4,
184-
**kwargs: Any,
185-
) -> List[Tuple[Document, float]]:
186-
"""Return docs and relevance scores, normalized on a scale from 0 to 1.
187-
188-
0 is dissimilar, 1 is most similar.
189-
"""
190-
raise NotImplementedError
191-
192-
def similarity_search_by_vector(self, embedding: List[float], k: int = 4, **kwargs: Any) -> List[Document]:
193-
"""Return docs most similar to embedding vector.
194-
195-
Args:
196-
embedding: Embedding to look up documents similar to.
197-
k: Number of Documents to return. Defaults to 4.
198-
199-
Returns:
200-
List of Documents most similar to the query vector.
201-
"""
202-
203-
query_doc = self.doc_cls(embedding=embedding)
204-
docs = self.doc_index.find(query_doc, search_field='embedding', limit=k).documents
205-
206-
result = [Document(page_content=doc.text) for doc in docs]
207-
return result
208-
209-
def max_marginal_relevance_search(
210-
self, query: str, k: int = 4, fetch_k: int = 20, **kwargs: Any
211-
) -> List[Document]:
212-
"""Return docs selected using the maximal marginal relevance.
213-
214-
Maximal marginal relevance optimizes for similarity to query AND diversity
215-
among selected documents.
216-
217-
Args:
218-
query: Text to look up documents similar to.
219-
k: Number of Documents to return. Defaults to 4.
220-
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
221-
222-
Returns:
223-
List of Documents selected by maximal marginal relevance.
224-
"""
225-
query_embedding = self.embedding.embed_query(query)
226-
query_doc = self.doc_cls(embedding=query_embedding)
227-
228-
docs, scores = self.doc_index.find(query_doc, search_field='embedding', limit=fetch_k)
229-
230-
embeddings = [emb for emb in docs.emb]
231-
232-
mmr_selected = maximal_marginal_relevance(query_embedding, embeddings, k=k)
233-
results = [Document(page_content=self.doc_index[idx].text) for idx in mmr_selected]
234-
return results
235-

0 commit comments

Comments
 (0)