-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtools.py
112 lines (92 loc) · 4.37 KB
/
tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from autogen_core.tools import FunctionTool
from azure.search.documents.models import QueryType, VectorizableTextQuery
from azure.search.documents import SearchClient
from azure.core.credentials import AzureKeyCredential
import os
import logging
from dotenv import load_dotenv, find_dotenv
import json
import base64
load_dotenv(find_dotenv())
class SearchTool:
def __init__(self, figure_and_chunk_pairs: dict):
self.figure_and_chunk_pairs = figure_and_chunk_pairs
def search_index(self, queries: list[str], top) -> list[dict]:
final_results = {}
for query in queries:
vector_query = [
VectorizableTextQuery(
text=query,
k_nearest_neighbors=top * 5,
fields="ChunkEmbedding",
)
]
credential = AzureKeyCredential(
os.environ["AIService__AzureSearchOptions__Key"]
)
retrieval_fields = ["ChunkId", "Title", "Chunk", "ChunkFigures"]
with SearchClient(
endpoint=os.environ["AIService__AzureSearchOptions__Endpoint"],
index_name="image-processing-index",
credential=credential,
) as search_client:
results = list(
search_client.search(
top=top,
semantic_configuration_name="image-processing-semantic-config",
search_text=query,
select=",".join(retrieval_fields),
vector_queries=vector_query,
query_type=QueryType.SEMANTIC,
query_language="en-GB",
)
)
for result in results:
if (
result["ChunkId"] not in final_results
and result["@search.reranker_score"] >= 2.5
):
chunk_to_store = {
"Title": result["Title"],
"Chunk": result["Chunk"],
}
final_results[result["ChunkId"]] = chunk_to_store
if result["ChunkId"] not in self.figure_and_chunk_pairs:
self.figure_and_chunk_pairs[result["ChunkId"]] = {}
# Store the figures for later
for figure in result["ChunkFigures"]:
for figure in result["ChunkFigures"]:
# Convert the base64 image to a bytes object.
image_data = base64.b64decode(figure["Data"])
self.figure_and_chunk_pairs[result["ChunkId"]][
figure["FigureId"]
] = image_data
logging.info("Results: %s", results)
return json.dumps(final_results)
def rag_search_index(self, search_term: str) -> list[dict]:
"""Search the Azure Search index for the given query."""
return self.search_index([search_term], top=4)
def rat_search_index_breadth_first(self, search_terms: list[str]) -> list[dict]:
"""Search the Azure Search index for the given set of queries."""
return self.search_index(search_terms, top=1)
def rat_search_index_depth_first(self, search_terms: list[str]) -> list[dict]:
"""Search the Azure Search index for the given set of queries."""
return self.search_index(search_terms, top=3)
@property
def rat_breadth_first_tool(self):
return FunctionTool(
self.rat_search_index_breadth_first,
description="Search the Azure Search index for the given set of queries. Send a minimum of 3 different search terms to the search index to retrieve the relevant information.",
)
@property
def rat_depth_first_tool(self):
return FunctionTool(
self.rat_search_index_depth_first,
description="Search the Azure Search index for the given set of queries. Send a minimum of 3 different search terms to the search index to retrieve the relevant information.",
)
@property
def rag_search_tool(self):
return FunctionTool(
self.rag_search_index,
description="Search the Azure Search index for the given query.",
)