Skip to content

Commit 60b63df

Browse files
Create test.py
1 parent ec972ac commit 60b63df

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed

test.py

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from pprint import pprint
2+
3+
4+
def apriori(transactions: dict, min_support: int) -> dict:
5+
# Generate initial candidate itemsets (C1)
6+
item_count = {}
7+
for transaction in transactions:
8+
for item in transaction:
9+
item = frozenset([item]) # Ensure each item is a frozenset for consistency
10+
if item in item_count:
11+
item_count[item] += 1
12+
else:
13+
item_count[item] = 1
14+
15+
# Filter out items that don't meet the minimum support to form L1
16+
Lk = {item for item, count in item_count.items() if count >= min_support}
17+
k = 1
18+
frequent_itemsets = [
19+
set()
20+
] # Start with an empty set to index frequent sets by size
21+
22+
# Main loop to generate Lk from Ck
23+
while Lk:
24+
frequent_itemsets.append(Lk)
25+
Ck_plus_1 = set()
26+
27+
# Join step: Generate Ck+1 from Lk by finding all pairs of frequent item sets that can be merged
28+
Lk_list = list(Lk)
29+
for i in range(len(Lk_list)):
30+
for j in range(i + 1, len(Lk_list)):
31+
itemset1, itemset2 = Lk_list[i], Lk_list[j]
32+
new_candidate = itemset1.union(itemset2)
33+
if len(new_candidate) == k + 1:
34+
Ck_plus_1.add(new_candidate)
35+
36+
# Test each candidate in Ck+1 for minimum support
37+
candidate_count = {candidate: 0 for candidate in Ck_plus_1}
38+
for transaction in transactions:
39+
for candidate in Ck_plus_1:
40+
if candidate.issubset(transaction):
41+
candidate_count[candidate] += 1
42+
43+
# Form Lk+1 from candidates that meet the minimum support
44+
Lk = {
45+
candidate
46+
for candidate, count in candidate_count.items()
47+
if count >= min_support
48+
}
49+
k += 1
50+
51+
# Return the union of all Lk
52+
return {item for sublist in frequent_itemsets for item in sublist}
53+
54+
55+
# Example usage:
56+
transactions = [
57+
{"bread", "milk"},
58+
{"bread", "diaper", "beer", "eggs"},
59+
{"milk", "diaper", "beer", "coke"},
60+
{"bread", "milk", "diaper", "beer"},
61+
{"bread", "milk", "diaper", "coke"},
62+
]
63+
64+
min_support = 2
65+
frequent_itemsets = apriori(transactions, min_support)
66+
pprint(frequent_itemsets)

0 commit comments

Comments
 (0)