Skip to content

Commit 41433e6

Browse files
feat: add in-memory and hnswlib vectorstore
Signed-off-by: anna-charlotte <[email protected]>
1 parent 0cf934c commit 41433e6

File tree

6 files changed

+712
-87
lines changed

6 files changed

+712
-87
lines changed

langchain/vectorstores/hnsw_lib.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
"""Wrapper around in-memory DocArray store."""
2+
from __future__ import annotations
3+
4+
from operator import itemgetter
5+
from typing import List, Optional, Any, Tuple, Iterable, Type, Callable, Sequence, TYPE_CHECKING
6+
7+
from langchain.embeddings.base import Embeddings
8+
from langchain.schema import Document
9+
from langchain.vectorstores import VectorStore
10+
from langchain.vectorstores.base import VST
11+
from langchain.vectorstores.utils import maximal_marginal_relevance
12+
13+
from docarray import BaseDoc
14+
from docarray.typing import NdArray
15+
16+
17+
class HnswLib(VectorStore):
18+
"""Wrapper around HnswLib storage.
19+
20+
To use it, you should have the ``docarray`` package with version >=0.30.0 installed.
21+
"""
22+
def __init__(
23+
self,
24+
work_dir: str,
25+
n_dim: int,
26+
texts: List[str],
27+
embedding: Embeddings,
28+
metadatas: Optional[List[dict]],
29+
sim_metric: str = 'cosine',
30+
kwargs: dict = None
31+
) -> None:
32+
"""Initialize HnswLib store."""
33+
try:
34+
import docarray
35+
da_version = docarray.__version__.split('.')
36+
if int(da_version[0]) == 0 and int(da_version[1]) <= 21:
37+
raise ValueError(
38+
f'To use the HnswLib VectorStore the docarray version >=0.30.0 is expected, '
39+
f'received: {docarray.__version__}.'
40+
f'To upgrade, please run: `pip install -U docarray`.'
41+
)
42+
else:
43+
from docarray import DocList
44+
from docarray.index import HnswDocumentIndex
45+
except ImportError:
46+
raise ImportError(
47+
"Could not import docarray python package. "
48+
"Please install it with `pip install -U docarray`."
49+
)
50+
try:
51+
import google.protobuf
52+
except ImportError:
53+
raise ImportError(
54+
"Could not import protobuf python package. "
55+
"Please install it with `pip install -U protobuf`."
56+
)
57+
58+
if metadatas is None:
59+
metadatas = [{} for _ in range(len(texts))]
60+
61+
self.embedding = embedding
62+
63+
self.doc_cls = self._get_doc_cls(n_dim, sim_metric)
64+
self.doc_index = HnswDocumentIndex[self.doc_cls](work_dir=work_dir)
65+
embeddings = self.embedding.embed_documents(texts)
66+
docs = DocList[self.doc_cls](
67+
[
68+
self.doc_cls(
69+
text=t,
70+
embedding=e,
71+
metadata=m,
72+
) for t, m, e in zip(texts, metadatas, embeddings)
73+
]
74+
)
75+
self.doc_index.index(docs)
76+
77+
@staticmethod
78+
def _get_doc_cls(n_dim: int, sim_metric: str):
79+
from pydantic import Field
80+
81+
class DocArrayDoc(BaseDoc):
82+
text: Optional[str]
83+
embedding: Optional[NdArray] = Field(dim=n_dim, space=sim_metric)
84+
metadata: Optional[dict]
85+
86+
return DocArrayDoc
87+
88+
@classmethod
89+
def from_texts(
90+
cls: Type[VST],
91+
texts: List[str],
92+
embedding: Embeddings,
93+
metadatas: Optional[List[dict]] = None,
94+
work_dir: str = None,
95+
n_dim: int = None,
96+
**kwargs: Any
97+
) -> HnswLib:
98+
99+
if work_dir is None:
100+
raise ValueError('`work_dir` parameter hs not been set.')
101+
if n_dim is None:
102+
raise ValueError('`n_dim` parameter has not been set.')
103+
104+
return cls(
105+
work_dir=work_dir,
106+
n_dim=n_dim,
107+
texts=texts,
108+
embedding=embedding,
109+
metadatas=metadatas,
110+
kwargs=kwargs
111+
)
112+
113+
def add_texts(
114+
self,
115+
texts: Iterable[str],
116+
metadatas: Optional[List[dict]] = None,
117+
**kwargs: Any
118+
) -> List[str]:
119+
"""Run more texts through the embeddings and add to the vectorstore.
120+
121+
Args:
122+
texts: Iterable of strings to add to the vectorstore.
123+
metadatas: Optional list of metadatas associated with the texts.
124+
125+
Returns:
126+
List of ids from adding the texts into the vectorstore.
127+
"""
128+
if metadatas is None:
129+
metadatas = [{} for _ in range(len(list(texts)))]
130+
131+
ids = []
132+
embeddings = self.embedding.embed_documents(texts)
133+
for t, m, e in zip(texts, metadatas, embeddings):
134+
doc = self.doc_cls(
135+
text=t,
136+
embedding=e,
137+
metadata=m
138+
)
139+
self.doc_index.index(doc)
140+
ids.append(doc.id) # TODO return index of self.docs ?
141+
142+
return ids
143+
144+
def similarity_search_with_score(
145+
self, query: str, k: int = 4, **kwargs: Any
146+
) -> List[Tuple[Document, float]]:
147+
"""Return docs most similar to query.
148+
149+
Args:
150+
query: Text to look up documents similar to.
151+
k: Number of Documents to return. Defaults to 4.
152+
153+
Returns:
154+
List of Documents most similar to the query and score for each.
155+
"""
156+
query_embedding = self.embedding.embed_query(query)
157+
query_embedding = [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.]
158+
print(f"query_embedding = {query_embedding}")
159+
query_doc = self.doc_cls(embedding=query_embedding)
160+
docs, scores = self.doc_index.find(query_doc, search_field='embedding', limit=k)
161+
162+
result = [(Document(page_content=doc.text), score) for doc, score in zip(docs, scores)]
163+
return result
164+
165+
def similarity_search(
166+
self, query: str, k: int = 4, **kwargs: Any
167+
) -> List[Document]:
168+
"""Return docs most similar to query.
169+
170+
Args:
171+
query: Text to look up documents similar to.
172+
k: Number of Documents to return. Defaults to 4.
173+
174+
Returns:
175+
List of Documents most similar to the query.
176+
"""
177+
results = self.similarity_search_with_score(query, k)
178+
return list(map(itemgetter(0), results))
179+
180+
def _similarity_search_with_relevance_scores(
181+
self,
182+
query: str,
183+
k: int = 4,
184+
**kwargs: Any,
185+
) -> List[Tuple[Document, float]]:
186+
"""Return docs and relevance scores, normalized on a scale from 0 to 1.
187+
188+
0 is dissimilar, 1 is most similar.
189+
"""
190+
raise NotImplementedError
191+
192+
def similarity_search_by_vector(self, embedding: List[float], k: int = 4, **kwargs: Any) -> List[Document]:
193+
"""Return docs most similar to embedding vector.
194+
195+
Args:
196+
embedding: Embedding to look up documents similar to.
197+
k: Number of Documents to return. Defaults to 4.
198+
199+
Returns:
200+
List of Documents most similar to the query vector.
201+
"""
202+
203+
query_doc = self.doc_cls(embedding=embedding)
204+
docs = self.doc_index.find(query_doc, search_field='embedding', limit=k).documents
205+
206+
result = [Document(page_content=doc.text) for doc in docs]
207+
return result
208+
209+
def max_marginal_relevance_search(
210+
self, query: str, k: int = 4, fetch_k: int = 20, **kwargs: Any
211+
) -> List[Document]:
212+
"""Return docs selected using the maximal marginal relevance.
213+
214+
Maximal marginal relevance optimizes for similarity to query AND diversity
215+
among selected documents.
216+
217+
Args:
218+
query: Text to look up documents similar to.
219+
k: Number of Documents to return. Defaults to 4.
220+
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
221+
222+
Returns:
223+
List of Documents selected by maximal marginal relevance.
224+
"""
225+
query_embedding = self.embedding.embed_query(query)
226+
query_doc = self.doc_cls(embedding=query_embedding)
227+
228+
docs, scores = self.doc_index.find(query_doc, search_field='embedding', limit=fetch_k)
229+
230+
embeddings = [emb for emb in docs.emb]
231+
232+
mmr_selected = maximal_marginal_relevance(query_embedding, embeddings, k=k)
233+
results = [Document(page_content=self.doc_index[idx].text) for idx in mmr_selected]
234+
return results
235+

0 commit comments

Comments
 (0)