Skip to content

Commit 30456bc

Browse files
Add more configurations for hnswlib
Signed-off-by: anna-charlotte <[email protected]>
1 parent de262f9 commit 30456bc

File tree

2 files changed

+101
-4
lines changed

2 files changed

+101
-4
lines changed

langchain/vectorstores/hnsw_lib.py

+53-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ def __init__(
2424
work_dir: str,
2525
n_dim: int,
2626
dist_metric: str = "cosine",
27+
max_elements: int = 1024,
28+
index: bool = True,
29+
ef_construction: int = 200,
30+
ef: int = 10,
31+
M: int = 16,
32+
allow_replace_deleted: bool = True,
33+
num_threads: int = 1,
2734
) -> None:
2835
"""Initialize HnswLib store.
2936
@@ -33,6 +40,19 @@ def __init__(
3340
n_dim (int): dimension of an embedding.
3441
dist_metric (str): Distance metric for HnswLib can be one of: "cosine",
3542
"ip", and "l2". Defaults to "cosine".
43+
max_elements (int): Maximum number of vectors that can be stored.
44+
Defaults to 1024.
45+
index (bool): Whether an index should be built for this field.
46+
Defaults to True.
47+
ef_construction (int): defines a construction time/accuracy trade-off.
48+
Defaults to 200.
49+
ef (int): parameter controlling query time/accuracy trade-off.
50+
Defaults to 10.
51+
M (int): parameter that defines the maximum number of outgoing
52+
connections in the graph. Defaults to 16.
53+
allow_replace_deleted (bool): Enables replacing of deleted elements
54+
with new added ones. Defaults to True.
55+
num_threads (int): Sets the number of cpu threads to use. Defaults to 1.
3656
"""
3757
_check_docarray_import()
3858
from docarray.index import HnswDocumentIndex
@@ -45,7 +65,19 @@ def __init__(
4565
"Please install it with `pip install \"langchain[hnswlib]\"`."
4666
)
4767

48-
doc_cls = self._get_doc_cls({"dim": n_dim, "space": dist_metric})
68+
doc_cls = self._get_doc_cls(
69+
{
70+
"dim": n_dim,
71+
"space": dist_metric,
72+
"max_elements": max_elements,
73+
"index": index,
74+
"ef_construction": ef_construction,
75+
"ef": ef,
76+
"M": M,
77+
"allow_replace_deleted": allow_replace_deleted,
78+
"num_threads": num_threads,
79+
}
80+
)
4981
doc_index = HnswDocumentIndex[doc_cls](work_dir=work_dir)
5082
super().__init__(doc_index, embedding)
5183

@@ -58,6 +90,13 @@ def from_texts(
5890
work_dir: str = None,
5991
n_dim: int = None,
6092
dist_metric: str = "cosine",
93+
max_elements: int = 1024,
94+
index: bool = True,
95+
ef_construction: int = 200,
96+
ef: int = 10,
97+
M: int = 16,
98+
allow_replace_deleted: bool = True,
99+
num_threads: int = 1,
61100
) -> HnswLib:
62101
"""Create an HnswLib store and insert data.
63102
@@ -70,6 +109,19 @@ def from_texts(
70109
n_dim (int): dimension of an embedding.
71110
dist_metric (str): Distance metric for HnswLib can be one of: "cosine",
72111
"ip", and "l2". Defaults to "cosine".
112+
max_elements (int): Maximum number of vectors that can be stored.
113+
Defaults to 1024.
114+
index (bool): Whether an index should be built for this field.
115+
Defaults to True.
116+
ef_construction (int): defines a construction time/accuracy trade-off.
117+
Defaults to 200.
118+
ef (int): parameter controlling query time/accuracy trade-off.
119+
Defaults to 10.
120+
M (int): parameter that defines the maximum number of outgoing
121+
connections in the graph. Defaults to 16.
122+
allow_replace_deleted (bool): Enables replacing of deleted elements
123+
with new added ones. Defaults to True.
124+
num_threads (int): Sets the number of cpu threads to use. Defaults to 1.
73125
74126
Returns:
75127
HnswLib Vector Store

tests/integration_tests/vectorstores/test_hnsw_lib.py

+48-3
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_hnswlib_vec_store_add_texts(tmp_path) -> None:
3636
assert docsearch.doc_index.num_docs() == 3
3737

3838

39-
@pytest.mark.parametrize('metric', ['cosine', 'ip', 'l2'])
39+
@pytest.mark.parametrize('metric', ['cosine', 'l2'])
4040
def test_sim_search(metric, tmp_path) -> None:
4141
"""Test end to end construction and simple similarity search."""
4242
texts = ["foo", "bar", "baz"]
@@ -45,12 +45,35 @@ def test_sim_search(metric, tmp_path) -> None:
4545
FakeEmbeddings(),
4646
work_dir=str(tmp_path),
4747
n_dim=10,
48+
dist_metric=metric,
49+
)
50+
output = hnswlib_vec_store.similarity_search("foo", k=1)
51+
assert output == [Document(page_content="foo")]
52+
53+
54+
@pytest.mark.parametrize('metric', ['cosine', 'l2'])
55+
def test_sim_search_all_configurations(metric, tmp_path) -> None:
56+
"""Test end to end construction and simple similarity search."""
57+
texts = ["foo", "bar", "baz"]
58+
hnswlib_vec_store = HnswLib.from_texts(
59+
texts,
60+
FakeEmbeddings(),
61+
work_dir=str(tmp_path),
62+
dist_metric=metric,
63+
n_dim=10,
64+
max_elements=8,
65+
index=False,
66+
ef_construction=300,
67+
ef=20,
68+
M=8,
69+
allow_replace_deleted=False,
70+
num_threads=2,
4871
)
4972
output = hnswlib_vec_store.similarity_search("foo", k=1)
5073
assert output == [Document(page_content="foo")]
5174

5275

53-
@pytest.mark.parametrize('metric', ['cosine', 'ip', 'l2'])
76+
@pytest.mark.parametrize('metric', ['cosine', 'l2'])
5477
def test_sim_search_by_vector(metric, tmp_path) -> None:
5578
"""Test end to end construction and similarity search by vector."""
5679
texts = ["foo", "bar", "baz"]
@@ -59,14 +82,15 @@ def test_sim_search_by_vector(metric, tmp_path) -> None:
5982
FakeEmbeddings(),
6083
work_dir=str(tmp_path),
6184
n_dim=10,
85+
dist_metric=metric,
6286
)
6387
embedding = [1.0] * 10
6488
output = hnswlib_vec_store.similarity_search_by_vector(embedding, k=1)
6589

6690
assert output == [Document(page_content="bar")]
6791

6892

69-
@pytest.mark.parametrize('metric', ['cosine', 'ip', 'l2'])
93+
@pytest.mark.parametrize('metric', ['cosine', 'l2'])
7094
def test_sim_search_with_score(metric, tmp_path) -> None:
7195
"""Test end to end construction and similarity search with score."""
7296
texts = ["foo", "bar", "baz"]
@@ -75,6 +99,7 @@ def test_sim_search_with_score(metric, tmp_path) -> None:
7599
FakeEmbeddings(),
76100
work_dir=str(tmp_path),
77101
n_dim=10,
102+
dist_metric=metric,
78103
)
79104
output = hnswlib_vec_store.similarity_search_with_score("foo", k=1)
80105
assert len(output) == 1
@@ -84,6 +109,26 @@ def test_sim_search_with_score(metric, tmp_path) -> None:
84109
assert np.isclose(out_score, 0.0, atol=1.e-6)
85110

86111

112+
def test_sim_search_with_score_for_ip_metric(tmp_path) -> None:
113+
"""
114+
Test end to end construction and similarity search with score for ip
115+
(inner-product) metric.
116+
"""
117+
texts = ["foo", "bar", "baz"]
118+
hnswlib_vec_store = HnswLib.from_texts(
119+
texts,
120+
FakeEmbeddings(),
121+
work_dir=str(tmp_path),
122+
n_dim=10,
123+
dist_metric='ip',
124+
)
125+
output = hnswlib_vec_store.similarity_search_with_score("foo", k=3)
126+
assert len(output) == 3
127+
128+
for result in output:
129+
assert result[1] == -8.0
130+
131+
87132
@pytest.mark.parametrize('metric', ['cosine', 'l2'])
88133
def test_max_marginal_relevance_search(metric, tmp_path) -> None:
89134
"""Test MRR search."""

0 commit comments

Comments
 (0)