-
Notifications
You must be signed in to change notification settings - Fork 125
/
Copy pathInMemPipeTestImagesJson.py
118 lines (95 loc) · 2.88 KB
/
InMemPipeTestImagesJson.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from zingg.client import *
from zingg.pipes import *
from pyspark.sql.types import *
import pandas
import pyspark.sql.functions as fn
from sentence_transformers import SentenceTransformer, util
import torch
import pickle
from PIL import Image
df = (getSparkSession().read.json('/home/ubuntu/image_data/listings/metadata'))
df = (
df
.filter("country='US'")
.select(
'item_id',
'brand',
'bullet_point',
'domain_name',
'marketplace',
'item_keywords',
'item_name',
'product_description',
'main_image_id',
'other_image_id',
'node'
)
)
image_metadata = (
getSparkSession()
.read
.csv(
path='/home/ubuntu/image_data/images/metadata',
sep=',',
header=True,
)
)
@fn.udf(ArrayType(StringType()))
def get_english_values_from_array(array=None):
# prioritized list of english language codes (biased towards us english)
english = ['en_US','en_CA','en_GB','en_AU','en_IN','en_SG','en_AE']
# initialize search
values = []
if array is None: array=[]
# for each potential english code
for e in english:
# for each item in array
for a in array:
# if we found the english variant we want
if a['language_tag']==e:
# get value and stop
values += [a['value']]
# if value has been found, then break
if len(values) > 0: break
return values
model = SentenceTransformer('clip-ViT-B-32',device='cuda')
@fn.udf(ArrayType(DoubleType()))
#@fn.udf(StringType())
def get_image_embedding(path):
embedding = []
if path is not None:
full_path = '/home/ubuntu/image_data/images/small/' + path
# open image and convert to embedding
try:
image = Image.open(full_path).convert('RGB')
embedding = model.encode(image, batch_size=128, convert_to_tensor=False, show_progress_bar=False)
embedding = embedding.tolist()
except:
pass
# return embedding value
return embedding
items = (
df
.alias('a')
.select(
'item_id',
'domain_name',
'marketplace',
get_english_values_from_array('brand')[0].alias('brand'),
get_english_values_from_array('item_name')[0].alias('item_name'),
get_english_values_from_array('product_description')[0].alias('product_description'),
get_english_values_from_array('bullet_point').alias('bulletpoint'),
get_english_values_from_array('item_keywords').alias('item_keywords'),
fn.split( fn.col('node')[0]['node_name'], '/').alias('hierarchy'),
'main_image_id'
)
.join(
image_metadata.alias('b').select('image_id','path'),
on=fn.expr('a.main_image_id=b.image_id'),
how='left'
)
.withColumn('main_image_embedding', get_image_embedding(fn.col('path')))
.drop('main_image_id','image_id','path','bulletpoint','item_keywords','hierarchy')
)
items.show()
items.write.json('/home/ubuntu/image_data/json_data/items')