-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpinecone_db.py
60 lines (55 loc) · 1.71 KB
/
pinecone_db.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import pinecone
import os
import numpy as np
from PIL import Image
import torch
import torchvision
from torchvision.transforms import (
Compose,
Resize,
CenterCrop,
ToTensor,
Normalize
)
def process_images(img_dir, model):
vectors = []
preprocess = Compose([
Resize(256),
CenterCrop(224),
ToTensor(),
Normalize(mean=[0.543, 0.546, 0.451], std=[0.312, 0.265, 0.257]),
])
ctr = 0
for file_name in os.listdir(img_dir):
if file_name.lower().endswith(('.jpg', '.jpeg')) and ctr<9900:
ctr+=1
print(ctr)
img_path = os.path.join(img_dir, file_name)
img = Image.open(img_path)
embedding = model(preprocess(img).unsqueeze(0)).tolist()
vector = {'id':str(file_name), 'values':embedding[0]}
vectors.append(vector)
else:
break
return vectors
def create_index(index_name,dim):
pinecone_api_key = config.api_keys["PINECONE_API_KEY"]
pinecone.init(api_key=pinecone_api_key, environment="northamerica-northeast1-gcp")
model = torchvision.models.squeezenet1_1(pretrained=True).eval()
if index_name not in pinecone.list_indexes():
pinecone.create_index(name=index_name, dimension=dim)
index = pinecone.Index(index_name)
vectors = process_images(img_dir="img",model=model)
batches = []
for i in range(0, len(vectors), 50):
batch = vectors[i:i+50]
batches.append(batch)
# Batch upsert
for b in batches:
print("batch upsert")
index.upsert(b)
pc_res = index.describe_index_stats()
return pc_res
INDEX_NAME = 'pinecone-image-search'
INDEX_DIM = 1000
print(create_index(INDEX_NAME,INDEX_DIM))