Skip to content

Commit 4b46818

Browse files
authored
[xh]Qdrant integration for data loader and exporter (mage-ai#4081)
* add qdrant * use text for query and export * address comments
1 parent 041a624 commit 4b46818

File tree

11 files changed

+310
-0
lines changed

11 files changed

+310
-0
lines changed
+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
---
2+
title: "Qdrant"
3+
sidebarTitle: "Qdrant"
4+
---
5+
6+
## Credentials
7+
8+
Open the file named `io_config.yaml` at the root of your Mage project and enter qdrant required fields:
9+
10+
```yaml
11+
version: 0.1.1
12+
default:
13+
QDRANT_COLLECTION: collection_name
14+
QDRANT_PATH: path of the qdrant persisitant storage
15+
```
16+
17+
## Using Python block
18+
19+
1. Create a new pipeline or open an existing pipeline.
20+
2. Add a data loader or data exporter using the Qdrant template under the "Databases" category.
21+
Both the data loader and exporter use SentenceTransformer 'all-MiniLM-L6-v2' as the default embedding function.
22+
3. Add your customized code into the loader, exporter or add extra transformer blocks.
23+
4. Run the block.
24+
25+
## Available functions
26+
27+
- Qdrant data loader arguments:
28+
- limit_results (int): Number of results to return.
29+
- query_vector (List): vector lit used for query.
30+
- collection_name (str): name of the collection. Default to use the name defined in io_config.yaml.
31+
32+
- Qdrant data exporter arguments:
33+
- df (DataFrame): Data to export.
34+
- document_column (str): Column name containinng documents to export.
35+
- id_column (str): Column name of the id. Default will use index in df.
36+
- vector_column (str): Column name of the vector. Will use default encoder to auto generate query vector.
37+
- collection_name (str): name of the collection. Deafult to use the name defined in io_config.yaml.
38+
- vector_size (int): dimension size of vector.
39+
- distance (models.Distance): distance metric to use.
40+
41+
At the same time there is `create_collection` function can be used in your block to create new collection.

docs/mint.json

+1
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@
356356
"integrations/databases/MySQL",
357357
"integrations/databases/Pinot",
358358
"integrations/databases/PostgreSQL",
359+
"integrations/databases/Qdrant",
359360
"integrations/databases/Redshift",
360361
"integrations/databases/S3",
361362
"integrations/databases/Snowflake",

mage_ai/data_preparation/templates/constants.py

+14
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,13 @@
227227
name='PostgreSQL',
228228
path='data_loaders/postgres.py',
229229
),
230+
dict(
231+
block_type=BlockType.DATA_LOADER,
232+
groups=[GROUP_DATABASES],
233+
language=BlockLanguage.PYTHON,
234+
name='Qdrant',
235+
path='data_loaders/qdrant.py',
236+
),
230237
dict(
231238
block_type=BlockType.DATA_LOADER,
232239
description='Fetch data from an API request.',
@@ -587,6 +594,13 @@
587594
name='PostgreSQL',
588595
path='data_exporters/postgres.py',
589596
),
597+
dict(
598+
block_type=BlockType.DATA_EXPORTER,
599+
groups=[GROUP_DATABASES],
600+
language=BlockLanguage.PYTHON,
601+
name='Qdrant',
602+
path='data_exporters/qdrant.py',
603+
),
590604
# Sensors
591605
dict(
592606
block_type=BlockType.SENSOR,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from os import path
2+
3+
from pandas import DataFrame
4+
5+
from mage_ai.settings.repo import get_repo_path
6+
from mage_ai.io.config import ConfigFileLoader
7+
from mage_ai.io.qdrant import Qdrant
8+
from sentence_transformers import SentenceTransformer
9+
10+
if 'data_exporter' not in globals():
11+
from mage_ai.data_preparation.decorators import data_exporter
12+
13+
14+
@data_exporter
15+
def export_data_to_qdrant(df: DataFrame, **kwargs) -> None:
16+
"""
17+
Template to write data into Qdrant.
18+
"""
19+
config_path = path.join(get_repo_path(), 'io_config.yaml')
20+
config_profile = 'default'
21+
# Update following collection name or the default value in io_config will be used.
22+
collection_name = 'new_colletion'
23+
# Column contains the document to save.
24+
document_column = 'payload'
25+
26+
Qdrant.with_config(ConfigFileLoader(config_path, config_profile)).export(
27+
df,
28+
collection_name=collection_name,
29+
document_column=document_column,
30+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{% extends "data_loaders/default.jinja" %}
2+
{% block imports %}
3+
from mage_ai.settings.repo import get_repo_path
4+
from mage_ai.io.config import ConfigFileLoader
5+
from mage_ai.io.qdrant import Qdrant
6+
from sentence_transformers import SentenceTransformer
7+
from os import path
8+
{{ super() -}}
9+
{% endblock %}
10+
11+
DEFAULT_MODEL = 'all-MiniLM-L6-v2'
12+
13+
{% block content %}
14+
@data_loader
15+
def load_data_from_qdrant(*args, **kwargs):
16+
"""
17+
Template to load data from Qdrant.
18+
"""
19+
# Use all-MiniLM-L6-v2 embedding model as default.
20+
encoder = SentenceTransformer(DEFAULT_MODEL)
21+
# Generate vector for query.
22+
query_vector = encoder.encode('Test query').tolist()
23+
# number of results to return.
24+
limit_results = 3
25+
config_path = path.join(get_repo_path(), 'io_config.yaml')
26+
config_profile = 'default'
27+
collection_name = 'test_collection'
28+
29+
return Qdrant.with_config(ConfigFileLoader(config_path, config_profile)).load(
30+
limit_results=limit_results,
31+
query_vector=query_vector,
32+
collection_name=collection_name)
33+
{% endblock %}

mage_ai/data_preparation/templates/repo/io_config.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,9 @@ default:
8383
POSTGRES_PASSWORD: password
8484
POSTGRES_HOST: hostname
8585
POSTGRES_PORT: 5432
86+
# Qdrant
87+
QDRANT_COLLECTION: collection
88+
QDRANT_PATH: path
8689
# Redshift
8790
REDSHIFT_SCHEMA: public # Optional
8891
REDSHIFT_DBNAME: redshift_db_name

mage_ai/io/base.py

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class DataSource(str, Enum):
2929
OPENSEARCH = 'opensearch'
3030
PINOT = 'pinot'
3131
POSTGRES = 'postgres'
32+
QDRANT = 'qdrant'
3233
REDSHIFT = 'redshift'
3334
S3 = 's3'
3435
SNOWFLAKE = 'snowflake'

mage_ai/io/config.py

+6
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ class ConfigKey(str, Enum):
8787
PINOT_SCHEME = 'PINOT_SCHEME'
8888
PINOT_USER = 'PINOT_USER'
8989

90+
QDRANT_COLLECTION = 'QDRANT_COLLECTION'
91+
QDRANT_PATH = 'QDRANT_PATH'
92+
9093
POSTGRES_CONNECTION_METHOD = 'POSTGRES_CONNECTION_METHOD'
9194
POSTGRES_CONNECT_TIMEOUT = 'POSTGRES_CONNECT_TIMEOUT'
9295
POSTGRES_DBNAME = 'POSTGRES_DBNAME'
@@ -335,6 +338,7 @@ class VerboseConfigKey(str, Enum):
335338
REDSHIFT = 'Redshift'
336339
SNOWFLAKE = 'Snowflake'
337340
SPARK = 'Spark'
341+
QDRANT = 'Qdrant'
338342

339343

340344
class ConfigFileLoader(BaseConfigLoader):
@@ -410,6 +414,8 @@ class ConfigFileLoader(BaseConfigLoader):
410414
ConfigKey.POSTGRES_PORT: (VerboseConfigKey.POSTGRES, 'port'),
411415
ConfigKey.POSTGRES_SCHEMA: (VerboseConfigKey.POSTGRES, 'schema'),
412416
ConfigKey.POSTGRES_USER: (VerboseConfigKey.POSTGRES, 'user'),
417+
ConfigKey.QDRANT_COLLECTION: (VerboseConfigKey.QDRANT, 'collection'),
418+
ConfigKey.QDRANT_PATH: (VerboseConfigKey.QDRANT, 'path'),
413419
ConfigKey.SNOWFLAKE_ACCOUNT: (VerboseConfigKey.SNOWFLAKE, 'account'),
414420
ConfigKey.SNOWFLAKE_DEFAULT_DB: (VerboseConfigKey.SNOWFLAKE, 'database'),
415421
ConfigKey.SNOWFLAKE_DEFAULT_SCHEMA: (VerboseConfigKey.SNOWFLAKE, 'schema'),

mage_ai/io/qdrant.py

+173
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
from typing import List
2+
3+
from pandas import DataFrame
4+
from qdrant_client import QdrantClient
5+
from qdrant_client.http import models
6+
from sentence_transformers import SentenceTransformer
7+
8+
from mage_ai.io.base import BaseIO
9+
from mage_ai.io.config import BaseConfigLoader, ConfigKey
10+
11+
DEFAULT_EMBEDDING_MODEL = 'all-MiniLM-L6-v2'
12+
13+
14+
class Qdrant(BaseIO):
15+
def __init__(
16+
self,
17+
collection: str,
18+
path: str = None,
19+
verbose: bool = True,
20+
**kwargs,) -> None:
21+
"""
22+
Initializes connection to qdrant db.
23+
"""
24+
super().__init__(verbose=verbose)
25+
self.collection = collection
26+
self.path = path
27+
self.open()
28+
29+
@classmethod
30+
def with_config(cls, config: BaseConfigLoader) -> 'Qdrant':
31+
return cls(
32+
collection=config[ConfigKey.QDRANT_COLLECTION],
33+
path=config[ConfigKey.QDRANT_PATH],
34+
)
35+
36+
def create_collection(
37+
self,
38+
vector_size: int,
39+
distance: models.Distance = None,
40+
collection_name: str = None):
41+
"""
42+
Create collection in qdrant db.
43+
Args:
44+
vector_size (int): dimension size of the vector.
45+
distance (models.Distance): distance metric to use.
46+
collection_name (str): name of the collection.
47+
Defaults to the name defined in io_config.yaml.
48+
Returns:
49+
collection created.
50+
"""
51+
collection_name = collection_name or self.collection
52+
distance = distance or models.Distance.COSINE
53+
return self.client.create_collection(
54+
collection_name=collection_name,
55+
vectors_config=models.VectorParams(
56+
size=vector_size,
57+
distance=distance),
58+
)
59+
60+
def load(
61+
self,
62+
limit_results: int,
63+
query_vector: List,
64+
collection_name: str = None,
65+
**kwargs,
66+
) -> DataFrame:
67+
"""
68+
Loads the data from Qdrant with query_vector.
69+
Args:
70+
limit_results (int): Number of results to return.
71+
query_vector (List): vector list used to query.
72+
collection_name (str): name of the collection.
73+
Defaults to the name defined in io_config.yaml.
74+
Returns:
75+
DataFrame: Data frame object loaded with data from qdrant
76+
"""
77+
# Assume collection is already created and exists.
78+
collection_name = collection_name or self.collection
79+
80+
hitted_results = self.client.search(
81+
collection_name=collection_name,
82+
query_vector=query_vector,
83+
limit=limit_results,
84+
with_vectors=True,
85+
)
86+
87+
output_df = {}
88+
output_df['id'] = [hit.id for hit in hitted_results]
89+
output_df['payload'] = [hit.payload for hit in hitted_results]
90+
output_df['score'] = [hit.score for hit in hitted_results]
91+
output_df['vector'] = [hit.vector for hit in hitted_results]
92+
93+
return DataFrame.from_dict(output_df)
94+
95+
def export(
96+
self,
97+
df: DataFrame,
98+
document_column: str,
99+
id_column: str = None,
100+
vector_column: str = None,
101+
collection_name: str = None,
102+
vector_size: int = None,
103+
distance: models.Distance = None,
104+
**kwargs,
105+
) -> None:
106+
"""
107+
Save data into Qdrant.
108+
Args:
109+
df (DataFrame): Data to export.
110+
document_column (str): Column name containinng documents to export.
111+
id_column (str): Column name of the id. Default will use index in df.
112+
vector_column (str): Column name of the vector. Will use default
113+
encoder to auto generate query vector to auto generate query vector.
114+
collection_name (str): name of the collection.
115+
vector_size (int): dimension size of vector.
116+
distance (models.Distance): distance metric to use.
117+
"""
118+
collection_name = collection_name or self.collection
119+
encoder = SentenceTransformer(DEFAULT_EMBEDDING_MODEL)
120+
121+
try:
122+
self.client.get_collection(collection_name)
123+
except ValueError:
124+
print(f'Creating collection: {collection_name}')
125+
self.create_collection(
126+
vector_size=vector_size or encoder.get_sentence_embedding_dimension(),
127+
distance=distance,
128+
collection_name=collection_name,
129+
)
130+
131+
payloads = df[document_column].tolist()
132+
if id_column is None:
133+
ids = [x for x in df.index.tolist()]
134+
else:
135+
ids = df[id_column].tolist()
136+
if vector_column is None:
137+
vectors = [encoder.encode(str(x)).tolist() for x in payloads]
138+
else:
139+
vectors = df[vector_column].tolist()
140+
141+
self.client.upsert(
142+
collection_name=collection_name,
143+
points=models.Batch(
144+
ids=ids,
145+
payloads=payloads,
146+
vectors=vectors,
147+
),
148+
)
149+
150+
def __del__(self):
151+
self.close()
152+
153+
def __enter__(self):
154+
self.open()
155+
return self
156+
157+
def __exit__(self, *args):
158+
self.close()
159+
160+
def open(self) -> None:
161+
"""
162+
Opens an underlying connection to Qdrannt.
163+
"""
164+
if self.path is None:
165+
self.client = QdrantClient(':memory:')
166+
else:
167+
self.client = QdrantClient(path=self.path)
168+
169+
def close(self) -> None:
170+
"""
171+
Close the underlying connection to Qdrant.
172+
"""
173+
self.client.close()

requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@ psycopg2==2.9.3
7676
psycopg2-binary==2.9.3
7777
pydruid==0.6.5
7878
pyodbc==4.0.35
79+
qdrant-client>=1.6.9
7980
redshift-connector==2.0.909
81+
sentence-transformers>=2.2.2
8082
snowflake-connector-python==3.2.1
8183
sshtunnel==0.4.0
8284
tables==3.7.0

setup.py

+6
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ def readme():
9292
'psycopg2-binary==2.9.3',
9393
'sshtunnel==0.4.0',
9494
],
95+
'qdrant': [
96+
'qdrant-client>=1.6.9',
97+
'sentence-transformers>=2.2.2',
98+
],
9599
'redshift': [
96100
'boto3==1.26.60',
97101
'redshift-connector==2.0.909',
@@ -177,8 +181,10 @@ def readme():
177181
'pydruid==0.6.5',
178182
'pymongo==4.3.3',
179183
'pyodbc==4.0.35',
184+
'qdrant-client>=1.6.9',
180185
'redshift-connector==2.0.909',
181186
'requests_aws4auth==1.1.2',
187+
'sentence-transformers>=2.2.2',
182188
'snowflake-connector-python==3.2.1',
183189
'sshtunnel==0.4.0',
184190
'stomp.py==8.1.0',

0 commit comments

Comments
 (0)