Skip to content

Commit b507f4d

Browse files
authored
Merge pull request #762 from asaintsever/asaintsever/feat/opensearch-semantic-vector-store
Add new OpenSearch Vector Store with embeddings support
2 parents 52da03d + 6b76dc4 commit b507f4d

File tree

4 files changed

+185
-3
lines changed

4 files changed

+185
-3
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ bigquery = ["google-cloud-bigquery"]
3333
snowflake = ["snowflake-connector-python"]
3434
duckdb = ["duckdb"]
3535
google = ["google-generativeai", "google-cloud-aiplatform"]
36-
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common", "faiss-cpu", "boto", "boto3", "botocore", "langchain_core", "langchain_postgres", "xinference-client"]
36+
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common", "faiss-cpu", "boto", "boto3", "botocore", "langchain_core", "langchain_postgres", "langchain-community", "langchain-huggingface", "xinference-client"]
3737
test = ["tox"]
3838
chromadb = ["chromadb"]
3939
openai = ["openai"]
@@ -47,7 +47,7 @@ ollama = ["ollama", "httpx"]
4747
qdrant = ["qdrant-client", "fastembed"]
4848
vllm = ["vllm"]
4949
pinecone = ["pinecone-client", "fastembed"]
50-
opensearch = ["opensearch-py", "opensearch-dsl"]
50+
opensearch = ["opensearch-py", "opensearch-dsl", "langchain-community", "langchain-huggingface"]
5151
hf = ["transformers"]
5252
milvus = ["pymilvus[model]"]
5353
bedrock = ["boto3", "botocore"]

src/vanna/opensearch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .opensearch_vector import OpenSearch_VectorStore
2+
from .opensearch_vector_semantic import OpenSearch_Semantic_VectorStore
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
import json
2+
3+
import pandas as pd
4+
from langchain_community.vectorstores import OpenSearchVectorSearch
5+
6+
from ..base import VannaBase
7+
from ..utils import deterministic_uuid
8+
9+
10+
class OpenSearch_Semantic_VectorStore(VannaBase):
11+
def __init__(self, config=None):
12+
VannaBase.__init__(self, config=config)
13+
if config is None:
14+
config = {}
15+
16+
if "embedding_function" in config:
17+
self.embedding_function = config.get("embedding_function")
18+
else:
19+
from langchain_huggingface import HuggingFaceEmbeddings
20+
self.embedding_function = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
21+
22+
self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
23+
self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10))
24+
self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
25+
26+
self.document_index = config.get("es_document_index", "vanna_document_index")
27+
self.ddl_index = config.get("es_ddl_index", "vanna_ddl_index")
28+
self.question_sql_index = config.get("es_question_sql_index", "vanna_questions_sql_index")
29+
30+
self.log(f"OpenSearch_Semantic_VectorStore initialized with document_index: {self.document_index}, ddl_index: {self.ddl_index}, question_sql_index: {self.question_sql_index}")
31+
32+
es_urls = config.get("es_urls", "https://localhost:9200")
33+
ssl = config.get("es_ssl", True)
34+
verify_certs = config.get("es_verify_certs", True)
35+
36+
if "es_user" in config:
37+
auth = (config["es_user"], config["es_password"])
38+
else:
39+
auth = None
40+
41+
headers = config.get("es_headers", None)
42+
timeout = config.get("es_timeout", 60)
43+
max_retries = config.get("es_max_retries", 10)
44+
45+
common_args = {
46+
"opensearch_url": es_urls,
47+
"embedding_function": self.embedding_function,
48+
"engine": "faiss",
49+
"http_auth": auth,
50+
"use_ssl": ssl,
51+
"verify_certs": verify_certs,
52+
"timeout": timeout,
53+
"max_retries": max_retries,
54+
"retry_on_timeout": True,
55+
"headers": headers,
56+
}
57+
58+
self.documentation_store = OpenSearchVectorSearch(index_name=self.document_index, **common_args)
59+
self.ddl_store = OpenSearchVectorSearch(index_name=self.ddl_index, **common_args)
60+
self.sql_store = OpenSearchVectorSearch(index_name=self.question_sql_index, **common_args)
61+
62+
def add_ddl(self, ddl: str, **kwargs) -> str:
63+
_id = deterministic_uuid(ddl) + "-ddl"
64+
self.ddl_store.add_texts(texts=[ddl], ids=[_id], **kwargs)
65+
return _id
66+
67+
def add_documentation(self, documentation: str, **kwargs) -> str:
68+
_id = deterministic_uuid(documentation) + "-doc"
69+
self.documentation_store.add_texts(texts=[documentation], ids=[_id], **kwargs)
70+
return _id
71+
72+
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
73+
question_sql_json = json.dumps(
74+
{
75+
"question": question,
76+
"sql": sql,
77+
},
78+
ensure_ascii=False,
79+
)
80+
81+
_id = deterministic_uuid(question_sql_json) + "-sql"
82+
self.sql_store.add_texts(texts=[question_sql_json], ids=[_id], **kwargs)
83+
return _id
84+
85+
def get_related_ddl(self, question: str, **kwargs) -> list:
86+
documents = self.ddl_store.similarity_search(query=question, k=self.n_results_ddl)
87+
return [document.page_content for document in documents]
88+
89+
def get_related_documentation(self, question: str, **kwargs) -> list:
90+
documents = self.documentation_store.similarity_search(query=question, k=self.n_results_documentation)
91+
return [document.page_content for document in documents]
92+
93+
def get_similar_question_sql(self, question: str, **kwargs) -> list:
94+
documents = self.sql_store.similarity_search(query=question, k=self.n_results_sql)
95+
return [json.loads(document.page_content) for document in documents]
96+
97+
def get_training_data(self, **kwargs) -> pd.DataFrame:
98+
data = []
99+
query = {
100+
"query": {
101+
"match_all": {}
102+
}
103+
}
104+
105+
indices = [
106+
{"index": self.document_index, "type": "documentation"},
107+
{"index": self.question_sql_index, "type": "sql"},
108+
{"index": self.ddl_index, "type": "ddl"},
109+
]
110+
111+
# Use documentation_store.client consistently for search on all indices
112+
opensearch_client = self.documentation_store.client
113+
114+
for index_info in indices:
115+
index_name = index_info["index"]
116+
training_data_type = index_info["type"]
117+
scroll = '1m' # keep scroll context for 1 minute
118+
response = opensearch_client.search(
119+
index=index_name,
120+
ignore_unavailable=True,
121+
body=query,
122+
scroll=scroll,
123+
size=1000
124+
)
125+
126+
scroll_id = response.get('_scroll_id')
127+
128+
while scroll_id:
129+
hits = response['hits']['hits']
130+
if not hits:
131+
break # No more hits, exit loop
132+
133+
for hit in hits:
134+
source = hit['_source']
135+
if training_data_type == "sql":
136+
try:
137+
doc_dict = json.loads(source['text'])
138+
content = doc_dict.get("sql")
139+
question = doc_dict.get("question")
140+
except json.JSONDecodeError as e:
141+
self.log(f"Skipping row with custom_id {hit['_id']} due to JSON parsing error: {e}","Error")
142+
continue
143+
else: # documentation or ddl
144+
content = source['text']
145+
question = None
146+
147+
data.append({
148+
"id": hit["_id"],
149+
"training_data_type": training_data_type,
150+
"question": question,
151+
"content": content,
152+
})
153+
154+
# Get next batch of results, using documentation_store.client.scroll
155+
response = opensearch_client.scroll(scroll_id=scroll_id, scroll=scroll)
156+
scroll_id = response.get('_scroll_id')
157+
158+
return pd.DataFrame(data)
159+
160+
def remove_training_data(self, id: str, **kwargs) -> bool:
161+
try:
162+
if id.endswith("-sql"):
163+
return self.sql_store.delete(ids=[id], **kwargs)
164+
elif id.endswith("-ddl"):
165+
return self.ddl_store.delete(ids=[id], **kwargs)
166+
elif id.endswith("-doc"):
167+
return self.documentation_store.delete(ids=[id], **kwargs)
168+
else:
169+
return False
170+
except Exception as e:
171+
self.log(f"Error deleting training dataError deleting training data: {e}", "Error")
172+
return False
173+
174+
def generate_embedding(self, data: str, **kwargs) -> list[float]:
175+
pass

tests/test_imports.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ def test_regular_imports():
1818
from vanna.openai.openai_chat import OpenAI_Chat
1919
from vanna.openai.openai_embeddings import OpenAI_Embeddings
2020
from vanna.opensearch.opensearch_vector import OpenSearch_VectorStore
21+
from vanna.opensearch.opensearch_vector_semantic import (
22+
OpenSearch_Semantic_VectorStore,
23+
)
2124
from vanna.pgvector.pgvector import PG_VectorStore
2225
from vanna.pinecone.pinecone_vector import PineconeDB_VectorStore
2326
from vanna.qdrant.qdrant import Qdrant_VectorStore
@@ -44,7 +47,10 @@ def test_shortcut_imports():
4447
from vanna.mistral import Mistral
4548
from vanna.ollama import Ollama
4649
from vanna.openai import OpenAI_Chat, OpenAI_Embeddings
47-
from vanna.opensearch import OpenSearch_VectorStore
50+
from vanna.opensearch import (
51+
OpenSearch_Semantic_VectorStore,
52+
OpenSearch_VectorStore,
53+
)
4854
from vanna.pgvector import PG_VectorStore
4955
from vanna.pinecone import PineconeDB_VectorStore
5056
from vanna.qdrant import Qdrant_VectorStore

0 commit comments

Comments
 (0)