-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSimilarity Join Ecommerce Datasets
72 lines (57 loc) · 3.05 KB
/
Similarity Join Ecommerce Datasets
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import sys
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
col, collect_list, concat, lit, explode,
array_intersect, array_union, size, expr, when, round,
format_string, regexp_replace, year, to_date
)
from pyspark.sql.types import StringType
class Project3:
def run(self, inputpath, outputpath, k):
spark = SparkSession.builder.getOrCreate()
# Reads five fields
df = spark.read.csv(inputpath).toDF("InvoiceNo", "Description", "Quantity", "InvoiceDate", "UnitPrice")
# Extracts year from invoicedate
df = df.withColumn("InvoiceDate", to_date(col("InvoiceDate"), "dd/MM/yyyy hh:mm a"))
df = df.withColumn("Year", year(col("InvoiceDate")))
df_grouped = df.groupBy("InvoiceNo", "Year").agg(collect_list("Description").alias("Items"))
df_grouped.cache()
# Counts item occurrences
item_counts = df_grouped.select(explode("Items").alias("Item")).groupBy("Item").count()
# Ensure inovoice 1 < invoice 2 and different years
pairs = df_grouped.alias("df1").join(df_grouped.alias("df2"),
(col("df1.InvoiceNo") < col("df2.InvoiceNo")) &
(col("df1.Year") != col("df2.Year")))
# Jaccard similarity
pairs = pairs.withColumn("Similarity", expr("size(array_intersect(df1.Items, df2.Items)) / size(array_union(df1.Items, df2.Items))"))
# Make invoice number to integer from string
result = pairs.filter(col("Similarity") >= float(k)).select(
col("df1.InvoiceNo").cast("int").alias("InvoiceNo1"),
col("df2.InvoiceNo").cast("int").alias("InvoiceNo2"),
"Similarity"
)
# Ensure invoice 1 is less than invoice 2 always by switching them
result = result.withColumn("temp", when(col("InvoiceNo1") > col("InvoiceNo2"), col("InvoiceNo1")).otherwise(col("InvoiceNo2"))) \
.withColumn("InvoiceNo1", when(col("InvoiceNo1") > col("InvoiceNo2"), col("InvoiceNo2")).otherwise(col("InvoiceNo1"))) \
.withColumn("InvoiceNo2", col("temp")).drop("temp")
# Ascending order
result = result.distinct().orderBy(col("InvoiceNo1").asc(), col("InvoiceNo2").asc())
# Fixing up decimals
formatted_result = result.withColumn("Similarityrounded", round("Similarity", 16)).select(
regexp_replace(format_string("%.16f", "Similarityrounded"), "0+?$|\\.(0+)$", "").alias("Similarityformatted"),
"InvoiceNo1", "InvoiceNo2"
)
# Correct formatting
final_output = formatted_result.select(
concat(
lit("("), col("InvoiceNo1"), lit(","),
col("InvoiceNo2"), lit("):"),
col("Similarityformatted")
).alias("Output")
)
# Output
final_output.coalesce(1).write.text(outputpath)
spark.stop()
if __name__ == '__main__':
Project3().run(sys.argv[1], sys.argv[2], sys.argv[3])