@@ -43,37 +43,67 @@ class VocabUnigramTableTests(unittest.TestCase):
4343 ".." , "examples" , "vocab_data.txt" )
4444 UNIGRAM_TABLE_SIZE = 10_000
4545 # found that this seed had the closest frequency at the sample size we're at
46- RANDOM_SEED = 4976
46+ RANDOM_SEED = 32
4747 NUM_SAMPLES = 20 # NOTE: 3, 9, 18, and 27 at a time are regular due to context vector sizes
4848 NUM_TIMES = 200
49- # based on the counts on vocab_data.txt and the one set in setUpClass
50- EXPECTED_FREQUENCIES = [0.62218692 , 0.32422858 , 0.0535845 ]
49+ # based on the counts on vocab_data.txt and the ones set in setUpClass
50+ # plus the power of 3/4
51+ EXPECTED_FREQUENCIES = {
52+ 0 : 0.61078822 , 1 : 0.3182886 ,
53+ 2 : 0.05260281 ,
54+ # NOTE: no 3 since that's got no vectors
55+ 4 : 0.01832037 }
5156 TOLERANCE = 0.001
5257
5358 @classmethod
5459 def setUpClass (cls ):
5560 cls .vocab = Vocab ()
5661 cls .vocab .add_words (cls .EXAMPLE_DATA_PATH )
5762 cls .vocab .add_word ("test" , cnt = 1310 , vec = [1.42 , 1.44 , 1.55 ])
63+ cls .vocab .add_word ("vectorless" , cnt = 1234 , vec = None )
64+ cls .vocab .add_word ("withvector" , cnt = 321 , vec = [1.3 , 1.2 , 0.8 ])
5865 cls .vocab .make_unigram_table (table_size = cls .UNIGRAM_TABLE_SIZE )
5966
6067 def setUp (self ):
6168 np .random .seed (self .RANDOM_SEED )
6269
6370 @classmethod
64- def _get_freqs (cls ) -> list [ float ]:
71+ def _get_freqs (cls ) -> dict [ int , float ]:
6572 c = Counter ()
6673 for _ in range (cls .NUM_TIMES ):
6774 got = cls .vocab .get_negative_samples (cls .NUM_SAMPLES )
6875 c += Counter (got )
69- total = sum (c [ i ] for i in c )
70- got_freqs = [ c [ i ] / total for i in range ( len ( cls . EXPECTED_FREQUENCIES ))]
76+ total = sum (c . values () )
77+ got_freqs = { index : val / total for index , val in c . items ()}
7178 return got_freqs
7279
73- def assert_accurate_enough (self , got_freqs : list [float ]):
80+ @classmethod
81+ def _get_abs_max_diff (cls , dict1 : dict [int , float ],
82+ dict2 : dict [int , float ]):
83+ assert dict1 .keys () == dict2 .keys ()
84+ vals1 , vals2 = [], []
85+ for index in dict1 :
86+ vals1 .append (dict1 [index ])
87+ vals2 .append (dict2 [index ])
88+ return np .max (np .abs (np .array (vals1 ) - np .array (vals2 )))
89+
90+ def assert_accurate_enough (self , got_freqs : dict [int , float ]):
91+ self .assertEqual (got_freqs .keys (), self .EXPECTED_FREQUENCIES .keys ())
7492 self .assertTrue (
75- np .max (np .abs (np .array (got_freqs ) - self .EXPECTED_FREQUENCIES )) < self .TOLERANCE
76- )
93+ self ._get_abs_max_diff (self .EXPECTED_FREQUENCIES , got_freqs ) < self .TOLERANCE )
94+
95+ def test_does_not_include_vectorless_indices (self , num_samples : int = 100 ):
96+ inds = self .vocab .get_negative_samples (num_samples )
97+ for index in inds :
98+ with self .subTest (f"Index: { index } " ):
99+ # in the right list
100+ self .assertIn (index , self .vocab .vec_index2word )
101+ word = self .vocab .vec_index2word [index ]
102+ info = self .vocab .vocab [word ]
103+ # the info has vector
104+ self .assertIn ("vec" , info )
105+ # the vector is an array or a list
106+ self .assertIsInstance (self .vocab .vec (word ), (np .ndarray , list ),)
77107
78108 def test_negative_sampling (self ):
79109 got_freqs = self ._get_freqs ()
0 commit comments