Skip to content

Commit eede5fe

Browse files
committed
add tests
Signed-off-by: Tim Schopf <tim.schopf@t-online.de>
1 parent e01123a commit eede5fe

File tree

4 files changed

+115
-2
lines changed

4 files changed

+115
-2
lines changed

.readthedocs.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ python:
2626
build:
2727
os: ubuntu-22.04
2828
tools:
29-
python: "3.8"
29+
python: "3.7"
3030

3131
submodules:
3232
include: all

tests/requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
pytest>=7.0.1
22
keybert>=0.5.0
33
flair==0.11.3
4-
scipy==1.7.3
4+
scipy==1.7.3
5+
bertopic>=0.16.1
6+
datasets==2.13.2

tests/test_vectorizers.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import flair
44
import spacy
5+
from bertopic import BERTopic
6+
from datasets import load_dataset
57
from flair.models import SequenceTagger
68
from flair.tokenization import SegtokSentenceSplitter
79
from keybert import KeyBERT
@@ -132,3 +134,48 @@ def custom_pos_tagger(raw_documents: List[str], tagger: flair.models.SequenceTag
132134
keyphrases = vectorizer.get_feature_names_out()
133135

134136
assert sorted(keyphrases) == sorted_english_test_keyphrases
137+
138+
139+
def test_online_vectorizer():
140+
first_doc_count_matrix = utils.get_sorted_english_first_doc_count_matrix()
141+
second_doc_count_matrix = utils.get_sorted_english_second_doc_count_matrix()
142+
first_doc_test_keyphrases = utils.get_english_first_doc_test_keyphrases()
143+
english_keyphrases = utils.get_english_test_keyphrases()
144+
frequencies_after_min_df = utils.get_frequencies_after_min_df()
145+
frequent_keyphrases_after_min_df = utils.get_frequent_keyphrases_after_min_df()
146+
frequencies_after_bow = utils.get_frequencies_after_bow()
147+
148+
# intitial vectorizer fit
149+
vectorizer = KeyphraseCountVectorizer(decay=0.5, delete_min_df=3)
150+
151+
assert [sorted(count_list) for count_list in
152+
vectorizer.fit_transform([english_docs[0]]).toarray()] == first_doc_count_matrix
153+
assert sorted(vectorizer.get_feature_names_out()) == first_doc_test_keyphrases
154+
155+
# learn additional keyphrases from new documents with partial fit
156+
vectorizer.partial_fit([english_docs[1]])
157+
158+
assert [sorted(count_list) for count_list in
159+
vectorizer.transform([english_docs[1]]).toarray()] == second_doc_count_matrix
160+
assert sorted(vectorizer.get_feature_names_out()) == english_keyphrases
161+
162+
# update list of learned keyphrases according to 'delete_min_df'
163+
vectorizer.update_bow([english_docs[1]])
164+
assert (vectorizer.transform([english_docs[1]]).toarray() == frequencies_after_min_df).all()
165+
166+
# check updated list of learned keyphrases (only the ones that appear more than 'delete_min_df' remain)
167+
assert sorted(vectorizer.get_feature_names_out()) == frequent_keyphrases_after_min_df
168+
169+
# update again and check the impact of 'decay' on the learned document-keyphrase matrix
170+
vectorizer.update_bow([english_docs[1]])
171+
assert (vectorizer.X_.toarray() == frequencies_after_bow).all()
172+
173+
174+
def test_bertopic():
175+
data = load_dataset("ag_news")
176+
texts = data['train']['text']
177+
texts = texts[:100]
178+
topic_model = BERTopic(vectorizer_model=KeyphraseCountVectorizer())
179+
topics, probs = topic_model.fit_transform(documents=texts)
180+
new_topics = topic_model.reduce_outliers(texts, topics)
181+
topic_model.update_topics(texts, topics=new_topics)

tests/utils.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
def get_english_test_docs():
23
english_docs = ["""Supervised learning is the machine learning task of learning a function that
34
maps an input to an output based on example input-output pairs. It infers a
@@ -56,6 +57,36 @@ def get_english_test_keyphrases():
5657
return sorted_english_test_keyphrases
5758

5859

60+
def get_english_first_doc_test_keyphrases():
61+
sorted_english_first_doc_test_keyphrases = ['algorithm',
62+
'class labels',
63+
'example',
64+
'function',
65+
'inductive bias',
66+
'input',
67+
'input object',
68+
'machine',
69+
'new examples',
70+
'optimal scenario',
71+
'output',
72+
'output pairs',
73+
'output value',
74+
'pair',
75+
'set',
76+
'supervised learning',
77+
'supervised learning algorithm',
78+
'supervisory signal',
79+
'task',
80+
'training data',
81+
'training examples',
82+
'unseen instances',
83+
'unseen situations',
84+
'vector',
85+
'way']
86+
87+
return sorted_english_first_doc_test_keyphrases
88+
89+
5990
def get_sorted_english_keyphrases_custom_flair_tagger():
6091
sorted_english_custom_tagger_keyphrases = ['algorithm', 'class labels', 'document', 'document content',
6192
'document relevance',
@@ -102,6 +133,21 @@ def get_sorted_english_count_matrix():
102133
return sorted_english_count_matrix
103134

104135

136+
def get_sorted_english_first_doc_count_matrix():
137+
sorted_english_first_doc_count_matrix = [
138+
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 3, 3, 3, 3, 3, 3]]
139+
140+
return sorted_english_first_doc_count_matrix
141+
142+
143+
def get_sorted_english_second_doc_count_matrix():
144+
sorted_english_second_doc_count_matrix = [
145+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
146+
1, 2, 2, 5, 5]]
147+
148+
return sorted_english_second_doc_count_matrix
149+
150+
105151
def get_sorted_french_count_matrix():
106152
sorted_french_count_matrix = [[1, 1, 1, 1]]
107153

@@ -130,3 +176,21 @@ def get_english_keybert_keyphrases():
130176
'document content']]
131177

132178
return english_keybert_keyphrases
179+
180+
181+
def get_frequencies_after_min_df():
182+
frequency_array = np.array([[5, 5]])
183+
184+
return frequency_array
185+
186+
187+
def get_frequencies_after_bow():
188+
frequency_array = np.array([[7.5, 7.5]])
189+
190+
return frequency_array
191+
192+
193+
def get_frequent_keyphrases_after_min_df():
194+
keyphrases = ['document', 'keywords']
195+
196+
return keyphrases

0 commit comments

Comments
 (0)