Skip to content

Commit f7b64af

Browse files
committed
link phase
1 parent 829e718 commit f7b64af

File tree

1 file changed

+157
-0
lines changed

1 file changed

+157
-0
lines changed

test/InMemPipeTestImagesLink.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
from zingg.client import *
2+
from zingg.pipes import *
3+
from pyspark.sql.types import *
4+
import pandas
5+
6+
import pyspark.sql.functions as fn
7+
from sentence_transformers import SentenceTransformer, util
8+
import torch
9+
import pickle
10+
11+
from PIL import Image
12+
13+
df = (getSparkSession().read.json('/home/ubuntu/image_data/listings/metadata'))
14+
15+
df = (
16+
df
17+
.filter("country='US'")
18+
.select(
19+
'item_id',
20+
'brand',
21+
'bullet_point',
22+
'domain_name',
23+
'marketplace',
24+
'item_keywords',
25+
'item_name',
26+
'product_description',
27+
'main_image_id',
28+
'other_image_id',
29+
'node'
30+
)
31+
)
32+
33+
image_metadata = (
34+
getSparkSession()
35+
.read
36+
.csv(
37+
path='/home/ubuntu/image_data/images/metadata',
38+
sep=',',
39+
header=True,
40+
)
41+
)
42+
43+
@fn.udf(ArrayType(StringType()))
44+
def get_english_values_from_array(array=None):
45+
46+
# prioritized list of english language codes (biased towards us english)
47+
english = ['en_US','en_CA','en_GB','en_AU','en_IN','en_SG','en_AE']
48+
49+
# initialize search
50+
values = []
51+
if array is None: array=[]
52+
53+
# for each potential english code
54+
for e in english:
55+
56+
# for each item in array
57+
for a in array:
58+
# if we found the english variant we want
59+
if a['language_tag']==e:
60+
# get value and stop
61+
values += [a['value']]
62+
63+
# if value has been found, then break
64+
if len(values) > 0: break
65+
66+
return values
67+
68+
model = SentenceTransformer('clip-ViT-B-32',device='cuda')
69+
70+
@fn.udf(ArrayType(DoubleType()))
71+
#@fn.udf(StringType())
72+
def get_image_embedding(path):
73+
74+
embedding = []
75+
76+
if path is not None:
77+
78+
full_path = '/home/ubuntu/image_data/images/small_partial2/' + path
79+
80+
# open image and convert to embedding
81+
try:
82+
image = Image.open(full_path).convert('RGB')
83+
embedding = model.encode(image, batch_size=128, convert_to_tensor=False, show_progress_bar=False)
84+
embedding = embedding.tolist()
85+
except:
86+
pass
87+
88+
# return embedding value
89+
return embedding
90+
91+
items = (
92+
df
93+
.alias('a')
94+
.select(
95+
'item_id',
96+
'domain_name',
97+
'marketplace',
98+
get_english_values_from_array('brand')[0].alias('brand'),
99+
get_english_values_from_array('item_name')[0].alias('item_name'),
100+
get_english_values_from_array('product_description')[0].alias('product_description'),
101+
get_english_values_from_array('bullet_point').alias('bulletpoint'),
102+
get_english_values_from_array('item_keywords').alias('item_keywords'),
103+
fn.split( fn.col('node')[0]['node_name'], '/').alias('hierarchy'),
104+
'main_image_id'
105+
)
106+
.join(
107+
image_metadata.alias('b').select('image_id','path'),
108+
on=fn.expr('a.main_image_id=b.image_id'),
109+
how='left'
110+
)
111+
.withColumn('main_image_embedding', get_image_embedding(fn.col('path')))
112+
.drop('main_image_id','image_id','bulletpoint','item_keywords','hierarchy')
113+
)
114+
115+
items.show()
116+
117+
#build the arguments for zingg
118+
args = Arguments()
119+
#set field definitions
120+
item_id = FieldDefinition("item_id", "string", MatchType.DONT_USE)
121+
domain_name = FieldDefinition("domain_name", "string", MatchType.DONT_USE)
122+
marketplace = FieldDefinition("marketplace", "string", MatchType.DONT_USE)
123+
brand = FieldDefinition("brand","string", MatchType.FUZZY)
124+
item_name = FieldDefinition("item_name", "string", MatchType.TEXT)
125+
product_description = FieldDefinition("product_description", "string", MatchType.DONT_USE)
126+
path = FieldDefinition("path", "string", MatchType.DONT_USE)
127+
main_image_embedding = FieldDefinition("main_image_embedding", "array<double>", MatchType.FUZZY)
128+
129+
#fieldDefs = [item_id, domain_name, marketplace, brand, item_name,product_description, bulletpoint, item_keywords, hierarchy,path, main_image_embedding]
130+
fieldDefs = [item_id, domain_name, marketplace, brand, item_name,product_description,path,main_image_embedding]
131+
args.setFieldDefinition(fieldDefs)
132+
#set the modelid and the zingg dir
133+
args.setModelId("9999")
134+
args.setZinggDir("/tmp/modelSmallImages")
135+
args.setNumPartitions(16)
136+
args.setLabelDataSampleSize(0.2)
137+
138+
inputPipeSmallImages1=InMemoryPipe("smallImages1")
139+
inputPipeSmallImages1.setDataset(items)
140+
141+
inputPipeSmallImages2=InMemoryPipe("smallImages2")
142+
inputPipeSmallImages2.setDataset(items)
143+
144+
args.setData(inputPipeSmallImages1,inputPipeSmallImages2)
145+
146+
#setting outputpipe in 'args'
147+
outputPipe = Pipe("resultSmallImages", "parquet")
148+
outputPipe.addProperty("location", "/tmp/resultSmallImages")
149+
args.setOutput(outputPipe)
150+
151+
options = ClientOptions([ClientOptions.PHASE,"link"])
152+
153+
#Zingg execution for the given phase
154+
zingg = Zingg(args, options)
155+
zingg.initAndExecute()
156+
157+

0 commit comments

Comments
 (0)