|
| 1 | +from zingg.client import * |
| 2 | +from zingg.pipes import * |
| 3 | +from pyspark.sql.types import * |
| 4 | +import pandas |
| 5 | + |
| 6 | +import pyspark.sql.functions as fn |
| 7 | +from sentence_transformers import SentenceTransformer, util |
| 8 | +import torch |
| 9 | +import pickle |
| 10 | + |
| 11 | +from PIL import Image |
| 12 | + |
| 13 | +df = (getSparkSession().read.json('/home/ubuntu/image_data/listings/metadata')) |
| 14 | + |
| 15 | +df = ( |
| 16 | + df |
| 17 | + .filter("country='US'") |
| 18 | + .select( |
| 19 | + 'item_id', |
| 20 | + 'brand', |
| 21 | + 'bullet_point', |
| 22 | + 'domain_name', |
| 23 | + 'marketplace', |
| 24 | + 'item_keywords', |
| 25 | + 'item_name', |
| 26 | + 'product_description', |
| 27 | + 'main_image_id', |
| 28 | + 'other_image_id', |
| 29 | + 'node' |
| 30 | + ) |
| 31 | + ) |
| 32 | + |
| 33 | +image_metadata = ( |
| 34 | + getSparkSession() |
| 35 | + .read |
| 36 | + .csv( |
| 37 | + path='/home/ubuntu/image_data/images/metadata', |
| 38 | + sep=',', |
| 39 | + header=True, |
| 40 | + ) |
| 41 | + ) |
| 42 | + |
| 43 | +@fn.udf(ArrayType(StringType())) |
| 44 | +def get_english_values_from_array(array=None): |
| 45 | + |
| 46 | + # prioritized list of english language codes (biased towards us english) |
| 47 | + english = ['en_US','en_CA','en_GB','en_AU','en_IN','en_SG','en_AE'] |
| 48 | + |
| 49 | + # initialize search |
| 50 | + values = [] |
| 51 | + if array is None: array=[] |
| 52 | + |
| 53 | + # for each potential english code |
| 54 | + for e in english: |
| 55 | + |
| 56 | + # for each item in array |
| 57 | + for a in array: |
| 58 | + # if we found the english variant we want |
| 59 | + if a['language_tag']==e: |
| 60 | + # get value and stop |
| 61 | + values += [a['value']] |
| 62 | + |
| 63 | + # if value has been found, then break |
| 64 | + if len(values) > 0: break |
| 65 | + |
| 66 | + return values |
| 67 | + |
| 68 | +model = SentenceTransformer('clip-ViT-B-32',device='cuda') |
| 69 | + |
| 70 | +@fn.udf(ArrayType(DoubleType())) |
| 71 | +#@fn.udf(StringType()) |
| 72 | +def get_image_embedding(path): |
| 73 | + |
| 74 | + embedding = [] |
| 75 | + |
| 76 | + if path is not None: |
| 77 | + |
| 78 | + full_path = '/home/ubuntu/image_data/images/small_partial2/' + path |
| 79 | + |
| 80 | + # open image and convert to embedding |
| 81 | + try: |
| 82 | + image = Image.open(full_path).convert('RGB') |
| 83 | + embedding = model.encode(image, batch_size=128, convert_to_tensor=False, show_progress_bar=False) |
| 84 | + embedding = embedding.tolist() |
| 85 | + except: |
| 86 | + pass |
| 87 | + |
| 88 | + # return embedding value |
| 89 | + return embedding |
| 90 | + |
| 91 | +items = ( |
| 92 | + df |
| 93 | + .alias('a') |
| 94 | + .select( |
| 95 | + 'item_id', |
| 96 | + 'domain_name', |
| 97 | + 'marketplace', |
| 98 | + get_english_values_from_array('brand')[0].alias('brand'), |
| 99 | + get_english_values_from_array('item_name')[0].alias('item_name'), |
| 100 | + get_english_values_from_array('product_description')[0].alias('product_description'), |
| 101 | + get_english_values_from_array('bullet_point').alias('bulletpoint'), |
| 102 | + get_english_values_from_array('item_keywords').alias('item_keywords'), |
| 103 | + fn.split( fn.col('node')[0]['node_name'], '/').alias('hierarchy'), |
| 104 | + 'main_image_id' |
| 105 | + ) |
| 106 | + .join( |
| 107 | + image_metadata.alias('b').select('image_id','path'), |
| 108 | + on=fn.expr('a.main_image_id=b.image_id'), |
| 109 | + how='left' |
| 110 | + ) |
| 111 | + .withColumn('main_image_embedding', get_image_embedding(fn.col('path'))) |
| 112 | + .drop('main_image_id','image_id','bulletpoint','item_keywords','hierarchy') |
| 113 | + ) |
| 114 | + |
| 115 | +items.show() |
| 116 | + |
| 117 | +#build the arguments for zingg |
| 118 | +args = Arguments() |
| 119 | +#set field definitions |
| 120 | +item_id = FieldDefinition("item_id", "string", MatchType.DONT_USE) |
| 121 | +domain_name = FieldDefinition("domain_name", "string", MatchType.DONT_USE) |
| 122 | +marketplace = FieldDefinition("marketplace", "string", MatchType.DONT_USE) |
| 123 | +brand = FieldDefinition("brand","string", MatchType.FUZZY) |
| 124 | +item_name = FieldDefinition("item_name", "string", MatchType.TEXT) |
| 125 | +product_description = FieldDefinition("product_description", "string", MatchType.DONT_USE) |
| 126 | +path = FieldDefinition("path", "string", MatchType.DONT_USE) |
| 127 | +main_image_embedding = FieldDefinition("main_image_embedding", "array<double>", MatchType.FUZZY) |
| 128 | + |
| 129 | +#fieldDefs = [item_id, domain_name, marketplace, brand, item_name,product_description, bulletpoint, item_keywords, hierarchy,path, main_image_embedding] |
| 130 | +fieldDefs = [item_id, domain_name, marketplace, brand, item_name,product_description,path,main_image_embedding] |
| 131 | +args.setFieldDefinition(fieldDefs) |
| 132 | +#set the modelid and the zingg dir |
| 133 | +args.setModelId("9999") |
| 134 | +args.setZinggDir("/tmp/modelSmallImages") |
| 135 | +args.setNumPartitions(16) |
| 136 | +args.setLabelDataSampleSize(0.2) |
| 137 | + |
| 138 | +inputPipeSmallImages1=InMemoryPipe("smallImages1") |
| 139 | +inputPipeSmallImages1.setDataset(items) |
| 140 | + |
| 141 | +inputPipeSmallImages2=InMemoryPipe("smallImages2") |
| 142 | +inputPipeSmallImages2.setDataset(items) |
| 143 | + |
| 144 | +args.setData(inputPipeSmallImages1,inputPipeSmallImages2) |
| 145 | + |
| 146 | +#setting outputpipe in 'args' |
| 147 | +outputPipe = Pipe("resultSmallImages", "parquet") |
| 148 | +outputPipe.addProperty("location", "/tmp/resultSmallImages") |
| 149 | +args.setOutput(outputPipe) |
| 150 | + |
| 151 | +options = ClientOptions([ClientOptions.PHASE,"link"]) |
| 152 | + |
| 153 | +#Zingg execution for the given phase |
| 154 | +zingg = Zingg(args, options) |
| 155 | +zingg.initAndExecute() |
| 156 | + |
| 157 | + |
0 commit comments