|
34 | 34 | "outputs": [], |
35 | 35 | "source": [ |
36 | 36 | "import pandas as pd\n", |
37 | | - "import uuid\n", |
38 | 37 | "from qdrant_client import QdrantClient\n", |
39 | | - "from qdrant_client.http import models\n", |
40 | | - "from qdrant_client.http.models import PointStruct" |
| 38 | + "from qdrant_client.http import models" |
41 | 39 | ] |
42 | 40 | }, |
43 | 41 | { |
|
71 | 69 | ], |
72 | 70 | "source": [ |
73 | 71 | "import datasets\n", |
| 72 | + "\n", |
74 | 73 | "dataset = datasets.load_dataset(\"KShivendu/dbpedia-entities-openai-1M\", split=\"train[0:100000]\")" |
75 | 74 | ] |
76 | 75 | }, |
|
133 | 132 | } |
134 | 133 | ], |
135 | 134 | "source": [ |
136 | | - "from qdrant_client import QdrantClient\n", |
137 | | - "\n", |
138 | 135 | "# client = QdrantClient(\n", |
139 | | - "# url=\"https://2aaa9439-b209-4ba6-8beb-d0b61dbd9388.us-east-1-0.aws.cloud.qdrant.io:6333\", \n", |
| 136 | + "# url=\"https://2aaa9439-b209-4ba6-8beb-d0b61dbd9388.us-east-1-0.aws.cloud.qdrant.io:6333\",\n", |
140 | 137 | "# api_key=\"FCF8_ADVuSRrtNGeg_rBJvAMJecEDgQhzuXMZGW8F7OzvaC9wYOPeQ\",\n", |
141 | 138 | "# prefer_grpc=True\n", |
142 | 139 | "# )\n", |
|
175 | 172 | "bs = 10000\n", |
176 | 173 | "for i in range(0, len(dataset), bs):\n", |
177 | 174 | " client.upload_collection(\n", |
178 | | - " collection_name=collection_name, \n", |
179 | | - " ids=range(i, i+bs),\n", |
180 | | - " vectors=dataset[i:i+bs][\"openai\"],\n", |
181 | | - " payload=[\n", |
182 | | - " {\"text\": x} for x in dataset[i:i+bs][\"text\"]\n", |
183 | | - " ],\n", |
| 175 | + " collection_name=collection_name,\n", |
| 176 | + " ids=range(i, i + bs),\n", |
| 177 | + " vectors=dataset[i : i + bs][\"openai\"],\n", |
| 178 | + " payload=[{\"text\": x} for x in dataset[i : i + bs][\"text\"]],\n", |
184 | 179 | " parallel=10,\n", |
185 | 180 | " )" |
186 | 181 | ] |
|
203 | 198 | ], |
204 | 199 | "source": [ |
205 | 200 | "client.update_collection(\n", |
206 | | - " collection_name=f\"{collection_name}\",\n", |
207 | | - " optimizer_config=models.OptimizersConfigDiff(\n", |
208 | | - " indexing_threshold=20000\n", |
209 | | - " )\n", |
| 201 | + " collection_name=f\"{collection_name}\", optimizer_config=models.OptimizersConfigDiff(indexing_threshold=20000)\n", |
210 | 202 | ")" |
211 | 203 | ] |
212 | 204 | }, |
|
289 | 281 | "source": [ |
290 | 282 | "import random\n", |
291 | 283 | "from random import randint\n", |
| 284 | + "\n", |
292 | 285 | "random.seed(37)\n", |
293 | 286 | "\n", |
294 | 287 | "query_indices = [randint(0, len(dataset)) for _ in range(100)]\n", |
|
304 | 297 | "source": [ |
305 | 298 | "## Add Gaussian noise to any vector\n", |
306 | 299 | "import numpy as np\n", |
| 300 | + "\n", |
307 | 301 | "np.random.seed(37)\n", |
| 302 | + "\n", |
| 303 | + "\n", |
308 | 304 | "def add_noise(vector, noise=0.05):\n", |
309 | 305 | " return vector + noise * np.random.randn(*vector.shape)" |
310 | 306 | ] |
|
959 | 955 | ], |
960 | 956 | "source": [ |
961 | 957 | "import time\n", |
| 958 | + "\n", |
| 959 | + "\n", |
962 | 960 | "def correct(results, text):\n", |
963 | 961 | " result_texts = [x.payload[\"text\"] for x in results]\n", |
964 | 962 | " return text in result_texts\n", |
|
977 | 975 | " rescore=rescore,\n", |
978 | 976 | " oversampling=oversampling,\n", |
979 | 977 | " )\n", |
980 | | - " )\n", |
| 978 | + " ),\n", |
981 | 979 | " )\n", |
982 | 980 | " correct_results += correct(results, text)\n", |
983 | 981 | " return correct_results\n", |
|
996 | 994 | " start = time.time()\n", |
997 | 995 | " correct_results = count_correct(query_dataset, limit=limit, oversampling=oversampling, rescore=rescore)\n", |
998 | 996 | " end = time.time()\n", |
999 | | - " results.append({\n", |
1000 | | - " \"limit\": limit,\n", |
1001 | | - " \"oversampling\": oversampling,\n", |
1002 | | - " \"rescore\": rescore,\n", |
1003 | | - " \"correct\": correct_results,\n", |
1004 | | - " \"total queries\": len(query_dataset[\"text\"]),\n", |
1005 | | - " \"time\": end - start,\n", |
1006 | | - " })\n", |
| 997 | + " results.append(\n", |
| 998 | + " {\n", |
| 999 | + " \"limit\": limit,\n", |
| 1000 | + " \"oversampling\": oversampling,\n", |
| 1001 | + " \"rescore\": rescore,\n", |
| 1002 | + " \"correct\": correct_results,\n", |
| 1003 | + " \"total queries\": len(query_dataset[\"text\"]),\n", |
| 1004 | + " \"time\": end - start,\n", |
| 1005 | + " }\n", |
| 1006 | + " )\n", |
1007 | 1007 | "\n", |
1008 | 1008 | "results_df = pd.DataFrame(results)\n", |
1009 | 1009 | "results_df" |
|
0 commit comments