-
Notifications
You must be signed in to change notification settings - Fork 552
/
Copy path06_retrieval.py
66 lines (51 loc) · 5.7 KB
/
06_retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from typing import Dict, List, Union
import numpy as np
from elasticsearch import Elasticsearch, exceptions
SAMPLE__EMBEDDINGS = [
[-0.1465761959552765, -0.4822517931461334, 0.07130702584981918, -0.25872930884361267, -0.1563894897699356, 0.16641047596931458, 0.24484659731388092, 0.2410498708486557, 0.008032954297959805, 0.17045290768146515, -0.009397129528224468, 0.09619587659835815, -0.22729521989822388, 0.10254761576652527, 0.016890447586774826, -0.13290464878082275, 0.11240798979997635, -0.11204371601343155, -0.057132963091135025, -0.011206787079572678, -0.007982085458934307, 0.279083788394928, 0.20115645229816437, -0.1427406221628189, -0.19398854672908783, -0.035979654639959335, 0.20723149180412292, 0.29891034960746765, 0.21407313644886017, 0.09746530652046204, 0.1671638935804367, 0.08161208778619766, 0.3090828061103821, -0.20648667216300964, 0.48498260974884033, -0.12691514194011688, 0.518856406211853, -0.26291757822036743, -0.0949832871556282, 0.09556109458208084, -0.20844918489456177, 0.2685297429561615, 0.053442806005477905, 0.05103180184960365, 0.1029752567410469, 0.04935301095247269, -0.11679927259683609, -0.012528933584690094, -0.08489680290222168, 0.013589601963758469, -0.32059246301651, 0.10357264429330826, -0.09533575177192688, 0.02984568662941456, 0.2793693542480469, -0.2653750777244568, -0.24152781069278717, -0.3563413619995117, 0.09674381464719772, -0.26155123114585876, -0.1397126317024231, -0.009133181534707546, 0.05972130224108696, -0.10438819974660873, 0.21889159083366394, 0.0694752112030983, -0.1312003880739212, -0.31072548031806946, -0.002836169209331274, 0.2468366175889969, 0.09420009702444077, 0.1284026801586151, -0.03227006644010544, -0.012532072141766548, 0.6650756597518921, -0.14863784611225128, 0.005239118821918964, -0.3317912817001343, 0.16372767090797424, -0.20166568458080292, 0.029721004888415337, -0.18536655604839325, -0.3608534038066864, -0.18234892189502716, 0.019248824566602707, 0.25257956981658936, 0.09671413153409958, 0.15569280087947845, -0.38228726387023926, 0.37017977237701416, 0.03356296569108963, -0.21182948350906372, 0.48848846554756165, 0.18350018560886383, -0.23519110679626465, -0.17464864253997803], [-0.18246106803417206, -0.36036479473114014, 0.3282334506511688, -0.230922132730484, 0.09600532799959183, 0.6859422326087952, 0.0581890344619751, 0.4913463294506073, 0.1536773443222046, -0.2965141832828522, 0.08466599136590958, 0.319297194480896, -0.15651769936084747, -0.043428342789411545, 0.014402368105947971, 0.16681505739688873, 0.22521673142910004, -0.2715776264667511, -0.11033261567354202, -0.04398636147379875, 0.3480629622936249, 0.11897992342710495, 0.8724615573883057, 0.10258488357067108, -0.5719427466392517, -0.03029855526983738, 0.23351268470287323, 0.20660561323165894, 0.575685441493988, -0.12116186320781708, 0.18459142744541168, -0.12865227460861206, 0.3948173522949219, -0.34464019536972046, 0.6699116230010986, -0.45167359709739685, 1.1505522727966309, -0.4498964548110962, -0.3248189687728882, -0.29674994945526123, -0.3570491075515747, 0.5436431765556335, 0.49576905369758606, -0.11180296540260315, -0.02045607566833496, -0.22768598794937134, -0.37912657856941223, -0.30414703488349915, -0.48289090394973755, -0.04158346354961395, -0.3547952473163605, 0.0687602087855339, 0.041512664407491684, 0.33524179458618164, 0.21826978027820587, -0.443082332611084, -0.5049593448638916, -0.5298929810523987, -0.02618088759481907, -0.2748631536960602, -0.1986193209886551, 0.35475826263427734, 0.22456413507461548, -0.29532068967819214, 0.25150877237319946, 0.243370920419693, -0.29938358068466187, -0.2128247618675232, -0.15292000770568848, -0.14813245832920074, -0.06183856353163719, -0.1251668632030487, 0.14256533980369568, -0.22781267762184143, 0.8101184964179993, 0.19796361029148102, 0.09104947745800018, -0.4860817790031433, 0.3078012764453888, -0.27373194694519043, 0.11800770461559296, -0.45869407057762146, 0.09508189558982849, -0.23971715569496155, -0.27427223324775696, 0.5139415264129639, 0.1871502846479416, 0.06647063046693802, -0.4054469168186188, 0.4751380681991577, 0.17067894339561462, 0.12443914264440536, 0.3577817678451538, 0.10574143379926682, -0.3181760311126709, -0.23804502189159393]
]
@data_loader
def search(*args, **kwargs) -> List[Dict]:
"""
query_embedding: Union[List[int], np.ndarray]
"""
connection_string = kwargs.get('connection_string', 'http://localhost:9200')
index_name = kwargs.get('index_name', 'documents')
source = kwargs.get('source', "cosineSimilarity(params.query_vector, 'embedding') + 1.0")
top_k = kwargs.get('top_k', 5)
chunk_column = kwargs.get('chunk_column', 'content')
query_embedding = None
if len(args):
query_embedding = args[0]
if not query_embedding:
query_embedding = SAMPLE__EMBEDDINGS[0]
if isinstance(query_embedding, np.ndarray):
query_embedding = query_embedding.tolist()
script_query = {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": source,
"params": {"query_vector": query_embedding},
}
}
}
print("Sending script query:", script_query)
es_client = Elasticsearch(connection_string)
try:
response = es_client.search(
index=index_name,
body={
"size": top_k,
"query": script_query,
"_source": [chunk_column],
},
)
print("Raw response from Elasticsearch:", response)
return [hit['_source'][chunk_column] for hit in response['hits']['hits']]
except exceptions.BadRequestError as e:
print(f"BadRequestError: {e.info}")
return []
except Exception as e:
print(f"Unexpected error: {e}")
return []