diff --git a/lib/linnaeus/classifier.rb b/lib/linnaeus/classifier.rb index 89750cd..e32cd73 100644 --- a/lib/linnaeus/classifier.rb +++ b/lib/linnaeus/classifier.rb @@ -38,10 +38,10 @@ def classification_scores(text) scores = {} @db.get_categories.each do |category| - words_with_count_for_category = @db.get_words_with_count_for_category category - total_word_count_sum_for_category = words_with_count_for_category.values.reduce(0){|sum, count| sum += count.to_i} + total_word_count_sum_for_category = @db.get_total_word_count_for_category category scores[category] = 0 + words_with_count_for_category = @db.fetch_scores_for_words(category, text.encode(@encoding).downcase.split) count_word_occurrences(text).each do |word, count| tmp_score = (words_with_count_for_category[word].nil?) ? 0.1 : words_with_count_for_category[word].to_i scores[category] += Math.log(tmp_score / total_word_count_sum_for_category.to_f) diff --git a/lib/linnaeus/persistence.rb b/lib/linnaeus/persistence.rb index fdd8f82..39520a6 100644 --- a/lib/linnaeus/persistence.rb +++ b/lib/linnaeus/persistence.rb @@ -64,7 +64,7 @@ def get_categories @redis.smembers category_collection_key end - # Get a list of words with their number of occurrences. + # Get a list of all words with their number of occurrences. # # == Parameters # category:: @@ -76,6 +76,64 @@ def get_words_with_count_for_category(category) @redis.hgetall base_category_key + category end + # Get a list of words with their number of occurrences. + # + # == Parameters + # category:: + # A string representing a category. + # + # wordlist:: + # A list of words for which to fetch scores. + # + # == Returns + # A hash with the word counts for the requested words + def fetch_scores_for_words(category, wordlist) + Hash[wordlist.zip(@redis.pipelined do + wordlist.each do |word| + @redis.hget base_category_key + category, word + end + end)] + end + + # Create or return the sum of all the word counts in + # this category + # + # == Parameters + # category:: + # A string representing a category. + # + # == Returns + # The sum of all word counts in this category + def get_total_word_count_for_category(category) + ret = @redis.get sum_key(category) + # create :total key for this category if it doesn't already exist + if !ret + words_with_count_for_category = get_words_with_count_for_category category + total_word_count_sum_for_category = words_with_count_for_category.values.reduce(0){|sum, count| sum += count.to_i} + @redis.set sum_key(category), total_word_count_sum_for_category + end + ret + end + + # Get a list of words and their scores + # + # == Parameters + # category:: + # A string representing a category. + # + # wordlist:: + # A list of words to look up. + # + # == Returns + # A hash with the counts each word. + def fetch_scores_for_words(category, wordlist) + Hash[wordlist.zip(@redis.pipelined do + wordlist.each do |word| + @redis.hget base_category_key + category, word + end + end)] + end + # Clear all training data from the backend. def clear_all_training_data @redis.flushdb @@ -99,7 +157,10 @@ def clear_training_data # A hash containing a count of the number of word occurences in a document def increment_word_counts_for_category(category, word_occurrences) word_occurrences.each do|word,count| - @redis.hincrby base_category_key + category, word, count + @redis.pipelined do + @redis.hincrby base_category_key + category, word, count + @redis.incrby sum_key(category), count + end end end @@ -112,7 +173,10 @@ def increment_word_counts_for_category(category, word_occurrences) # A hash containing a count of the number of word occurences in a document def decrement_word_counts_for_category(category, word_occurrences) word_occurrences.each do|word,count| - @redis.hincrby base_category_key + category, word, - count + @redis.pipelined do + @redis.hincrby base_category_key + category, word, - count + @redis.incrby sum_key(category), - count + end end end @@ -146,6 +210,12 @@ def base_category_key [ base_key, 'cat:' ].flatten.join(':') end + # The name for the key that holds the sum for all the words in + # a category. + def sum_key(category) + base_category_key + category + ':total' + end + def base_key [ 'Linnaeus', @scope ].compact end diff --git a/spec/linnaeus_classifier_spec.rb b/spec/linnaeus_classifier_spec.rb index 63ccc6e..b75e69c 100644 --- a/spec/linnaeus_classifier_spec.rb +++ b/spec/linnaeus_classifier_spec.rb @@ -24,7 +24,7 @@ { "movie"=>-6.272877006546167, "bird"=>-4.2626798770413155 } ) subject.classification_scores('a directorial bird').should eq( - { "movie"=>-10.24316892009829, "bird"=>-10.827944847076676 } + { "movie"=>-12.545754013092335, "bird"=>-10.827944847076676 } ) end end diff --git a/spec/linnaeus_persistence_spec.rb b/spec/linnaeus_persistence_spec.rb index b62bded..13c626a 100644 --- a/spec/linnaeus_persistence_spec.rb +++ b/spec/linnaeus_persistence_spec.rb @@ -14,9 +14,15 @@ it 'sets keys properly with defaults' do lp2 = get_linnaeus_persistence train_a_document_in('foobar') - lp2.redis.keys('*').should match_array ['Linnaeus:category', 'Linnaeus:cat:foobar'] + lp2.redis.keys('*').should match_array ['Linnaeus:category', 'Linnaeus:cat:foobar', 'Linnaeus:cat:foobar:total'] end + it 'has the right totals' do + lp2 = get_linnaeus_persistence + train_a_document_in('foobar') + lp2.redis.get('Linnaeus:cat:foobar:total').should eq '5' + end + context "custom scopes" do it 'sets keys properly' do lp2 = get_linnaeus_persistence(scope: 'new-scope') @@ -25,7 +31,7 @@ train_a_document_in('foobar', scope: 'new-scope') lp2.redis.keys('*').should match_array [ - 'Linnaeus:new-scope:cat:foobar', 'Linnaeus:new-scope:category' + 'Linnaeus:new-scope:cat:foobar', 'Linnaeus:new-scope:category', 'Linnaeus:new-scope:cat:foobar:total' ] end @@ -40,13 +46,14 @@ lp.redis.keys.should match_array [ "Linnaeus:cat:foobar", "Linnaeus:category", - "Linnaeus:new-scope:cat:foobar", "Linnaeus:new-scope:category" + "Linnaeus:new-scope:cat:foobar", "Linnaeus:new-scope:category", + "Linnaeus:cat:foobar:total", "Linnaeus:new-scope:cat:foobar:total" ] lp2.clear_training_data lp.redis.keys.should match_array [ - "Linnaeus:cat:foobar", "Linnaeus:category" + "Linnaeus:cat:foobar", "Linnaeus:category", "Linnaeus:cat:foobar:total" ] end