|
| 1 | +import itertools |
| 2 | +from collections import defaultdict |
| 3 | + |
| 4 | +def apriori(transactions, min_support=0.5, max_length=None): |
| 5 | + """ |
| 6 | + Min Support : Minimum frequency threshold for an itemset to be considered frequent |
| 7 | + Max Length : Limits the maximum size of the frequent itemsets |
| 8 | + """ |
| 9 | + if not transactions: |
| 10 | + raise ValueError("Transaction list cannot be empty") |
| 11 | + if not 0 < min_support <= 1: |
| 12 | + raise ValueError("Minimum support must be between 0 and 1") |
| 13 | + |
| 14 | + num_transactions = len(transactions) |
| 15 | + min_support_count = min_support * num_transactions |
| 16 | + |
| 17 | + item_counts = defaultdict(int) |
| 18 | + for transaction in transactions: |
| 19 | + for item in transaction: |
| 20 | + item_counts[frozenset([item])] += 1 |
| 21 | + |
| 22 | + frequent_itemsets = {itemset: count for itemset, count in item_counts.items() |
| 23 | + if count >= min_support_count} |
| 24 | + |
| 25 | + k = 1 # Current itemset size |
| 26 | + all_frequent_itemsets = dict(frequent_itemsets) |
| 27 | + |
| 28 | + while frequent_itemsets and (max_length is None or k < max_length): |
| 29 | + k += 1 |
| 30 | + candidates = generate_candidates(frequent_itemsets.keys(), k) |
| 31 | + |
| 32 | + candidate_counts = defaultdict(int) |
| 33 | + for transaction in transactions: |
| 34 | + transaction_set = frozenset(transaction) |
| 35 | + for candidate in candidates: |
| 36 | + if candidate.issubset(transaction_set): |
| 37 | + candidate_counts[candidate] += 1 |
| 38 | + |
| 39 | + frequent_itemsets = {itemset: count for itemset, count in candidate_counts.items() |
| 40 | + if count >= min_support_count} |
| 41 | + |
| 42 | + all_frequent_itemsets.update(frequent_itemsets) |
| 43 | + |
| 44 | + return {itemset: count / num_transactions for itemset, count in all_frequent_itemsets.items()} |
| 45 | + |
| 46 | +def generate_candidates(prev_frequent_itemsets, k): |
| 47 | + candidates = set() |
| 48 | + prev_frequent_list = sorted(list(prev_frequent_itemsets), key=lambda x: sorted(x)) |
| 49 | + |
| 50 | + for i in range(len(prev_frequent_list)): |
| 51 | + for j in range(i + 1, len(prev_frequent_list)): |
| 52 | + itemset1 = prev_frequent_list[i] |
| 53 | + itemset2 = prev_frequent_list[j] |
| 54 | + |
| 55 | + if k > 2: |
| 56 | + if sorted(itemset1)[:-1] != sorted(itemset2)[:-1]: |
| 57 | + continue |
| 58 | + |
| 59 | + new_candidate = itemset1 | itemset2 |
| 60 | + if len(new_candidate) == k: |
| 61 | + candidates.add(new_candidate) |
| 62 | + |
| 63 | + return candidates |
| 64 | + |
| 65 | +def test_apriori(): |
| 66 | + transactions1 = [{'bread', 'milk'}, |
| 67 | + {'bread', 'diaper', 'beer', 'eggs'}, |
| 68 | + {'milk', 'diaper', 'beer', 'cola'}, |
| 69 | + {'bread', 'milk', 'diaper', 'beer'}, |
| 70 | + {'bread', 'milk', 'diaper', 'cola'}] |
| 71 | + |
| 72 | + result1 = apriori(transactions1, min_support=0.6) |
| 73 | + expected1 = { |
| 74 | + frozenset({'bread'}): 0.8, |
| 75 | + frozenset({'milk'}): 0.8, |
| 76 | + frozenset({'diaper'}): 0.8, |
| 77 | + frozenset({'bread', 'milk'}): 0.6, |
| 78 | + frozenset({'milk', 'diaper'}): 0.6, |
| 79 | + frozenset({'bread', 'diaper'}): 0.6 |
| 80 | + } |
| 81 | + assert set(result1.keys()) == set(expected1.keys()), "Test Case 1 Failed" |
| 82 | + assert all(abs(result1[k] - expected1[k]) < 0.001 for k in expected1), "Test Case 1 Failed" |
| 83 | + |
| 84 | + transactions2 = [{'a', 'b'}, {'c', 'd'}, {'e', 'f'}, {'g', 'h'}] |
| 85 | + result2 = apriori(transactions2, min_support=0.5) |
| 86 | + expected2 = {} # No itemset appears in at least 2 transactions |
| 87 | + assert set(result2.keys()) == set(expected2.keys()), "Test Case 2 Failed" |
| 88 | + |
| 89 | + transactions3 = [{'a', 'b', 'c', 'd'}, {'a', 'b', 'c', 'd'}, {'a', 'b', 'c', 'd'}] |
| 90 | + result3 = apriori(transactions3, min_support=0.5, max_length=2) |
| 91 | + expected3 = { |
| 92 | + frozenset({'a'}): 1.0, |
| 93 | + frozenset({'b'}): 1.0, |
| 94 | + frozenset({'c'}): 1.0, |
| 95 | + frozenset({'d'}): 1.0, |
| 96 | + frozenset({'a', 'b'}): 1.0, |
| 97 | + frozenset({'a', 'c'}): 1.0, |
| 98 | + frozenset({'a', 'd'}): 1.0, |
| 99 | + frozenset({'b', 'c'}): 1.0, |
| 100 | + frozenset({'b', 'd'}): 1.0, |
| 101 | + frozenset({'c', 'd'}): 1.0 |
| 102 | + } |
| 103 | + assert set(result3.keys()) == set(expected3.keys()), "Test Case 3 Failed" |
| 104 | + |
| 105 | + try: |
| 106 | + apriori([], min_support=0.5) |
| 107 | + assert False, "Test Case 4 Failed: Should raise ValueError" |
| 108 | + except ValueError: |
| 109 | + pass |
| 110 | + |
| 111 | + transactions5 = [ |
| 112 | + {'apple', 'banana', 'orange'}, |
| 113 | + {'apple', 'banana', 'grape'}, |
| 114 | + {'apple', 'orange', 'grape'}, |
| 115 | + {'banana', 'orange', 'grape'}, |
| 116 | + {'apple', 'banana', 'orange', 'grape'} |
| 117 | + ] |
| 118 | + result5 = apriori(transactions5, min_support=0.6) |
| 119 | + expected5 = { |
| 120 | + frozenset({'apple'}): 0.8, |
| 121 | + frozenset({'banana'}): 0.8, |
| 122 | + frozenset({'orange'}): 0.8, |
| 123 | + frozenset({'grape'}): 0.8, |
| 124 | + frozenset({'apple', 'banana'}): 0.6, |
| 125 | + frozenset({'apple', 'orange'}): 0.6, |
| 126 | + frozenset({'apple', 'grape'}): 0.6, |
| 127 | + frozenset({'banana', 'orange'}): 0.6, |
| 128 | + frozenset({'banana', 'grape'}): 0.6, |
| 129 | + frozenset({'orange', 'grape'}): 0.6 |
| 130 | + } |
| 131 | + assert set(result5.keys()) == set(expected5.keys()), "Test Case 5 Failed" |
| 132 | + |
| 133 | +if __name__ == "__main__": |
| 134 | + test_apriori() |
| 135 | + print("All Test Cases Passed!") |
0 commit comments