Skip to content

Commit bd652ad

Browse files
authored
CU-8698f8fgc: Fix negative sampling including indices for words without a vector (CogStack/MedCAT#524)
* CU-8698f8fgc: Add new test to check that the negative sampling indices do not include non-vectored indices * CU-8698f8fgc: Add fix for negative sampling including indices for words without a vector * CU-8698f8fgc: Update tests to make sure index frequencies are respected * CU-8698f8fgc: Add 3.9-friendly counter totalling method
1 parent 896b1d2 commit bd652ad

File tree

2 files changed

+58
-12
lines changed

2 files changed

+58
-12
lines changed

medcat-v1/medcat/vocab.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,15 +190,28 @@ def make_unigram_table(self, table_size: int = -1) -> None:
190190
"the creation of a massive array. So therefore, there "
191191
"is no need to pass the `table_size` parameter anymore.")
192192
freqs = []
193-
for word in self.vec_index2word.values():
193+
# index list maps the slot in which a word index
194+
# sits in vec_index2word to the actual index for said word
195+
# e.g:
196+
# if we have words indexed 0, 1, and 2
197+
# but only 0, and 2 have corresponding vectors
198+
# then only 0 and 2 will occur in vec_index2word
199+
# and while 0 will be in the 0th position (as expected)
200+
# in the final probability list, 2 will be in 1st position
201+
# so we need to mark that conversion down
202+
index_list = []
203+
for word_index, word in self.vec_index2word.items():
194204
freqs.append(self[word])
205+
index_list.append(word_index)
195206

196207
# Power and normalize frequencies
197208
freqs = np.array(freqs) ** (3/4)
198209
freqs /= freqs.sum()
199210

200211
# Calculate cumulative probabilities
201212
self.cum_probs = np.cumsum(freqs)
213+
# the mapping from vector index order to word indices
214+
self._index_list = index_list
202215

203216
def get_negative_samples(self, n: int = 6, ignore_punct_and_num: bool = False) -> List[int]:
204217
"""Get N negative samples.
@@ -216,8 +229,11 @@ def get_negative_samples(self, n: int = 6, ignore_punct_and_num: bool = False) -
216229
if len(self.cum_probs) == 0:
217230
self.make_unigram_table()
218231
random_vals = np.random.rand(n)
219-
# NOTE: there's a change in numpy
220-
inds = cast(List[int], np.searchsorted(self.cum_probs, random_vals).tolist())
232+
# NOTE: These indices are in terms of the cum_probs array
233+
# which only has word data for words with vectors.
234+
vec_slots = cast(List[int], np.searchsorted(self.cum_probs, random_vals).tolist())
235+
# so we need to translate these back to word indices
236+
inds = list(map(self._index_list.__getitem__, vec_slots))
221237

222238
if ignore_punct_and_num:
223239
# Do not return anything that does not have letters in it

medcat-v1/tests/test_vocab.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,37 +43,67 @@ class VocabUnigramTableTests(unittest.TestCase):
4343
"..", "examples", "vocab_data.txt")
4444
UNIGRAM_TABLE_SIZE = 10_000
4545
# found that this seed had the closest frequency at the sample size we're at
46-
RANDOM_SEED = 4976
46+
RANDOM_SEED = 32
4747
NUM_SAMPLES = 20 # NOTE: 3, 9, 18, and 27 at a time are regular due to context vector sizes
4848
NUM_TIMES = 200
49-
# based on the counts on vocab_data.txt and the one set in setUpClass
50-
EXPECTED_FREQUENCIES = [0.62218692, 0.32422858, 0.0535845]
49+
# based on the counts on vocab_data.txt and the ones set in setUpClass
50+
# plus the power of 3/4
51+
EXPECTED_FREQUENCIES = {
52+
0: 0.61078822, 1: 0.3182886,
53+
2: 0.05260281,
54+
# NOTE: no 3 since that's got no vectors
55+
4: 0.01832037}
5156
TOLERANCE = 0.001
5257

5358
@classmethod
5459
def setUpClass(cls):
5560
cls.vocab = Vocab()
5661
cls.vocab.add_words(cls.EXAMPLE_DATA_PATH)
5762
cls.vocab.add_word("test", cnt=1310, vec=[1.42, 1.44, 1.55])
63+
cls.vocab.add_word("vectorless", cnt=1234, vec=None)
64+
cls.vocab.add_word("withvector", cnt=321, vec=[1.3, 1.2, 0.8])
5865
cls.vocab.make_unigram_table(table_size=cls.UNIGRAM_TABLE_SIZE)
5966

6067
def setUp(self):
6168
np.random.seed(self.RANDOM_SEED)
6269

6370
@classmethod
64-
def _get_freqs(cls) -> list[float]:
71+
def _get_freqs(cls) -> dict[int, float]:
6572
c = Counter()
6673
for _ in range(cls.NUM_TIMES):
6774
got = cls.vocab.get_negative_samples(cls.NUM_SAMPLES)
6875
c += Counter(got)
69-
total = sum(c[i] for i in c)
70-
got_freqs = [c[i]/total for i in range(len(cls.EXPECTED_FREQUENCIES))]
76+
total = sum(c.values())
77+
got_freqs = {index: val/total for index, val in c.items()}
7178
return got_freqs
7279

73-
def assert_accurate_enough(self, got_freqs: list[float]):
80+
@classmethod
81+
def _get_abs_max_diff(cls, dict1: dict[int, float],
82+
dict2: dict[int, float]):
83+
assert dict1.keys() == dict2.keys()
84+
vals1, vals2 = [], []
85+
for index in dict1:
86+
vals1.append(dict1[index])
87+
vals2.append(dict2[index])
88+
return np.max(np.abs(np.array(vals1) - np.array(vals2)))
89+
90+
def assert_accurate_enough(self, got_freqs: dict[int, float]):
91+
self.assertEqual(got_freqs.keys(), self.EXPECTED_FREQUENCIES.keys())
7492
self.assertTrue(
75-
np.max(np.abs(np.array(got_freqs) - self.EXPECTED_FREQUENCIES)) < self.TOLERANCE
76-
)
93+
self._get_abs_max_diff(self.EXPECTED_FREQUENCIES, got_freqs) < self.TOLERANCE)
94+
95+
def test_does_not_include_vectorless_indices(self, num_samples: int = 100):
96+
inds = self.vocab.get_negative_samples(num_samples)
97+
for index in inds:
98+
with self.subTest(f"Index: {index}"):
99+
# in the right list
100+
self.assertIn(index, self.vocab.vec_index2word)
101+
word = self.vocab.vec_index2word[index]
102+
info = self.vocab.vocab[word]
103+
# the info has vector
104+
self.assertIn("vec", info)
105+
# the vector is an array or a list
106+
self.assertIsInstance(self.vocab.vec(word), (np.ndarray, list),)
77107

78108
def test_negative_sampling(self):
79109
got_freqs = self._get_freqs()

0 commit comments

Comments
 (0)