diff --git a/swig/ctc_beam_search_decoder.cpp b/swig/ctc_beam_search_decoder.cpp index c86475f..ac343a4 100644 --- a/swig/ctc_beam_search_decoder.cpp +++ b/swig/ctc_beam_search_decoder.cpp @@ -137,7 +137,8 @@ std::vector>> ctc_beam_search_decoder( float hotwords_score = 0.0; std::vector ngram; PathTrie *prefix_to_score = nullptr; - if (hotwords_scorer != nullptr && !hotwords_scorer->hotwords_dict.empty()) { + if (hotwords_scorer != nullptr && !hotwords_scorer->hotwords_dict.empty() && + (c == space_id || hotwords_scorer->is_character_based)) { if (hotwords_scorer->is_character_based) { prefix_to_score = prefix_new; } else { @@ -151,7 +152,7 @@ std::vector>> ctc_beam_search_decoder( // language model scoring float ngram_score = 0.0; - if (ext_scorer != nullptr ) { + if (ext_scorer != nullptr && (c == space_id || ext_scorer->is_character_based())) { if (hotwords_scorer != nullptr && !hotwords_scorer->hotwords_dict.empty() && !(hotwords_scorer->is_character_based ^ ext_scorer->is_character_based()) && hotwords_scorer->window_length >= ext_scorer->get_max_order()) { @@ -159,17 +160,15 @@ std::vector>> ctc_beam_search_decoder( std::vector::const_iterator last = ngram.end(); std::vector slice_ngram(first, last); ngram_score = ext_scorer->get_log_cond_prob(slice_ngram) * ext_scorer->alpha + ext_scorer->beta; - } else { - if (c == space_id || ext_scorer->is_character_based()) { - // skip scoring the space - if (ext_scorer->is_character_based()) { - prefix_to_score = prefix_new; - } else { - prefix_to_score = prefix; - } - ngram = ext_scorer->make_ngram(prefix_to_score); - ngram_score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha + ext_scorer->beta; + } + else { + if (ext_scorer->is_character_based()) { + prefix_to_score = prefix_new; + } else { + prefix_to_score = prefix; } + ngram = ext_scorer->make_ngram(prefix_to_score); + ngram_score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha + ext_scorer->beta; } } log_p += ngram_score; diff --git a/swig/hotwords.cpp b/swig/hotwords.cpp index 31f3995..1e8207e 100644 --- a/swig/hotwords.cpp +++ b/swig/hotwords.cpp @@ -90,8 +90,11 @@ float HotWordsScorer::get_hotwords_score(const std::vector& words, // word = std::accumulate(words.begin() + offset, words.end() - index, std::string{}); word = std::accumulate(words.begin() + offset + index, words.end(), std::string{}); } else { - // word in fixed window, traverse each word in words. - word = words[index]; + // word in fixed window, traverse each word in words, skip token. + if (index + offset >= words_size) { + break; + } + word = words[index + offset]; } iter = this->hotwords_dict.find(word); if (iter != this->hotwords_dict.end()) {