1
1
"""Wrapper around in-memory DocArray store."""
2
2
from __future__ import annotations
3
3
4
- from operator import itemgetter
5
4
from typing import List , Optional , Any , Tuple , Iterable , Type , Callable , Sequence , TYPE_CHECKING
5
+ from docarray .typing import NdArray
6
6
7
7
from langchain .embeddings .base import Embeddings
8
- from langchain .schema import Document
9
- from langchain .vectorstores import VectorStore
10
8
from langchain .vectorstores .base import VST
11
- from langchain .vectorstores .utils import maximal_marginal_relevance
12
-
13
- from docarray import BaseDoc
14
- from docarray .typing import NdArray
9
+ from langchain .vectorstores .vector_store_from_doc_index import VecStoreFromDocIndex , _check_docarray_import
15
10
16
11
17
- class HnswLib (VectorStore ):
12
+ class HnswLib (VecStoreFromDocIndex ):
18
13
"""Wrapper around HnswLib storage.
19
14
20
- To use it, you should have the ``docarray`` package with version >=0.30 .0 installed.
15
+ To use it, you should have the ``docarray`` package with version >=0.31 .0 installed.
21
16
"""
22
17
def __init__ (
23
18
self ,
24
- work_dir : str ,
25
- n_dim : int ,
26
19
texts : List [str ],
27
20
embedding : Embeddings ,
21
+ work_dir : str ,
22
+ n_dim : int ,
28
23
metadatas : Optional [List [dict ]],
29
- sim_metric : str = 'cosine' ,
30
- kwargs : dict = None
24
+ dist_metric : str = 'cosine' ,
25
+ ** kwargs ,
31
26
) -> None :
32
- """Initialize HnswLib store."""
33
- try :
34
- import docarray
35
- da_version = docarray .__version__ .split ('.' )
36
- if int (da_version [0 ]) == 0 and int (da_version [1 ]) <= 21 :
37
- raise ValueError (
38
- f'To use the HnswLib VectorStore the docarray version >=0.30.0 is expected, '
39
- f'received: { docarray .__version__ } .'
40
- f'To upgrade, please run: `pip install -U docarray`.'
41
- )
42
- else :
43
- from docarray import DocList
44
- from docarray .index import HnswDocumentIndex
45
- except ImportError :
46
- raise ImportError (
47
- "Could not import docarray python package. "
48
- "Please install it with `pip install -U docarray`."
49
- )
27
+ """Initialize HnswLib store.
28
+
29
+ Args:
30
+ texts (List[str]): Text data.
31
+ embedding (Embeddings): Embedding function.
32
+ metadatas (Optional[List[dict]]): Metadata for each text if it exists.
33
+ Defaults to None.
34
+ work_dir (str): path to the location where all the data will be stored.
35
+ n_dim (int): dimension of an embedding.
36
+ dist_metric (str): Distance metric for HnswLib can be one of: 'cosine',
37
+ 'ip', and 'l2'. Defaults to 'cosine'.
38
+ """
39
+ _check_docarray_import ()
40
+ from docarray .index import HnswDocumentIndex
41
+
50
42
try :
51
43
import google .protobuf
52
44
except ImportError :
@@ -55,27 +47,13 @@ def __init__(
55
47
"Please install it with `pip install -U protobuf`."
56
48
)
57
49
58
- if metadatas is None :
59
- metadatas = [{} for _ in range (len (texts ))]
60
-
61
- self .embedding = embedding
62
-
63
- self .doc_cls = self ._get_doc_cls (n_dim , sim_metric )
64
- self .doc_index = HnswDocumentIndex [self .doc_cls ](work_dir = work_dir )
65
- embeddings = self .embedding .embed_documents (texts )
66
- docs = DocList [self .doc_cls ](
67
- [
68
- self .doc_cls (
69
- text = t ,
70
- embedding = e ,
71
- metadata = m ,
72
- ) for t , m , e in zip (texts , metadatas , embeddings )
73
- ]
74
- )
75
- self .doc_index .index (docs )
50
+ doc_cls = self ._get_doc_cls (n_dim , dist_metric )
51
+ doc_index = HnswDocumentIndex [doc_cls ](work_dir = work_dir )
52
+ super ().__init__ (doc_index , texts , embedding , metadatas )
76
53
77
54
@staticmethod
78
55
def _get_doc_cls (n_dim : int , sim_metric : str ):
56
+ from docarray import BaseDoc
79
57
from pydantic import Field
80
58
81
59
class DocArrayDoc (BaseDoc ):
@@ -93,6 +71,7 @@ def from_texts(
93
71
metadatas : Optional [List [dict ]] = None ,
94
72
work_dir : str = None ,
95
73
n_dim : int = None ,
74
+ dist_metric : str = 'cosine' ,
96
75
** kwargs : Any
97
76
) -> HnswLib :
98
77
@@ -107,129 +86,6 @@ def from_texts(
107
86
texts = texts ,
108
87
embedding = embedding ,
109
88
metadatas = metadatas ,
110
- kwargs = kwargs
89
+ dist_metric = dist_metric ,
90
+ kwargs = kwargs ,
111
91
)
112
-
113
- def add_texts (
114
- self ,
115
- texts : Iterable [str ],
116
- metadatas : Optional [List [dict ]] = None ,
117
- ** kwargs : Any
118
- ) -> List [str ]:
119
- """Run more texts through the embeddings and add to the vectorstore.
120
-
121
- Args:
122
- texts: Iterable of strings to add to the vectorstore.
123
- metadatas: Optional list of metadatas associated with the texts.
124
-
125
- Returns:
126
- List of ids from adding the texts into the vectorstore.
127
- """
128
- if metadatas is None :
129
- metadatas = [{} for _ in range (len (list (texts )))]
130
-
131
- ids = []
132
- embeddings = self .embedding .embed_documents (texts )
133
- for t , m , e in zip (texts , metadatas , embeddings ):
134
- doc = self .doc_cls (
135
- text = t ,
136
- embedding = e ,
137
- metadata = m
138
- )
139
- self .doc_index .index (doc )
140
- ids .append (doc .id ) # TODO return index of self.docs ?
141
-
142
- return ids
143
-
144
- def similarity_search_with_score (
145
- self , query : str , k : int = 4 , ** kwargs : Any
146
- ) -> List [Tuple [Document , float ]]:
147
- """Return docs most similar to query.
148
-
149
- Args:
150
- query: Text to look up documents similar to.
151
- k: Number of Documents to return. Defaults to 4.
152
-
153
- Returns:
154
- List of Documents most similar to the query and score for each.
155
- """
156
- query_embedding = self .embedding .embed_query (query )
157
- query_embedding = [1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 1. , 0. ]
158
- print (f"query_embedding = { query_embedding } " )
159
- query_doc = self .doc_cls (embedding = query_embedding )
160
- docs , scores = self .doc_index .find (query_doc , search_field = 'embedding' , limit = k )
161
-
162
- result = [(Document (page_content = doc .text ), score ) for doc , score in zip (docs , scores )]
163
- return result
164
-
165
- def similarity_search (
166
- self , query : str , k : int = 4 , ** kwargs : Any
167
- ) -> List [Document ]:
168
- """Return docs most similar to query.
169
-
170
- Args:
171
- query: Text to look up documents similar to.
172
- k: Number of Documents to return. Defaults to 4.
173
-
174
- Returns:
175
- List of Documents most similar to the query.
176
- """
177
- results = self .similarity_search_with_score (query , k )
178
- return list (map (itemgetter (0 ), results ))
179
-
180
- def _similarity_search_with_relevance_scores (
181
- self ,
182
- query : str ,
183
- k : int = 4 ,
184
- ** kwargs : Any ,
185
- ) -> List [Tuple [Document , float ]]:
186
- """Return docs and relevance scores, normalized on a scale from 0 to 1.
187
-
188
- 0 is dissimilar, 1 is most similar.
189
- """
190
- raise NotImplementedError
191
-
192
- def similarity_search_by_vector (self , embedding : List [float ], k : int = 4 , ** kwargs : Any ) -> List [Document ]:
193
- """Return docs most similar to embedding vector.
194
-
195
- Args:
196
- embedding: Embedding to look up documents similar to.
197
- k: Number of Documents to return. Defaults to 4.
198
-
199
- Returns:
200
- List of Documents most similar to the query vector.
201
- """
202
-
203
- query_doc = self .doc_cls (embedding = embedding )
204
- docs = self .doc_index .find (query_doc , search_field = 'embedding' , limit = k ).documents
205
-
206
- result = [Document (page_content = doc .text ) for doc in docs ]
207
- return result
208
-
209
- def max_marginal_relevance_search (
210
- self , query : str , k : int = 4 , fetch_k : int = 20 , ** kwargs : Any
211
- ) -> List [Document ]:
212
- """Return docs selected using the maximal marginal relevance.
213
-
214
- Maximal marginal relevance optimizes for similarity to query AND diversity
215
- among selected documents.
216
-
217
- Args:
218
- query: Text to look up documents similar to.
219
- k: Number of Documents to return. Defaults to 4.
220
- fetch_k: Number of Documents to fetch to pass to MMR algorithm.
221
-
222
- Returns:
223
- List of Documents selected by maximal marginal relevance.
224
- """
225
- query_embedding = self .embedding .embed_query (query )
226
- query_doc = self .doc_cls (embedding = query_embedding )
227
-
228
- docs , scores = self .doc_index .find (query_doc , search_field = 'embedding' , limit = fetch_k )
229
-
230
- embeddings = [emb for emb in docs .emb ]
231
-
232
- mmr_selected = maximal_marginal_relevance (query_embedding , embeddings , k = k )
233
- results = [Document (page_content = self .doc_index [idx ].text ) for idx in mmr_selected ]
234
- return results
235
-
0 commit comments