diff --git a/machine_learning/apriori_algorithm.py b/machine_learning/apriori_algorithm.py index 09a89ac236bd..71b7e0c0a5f1 100644 --- a/machine_learning/apriori_algorithm.py +++ b/machine_learning/apriori_algorithm.py @@ -11,6 +11,7 @@ Examples: https://www.kaggle.com/code/earthian/apriori-association-rules-mining """ +from collections import Counter from itertools import combinations @@ -32,7 +33,7 @@ def prune(itemset: list, candidates: list, length: int) -> list: the frequent itemsets of the previous iteration (valid subsequences of the frequent itemsets from the previous iteration). - Prunes candidate itemsets that are not frequent. + Prunes candidate itemsets that are not frequent using Counter for optimization. >>> itemset = ['X', 'Y', 'Z'] >>> candidates = [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']] @@ -44,11 +45,14 @@ def prune(itemset: list, candidates: list, length: int) -> list: >>> prune(itemset, candidates, 3) [] """ + itemset_counter = Counter(tuple(x) for x in itemset) pruned = [] + for candidate in candidates: is_subsequence = True for item in candidate: - if item not in itemset or itemset.count(item) < length - 1: + tupla = tuple(item) + if tupla not in itemset_counter or itemset_counter[tupla] < length - 1: is_subsequence = False break if is_subsequence: