Skip to content

Commit 973da35

Browse files
authored
Fix query to align with Qdrant mixin usage (#115)
* fix: query in text_embedding_base to work with both Iterable and str as users might supply both * Fix Qdrant query to align with future usage * * refactor(text_embedding_base.py): change query parameter type from str to Union[str, Iterable[str]] in query_embed method * Update return type of query_embed method * Update return type in TextEmbeddingBase
1 parent 4696818 commit 973da35

File tree

3 files changed

+40
-31
lines changed

3 files changed

+40
-31
lines changed

docs/examples/Retrieval_with_FastEmbed.ipynb

+11-11
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
},
2222
{
2323
"cell_type": "code",
24-
"execution_count": 8,
24+
"execution_count": 1,
2525
"metadata": {},
2626
"outputs": [],
2727
"source": [
@@ -37,7 +37,7 @@
3737
},
3838
{
3939
"cell_type": "code",
40-
"execution_count": 9,
40+
"execution_count": 2,
4141
"metadata": {},
4242
"outputs": [],
4343
"source": [
@@ -58,7 +58,7 @@
5858
},
5959
{
6060
"cell_type": "code",
61-
"execution_count": 10,
61+
"execution_count": 3,
6262
"metadata": {},
6363
"outputs": [
6464
{
@@ -105,7 +105,7 @@
105105
},
106106
{
107107
"cell_type": "code",
108-
"execution_count": 11,
108+
"execution_count": 4,
109109
"metadata": {},
110110
"outputs": [],
111111
"source": [
@@ -138,7 +138,7 @@
138138
},
139139
{
140140
"cell_type": "code",
141-
"execution_count": 12,
141+
"execution_count": 5,
142142
"metadata": {},
143143
"outputs": [
144144
{
@@ -148,7 +148,7 @@
148148
"Rank 1: Maharana Pratap was a Rajput warrior king from Mewar\n",
149149
"Rank 2: Maharana Pratap is considered a symbol of Rajput resistance against foreign rule\n",
150150
"Rank 3: His legacy is celebrated in Rajasthan through festivals and monuments\n",
151-
"Rank 4: His capital was Chittorgarh, which he lost to the Mughals\n",
151+
"Rank 4: He had 11 wives and 17 sons, including Amar Singh I who succeeded him as ruler of Mewar\n",
152152
"Rank 5: He fought against the Mughal Empire led by Akbar\n"
153153
]
154154
}
@@ -166,16 +166,16 @@
166166
},
167167
{
168168
"cell_type": "code",
169-
"execution_count": 13,
169+
"execution_count": 6,
170170
"metadata": {},
171171
"outputs": [
172172
{
173173
"name": "stdout",
174174
"output_type": "stream",
175175
"text": [
176-
"Rank 1: He died in 1597 at the age of 57\n",
177-
"Rank 2: His life has been depicted in various films, TV shows, and books\n",
178-
"Rank 3: Maharana Pratap was a Rajput warrior king from Mewar\n",
176+
"Rank 1: Maharana Pratap was a Rajput warrior king from Mewar\n",
177+
"Rank 2: Maharana Pratap is considered a symbol of Rajput resistance against foreign rule\n",
178+
"Rank 3: His legacy is celebrated in Rajasthan through festivals and monuments\n",
179179
"Rank 4: He had 11 wives and 17 sons, including Amar Singh I who succeeded him as ruler of Mewar\n",
180180
"Rank 5: He fought against the Mughal Empire led by Akbar\n"
181181
]
@@ -213,7 +213,7 @@
213213
"name": "python",
214214
"nbconvert_exporter": "python",
215215
"pygments_lexer": "ipython3",
216-
"version": "3.9.17"
216+
"version": "3.11.5"
217217
},
218218
"orig_nbformat": 4
219219
},

docs/examples/Usage_With_Qdrant.ipynb

+20-13
Original file line numberDiff line numberDiff line change
@@ -102,19 +102,26 @@
102102
"execution_count": 4,
103103
"metadata": {},
104104
"outputs": [
105+
{
106+
"name": "stderr",
107+
"output_type": "stream",
108+
"text": [
109+
"100%|██████████| 77.7M/77.7M [00:05<00:00, 14.6MiB/s]\n"
110+
]
111+
},
105112
{
106113
"data": {
107114
"text/plain": [
108-
"['6e8fcf7e0ecc407b9b6bb011d169f629',\n",
109-
" 'c9d26e7e0ea741b2b1082d097796b28b',\n",
110-
" 'cf05747e7eb34d2490b1df1f8be94049',\n",
111-
" '208c197266d547a880dfb65e46738b19',\n",
112-
" '27bd985c5d6f49d68fc2cf73dac74199',\n",
113-
" 'c5e929c8837f4370818c97f63996f8ef',\n",
114-
" 'c12213c6cdac470aa2471f2d30dc4041',\n",
115-
" '974e64a7d8624f6e9824fa7b9c94f99d',\n",
116-
" '0129fae193c740eba092512d8e53ab4a',\n",
117-
" '492cad6e741e4aeebb196bd818a97d17']"
115+
"['4fa8b10c78da4b18ba0830ba8a57367a',\n",
116+
" '2eae04b515ee4e9185a9a0e6be812bba',\n",
117+
" 'c6039f88486f47f1835ae3b069c5823c',\n",
118+
" 'c2c8c51e305144d1917b373125fb4d95',\n",
119+
" '79fd23b9ec0648cdab38d1947c6b933e',\n",
120+
" '036aa200d8c3492b8a438e4f825f5e7f',\n",
121+
" 'c35c77f3ea37460a9a13723fb77b7367',\n",
122+
" '6ebccbca571b40d0ab6e83e5e0f2f562',\n",
123+
" '38048c2ccc1d4962a4f8f1bd89c8357a',\n",
124+
" 'c6b09308360140c7b4f106af3658a31e']"
118125
]
119126
},
120127
"execution_count": 4,
@@ -187,12 +194,12 @@
187194
"name": "stdout",
188195
"output_type": "stream",
189196
"text": [
190-
"[QueryResponse(id='42', embedding=None, metadata={'document': 'Qdrant has Langchain integrations', 'source': 'Langchain-docs'}, document='Qdrant has Langchain integrations', score=0.8496814051311954), QueryResponse(id='2', embedding=None, metadata={'document': 'Qdrant also has Llama Index integrations', 'source': 'Linkedin-docs'}, document='Qdrant also has Llama Index integrations', score=0.8478494193031256)]\n"
197+
"[QueryResponse(id=42, embedding=None, metadata={'document': 'Qdrant has Langchain integrations', 'source': 'Langchain-docs'}, document='Qdrant has Langchain integrations', score=0.8276550115796268), QueryResponse(id=2, embedding=None, metadata={'document': 'Qdrant also has Llama Index integrations', 'source': 'Linkedin-docs'}, document='Qdrant also has Llama Index integrations', score=0.8265536935180283)]\n"
191198
]
192199
}
193200
],
194201
"source": [
195-
"search_result = client.query(collection_name=\"demo_collection\", query_text=[\"This is a query document\"])\n",
202+
"search_result = client.query(collection_name=\"demo_collection\", query_text=\"This is a query document\")\n",
196203
"print(search_result)"
197204
]
198205
},
@@ -226,7 +233,7 @@
226233
"name": "python",
227234
"nbconvert_exporter": "python",
228235
"pygments_lexer": "ipython3",
229-
"version": "3.9.17"
236+
"version": "3.11.5"
230237
},
231238
"orig_nbformat": 4
232239
},

fastembed/text/text_embedding_base.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Union, Iterable, List, Dict, Any
1+
from typing import Any, Dict, Iterable, List, Optional, Union
22

33
import numpy as np
44

@@ -39,17 +39,19 @@ def passage_embed(self, texts: Iterable[str], **kwargs) -> Iterable[np.ndarray]:
3939
# This is model-specific, so that different models can have specialized implementations
4040
yield from self.embed(texts, **kwargs)
4141

42-
def query_embed(self, query: str, **kwargs) -> np.ndarray:
42+
def query_embed(self, query: Union[str, Iterable[str]], **kwargs) -> Iterable[np.ndarray]:
4343
"""
44-
Embeds a query
44+
Embeds queries
4545
4646
Args:
47-
query (str): The query to search for.
47+
query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
4848
4949
Returns:
50-
np.ndarray: The embeddings.
50+
Iterable[np.ndarray]: The embeddings.
5151
"""
5252

5353
# This is model-specific, so that different models can have specialized implementations
54-
query_embedding = list(self.embed([query], **kwargs))[0]
55-
return query_embedding
54+
if isinstance(query, str):
55+
yield from self.embed([query], **kwargs)
56+
if isinstance(query, Iterable):
57+
yield from self.embed(query, **kwargs)

0 commit comments

Comments
 (0)