diff --git a/.clang-format b/.clang-format new file mode 100644 index 00000000..c0582c0e --- /dev/null +++ b/.clang-format @@ -0,0 +1,41 @@ +--- +Language: Cpp +BasedOnStyle: WebKit +AlignAfterOpenBracket: Align +AlignOperands: true +AlignTrailingComments: true +AllowAllArgumentsOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: false +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: Yes +BinPackArguments: false +BinPackParameters: false +BraceWrapping: + AfterClass: true + AfterControlStatement: false + AfterEnum: true + AfterFunction: true + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: true + AfterUnion: true + AfterExternBlock: true + BeforeCatch: false + BeforeElse: true + IndentBraces: false + SplitEmptyFunction: true + SplitEmptyRecord: false + SplitEmptyNamespace: true +ColumnLimit: 100 +IndentCaseLabels: true +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: true +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +Standard: Cpp11 +TabWidth: 4 +UseTab: Never +... diff --git a/.gitmodules b/.gitmodules index 8a98f787..ba408aff 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ -[submodule "third_party/kenlm"] - path = third_party/kenlm - url = https://github.com/kpu/kenlm.git [submodule "third_party/ThreadPool"] path = third_party/ThreadPool url = https://github.com/progschj/ThreadPool.git +[submodule "third_party/kenlm"] + path = third_party/kenlm + url = https://github.com/kpu/kenlm diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 00000000..6686b5a4 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,44 @@ +cmake_minimum_required(VERSION 3.16 FATAL_ERROR) +set(CMAKE_CXX_STANDARD 17) + +# project name +project(CTCBeamDecoder CXX) + +# define path to the libtorch extracted folder +set(CMAKE_PREFIX_PATH ${CMAKE_SOURCE_DIR}/third_party/libtorch) + +# find torch library and all necessary files +find_package(Torch REQUIRED) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") + +# add cxxopts library for command line parsing +include(FetchContent) + +FetchContent_Declare( + cxxopts + GIT_REPOSITORY https://github.com/jarro2783/cxxopts.git + GIT_TAG v3.1.1 +) + +FetchContent_GetProperties(cxxopts) + +if (NOT cxxopts_POPULATED) + FetchContent_Populate(cxxopts) + add_subdirectory(${cxxopts_SOURCE_DIR} ${cxxopts_BINARY_DIR}) +endif() + +# add sudirectories +add_subdirectory(ctcdecode) +add_subdirectory(${CMAKE_SOURCE_DIR}/third_party) +add_subdirectory(${CMAKE_SOURCE_DIR}/tests/cpp) + +# build_fst library +add_library(build_fst_lib ${CMAKE_SOURCE_DIR}/tools/build_fst.cpp) +target_include_directories(build_fst_lib PUBLIC ${CMAKE_SOURCE_DIR}/tools) +target_link_libraries(build_fst_lib PUBLIC ctcdecode "${TORCH_LIBRARIES}" cxxopts pthread dl) + +# executable to add that we want to compile and run +add_executable(build_fst ${CMAKE_SOURCE_DIR}/tools/build_fst_main.cpp) + +# link libraries to our executable +target_link_libraries(build_fst build_fst_lib) \ No newline at end of file diff --git a/README.md b/README.md index e7a91929..8630bb3e 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,11 @@ git clone --recursive https://github.com/parlance/ctcdecode.git cd ctcdecode && pip install . ``` +To build ctcdecode library, +```bash +bash build.sh +``` + ## How to Use ```python diff --git a/build.sh b/build.sh new file mode 100644 index 00000000..11ab400a --- /dev/null +++ b/build.sh @@ -0,0 +1,52 @@ + +#!/bin/bash + +# Download libtorch built CPU libraries +URL="https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-2.0.1%2Bcpu.zip" # stable version 2.0.1 +LIBTORCH_FILE_NAME="libtorch-shared-with-deps-2.0.1+cpu.zip" +BUILD_DIR="build" + +# Check if the file exists +if [ ! -f "third_party/$LIBTORCH_FILE_NAME" ]; then + # If the file doesn't exist, download it + cd third_party + wget "$URL" + # Unzip the file + unzip "$LIBTORCH_FILE_NAME" + cd .. +fi + +download_and_extract(){ + URL=$1 + FILE_NAME=$2 + if [ ! -f "third_party/$FILE_NAME" ]; then + # If the file doesn't exist, download it + cd third_party + wget "$URL" + # Unzip the file + tar -xvzf "$FILE_NAME" + cd .. + fi + +} + +# Download OpenFST +URL="https://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.8.2.tar.gz" +OPENFST_FILE_NAME="openfst-1.8.2.tar.gz" +download_and_extract "$URL" "$OPENFST_FILE_NAME" + + + +# Download boost +URL="https://github.com/parlance/ctcdecode/releases/download/v1.0/boost_1_67_0.tar.gz" +BOOST_FILE_NAME="boost_1_67_0.tar.gz" +download_and_extract "$URL" "$BOOST_FILE_NAME" + + +if [ ! -d "$BUILD_DIR" ]; then + mkdir "$BUILD_DIR" +fi + +cd build +cmake .. +make \ No newline at end of file diff --git a/ctcdecode/CMakeLists.txt b/ctcdecode/CMakeLists.txt new file mode 100644 index 00000000..7e7d1716 --- /dev/null +++ b/ctcdecode/CMakeLists.txt @@ -0,0 +1,45 @@ +cmake_minimum_required(VERSION 3.16 FATAL_ERROR) + +set(CMAKE_CXX_STANDARD 17) + +# find python package +find_package(Python COMPONENTS Interpreter Development) +# Check if Python was found +if(Python_FOUND) + message("Python found: ${Python_EXECUTABLE}") + + # Include directories provided by Python + include_directories(${Python_INCLUDE_DIRS}) + +else() + message("Python not found.") +endif() + +# build pybind11 +include(FetchContent) + +FetchContent_Declare( + pybind11 + GIT_REPOSITORY https://github.com/pybind/pybind11.git + GIT_TAG v2.10.4 +) + +FetchContent_MakeAvailable(pybind11) + +# define path to the libtorch extracted folder +set(CMAKE_PREFIX_PATH ${CMAKE_SOURCE_DIR}/third_party/libtorch) + +#find torch library and all necessary files +find_package(Torch REQUIRED) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") + +add_compile_options("-O3" "-DKENLM_MAX_ORDER=6" "-std=c++17" "-fPIC" "-DINCLUDE_KENLM") + +# ctc decode +file(GLOB CTC_SOURCES ${CMAKE_SOURCE_DIR}/ctcdecode/src/*.cpp) +add_library(ctcdecode STATIC "${CTC_SOURCES}") + +target_include_directories(ctcdecode PUBLIC ${CMAKE_SOURCE_DIR}/third_party/kenlm ${CMAKE_SOURCE_DIR}/third_party/openfst-1.8.2/src/include ${CMAKE_SOURCE_DIR}/third_party/utf8 ${CMAKE_SOURCE_DIR}/third_party/ThreadPool ${CMAKE_SOURCE_DIR}/third_party/boost_1_67_0 ${CMAKE_SOURCE_DIR}/ctcdecode/src ) +target_link_libraries(ctcdecode PUBLIC "${TORCH_LIBRARIES}" kenlm fst pybind11::module) + +# message("${CTC_SOURCES}") \ No newline at end of file diff --git a/ctcdecode/__init__.py b/ctcdecode/__init__.py index 1715b778..e724fd4c 100644 --- a/ctcdecode/__init__.py +++ b/ctcdecode/__init__.py @@ -1,3 +1,6 @@ +from time import time +from typing import List, Optional, Union + import torch from ._ext import ctc_decode @@ -21,20 +24,29 @@ class CTCBeamDecoder(object): num_processes (int): Parallelize the batch using num_processes workers. blank_id (int): Index of the CTC blank token (probably 0) used when training your model. log_probs_input (bool): False if your model has passed through a softmax and output probabilities sum to 1. + is_bpe_based (bool): True if your labels contains bpe tokens else False + lm_type (str): Whether the language model file is character, bpe or word based + token_separator (str): prefix of the bpe tokens. Default value is "#" and it is always assumed that the tokens + starting with this prefix are meant to be merged with tokens that doesn't contain this prefix """ def __init__( self, - labels, - model_path=None, - alpha=0, - beta=0, - cutoff_top_n=40, - cutoff_prob=1.0, - beam_width=100, - num_processes=4, - blank_id=0, - log_probs_input=False, + labels: List[str], + model_path: Optional[str] = None, + alpha: float = 0, + beta: float = 0, + cutoff_top_n: int = 40, + cutoff_prob: float = 1.0, + beam_width: int = 100, + num_processes: int = 4, + blank_id: int = 0, + log_probs_input: bool = False, + is_bpe_based: bool = False, + unk_score: float = -5.0, + lm_type: str = "character", + token_separator: str = "#", + lexicon_fst_path: Optional[str] = None, ): self.cutoff_top_n = cutoff_top_n self._beam_width = beam_width @@ -43,20 +55,87 @@ def __init__( self._labels = list(labels) # Ensure labels are a list self._num_labels = len(labels) self._blank_id = blank_id - self._log_probs = 1 if log_probs_input else 0 + self._log_probs = True if log_probs_input else False + self.token_separator = token_separator + + lexicon_fst_path = lexicon_fst_path if lexicon_fst_path is not None else "" + if model_path: self._scorer = ctc_decode.paddle_get_scorer( - alpha, beta, model_path.encode(), self._labels, self._num_labels + alpha, + beta, + model_path.encode(), + self._labels, + lm_type, + lexicon_fst_path.encode(), ) + self._is_bpe_based = is_bpe_based self._cutoff_prob = cutoff_prob - def decode(self, probs, seq_lens=None): + self.decoder_options = ctc_decode.paddle_get_decoder_options( + self._labels, + cutoff_top_n, + cutoff_prob, + beam_width, + num_processes, + blank_id, + self._log_probs, + is_bpe_based, + unk_score, + token_separator, + ) + + def create_hotword_scorer( + self, + hotwords: List[List[str]], + hotword_weight: Union[float, List[float]] = 10.0, + ): + """ + Method to create hotword scorer object for the given hotwords + Args: + hotwords (List[List[str]]) - Tokenized list of hotwords. + For example: + Hotword list for BPE token inputs = [ ["co", "##r", "##p"], ["t", "##es","##t"] ] + Hotword list for character inputs = [ ['c', 'o', 'r', 'p'], ['t', 'e', 's', 't'] ] + hotword_weight (Union[float, List[float]]) - Weight for each hotword. The weight for all the hotwords will be same when only one weight is provided. + ( default = 10.0 ) + """ + if isinstance(hotword_weight, float) or isinstance(hotword_weight, int): + hotword_weight = [hotword_weight] * len(hotwords) + elif ( + isinstance(hotword_weight, List) + and isinstance(hotwords, List) + and len(hotwords) != len(hotword_weight) + ): + raise ValueError("Hotword weight list and Hotwords length doesn't match.") + + hotword_scorer = ctc_decode.get_hotword_scorer( + self.decoder_options, hotwords, hotword_weight, self.token_separator + ) + + return hotword_scorer + + def decode( + self, + probs, + seq_lens=None, + hotword_scorer=None, + hotwords: List[List[str]] = None, + hotword_weight: Union[float, List[float]] = 10.0, + ): """ Conducts the beamsearch on model outputs and return results. Args: probs (Tensor) - A rank 3 tensor representing model outputs. Shape is batch x num_timesteps x num_labels. seq_lens (Tensor) - A rank 1 tensor representing the sequence length of the items in the batch. Optional, if not provided the size of axis 1 (num_timesteps) of `probs` is used for all items + hotwords (List[List[str]]) - Tokenized list of hotwords. + For example: + Hotword list for BPE token inputs = [ ["co", "##r", "##p"], ["t", "##es","##t"] ] + Hotword list for character inputs = [ ['c', 'o', 'r', 'p'], ['t', 'e', 's', 't'] ] + hotword_weight (Union[float, List[float]]) - This is the boost factor for scoring the hotword when appeared in the beam path. Recommend to + use the range between 0 - 15 for each hotword. If single value is provided then the same weigh will be used for all + the hotwords Returns: tuple: (beam_results, beam_scores, timesteps, out_lens) @@ -80,46 +159,71 @@ def decode(self, probs, seq_lens=None): seq_lens = torch.IntTensor(batch_size).fill_(max_seq_len) else: seq_lens = seq_lens.cpu().int() + + if hotwords and hotword_scorer: + raise ValueError( + "You can only provide either a hotword or a hotword scorer, not both at the same time.\n" + ) + + # if hotwords list is provided then create a scorer for it + if hotwords: + hotword_scorer = self.create_hotword_scorer(hotwords, hotword_weight) + output = torch.IntTensor(batch_size, self._beam_width, max_seq_len).cpu().int() - timesteps = torch.IntTensor(batch_size, self._beam_width, max_seq_len).cpu().int() + timesteps = ( + torch.IntTensor(batch_size, self._beam_width, max_seq_len).cpu().int() + ) scores = torch.FloatTensor(batch_size, self._beam_width).cpu().float() out_seq_len = torch.zeros(batch_size, self._beam_width).cpu().int() - if self._scorer: - ctc_decode.paddle_beam_decode_lm( + + if not self._scorer and not hotword_scorer: + ctc_decode.paddle_beam_decode( probs, seq_lens, - self._labels, - self._num_labels, - self._beam_width, - self._num_processes, - self._cutoff_prob, - self.cutoff_top_n, - self._blank_id, - self._log_probs, + self.decoder_options, + output, + timesteps, + scores, + out_seq_len, + ) + elif self._scorer and not hotword_scorer: + ctc_decode.paddle_beam_decode_with_lm( + probs, + seq_lens, + self.decoder_options, self._scorer, output, timesteps, scores, out_seq_len, ) + elif not self._scorer and hotword_scorer: + ctc_decode.paddle_beam_decode_with_hotwords( + probs, + seq_lens, + self.decoder_options, + hotword_scorer, + output, + timesteps, + scores, + out_seq_len, + ) else: - ctc_decode.paddle_beam_decode( + ctc_decode.paddle_beam_decode_with_lm_and_hotwords( probs, seq_lens, - self._labels, - self._num_labels, - self._beam_width, - self._num_processes, - self._cutoff_prob, - self.cutoff_top_n, - self._blank_id, - self._log_probs, + self.decoder_options, + self._scorer, + hotword_scorer, output, timesteps, scores, out_seq_len, ) + if hotwords: + self.delete_hotword_scorer(hotword_scorer) + return output, scores, timesteps, out_seq_len def character_based(self): @@ -138,6 +242,12 @@ def reset_params(self, alpha, beta): def __del__(self): if self._scorer is not None: ctc_decode.paddle_release_scorer(self._scorer) + if self.decoder_options: + ctc_decode.paddle_release_decoder_options(self.decoder_options) + + def delete_hotword_scorer(self, hw_scorer=None): + if hw_scorer: + ctc_decode.paddle_release_hotword_scorer(hw_scorer) class OnlineCTCBeamDecoder(object): @@ -158,7 +268,14 @@ class OnlineCTCBeamDecoder(object): num_processes (int): Parallelize the batch using num_processes workers. blank_id (int): Index of the CTC blank token (probably 0) used when training your model. log_probs_input (bool): False if your model has passed through a softmax and output probabilities sum to 1. + is_bpe_based (bool): True if your labels contains bpe tokens else False + lm_type (str): Whether the language model file is character, bpe or word based + token_separator (str): prefix of the bpe tokens. Default value is "#" and it is always assumed that the tokens + starting with this prefix are meant to be merged with tokens that doesn't contain this prefix + lexicon_fst_path (str): Path to the fst model file for decoding. It can be either be optimized or not. If not provided then + fst will not be used for decoding. Default value is None. """ + def __init__( self, labels, @@ -171,6 +288,11 @@ def __init__( num_processes=4, blank_id=0, log_probs_input=False, + is_bpe_based: bool = False, + unk_score: float = -5.0, + lm_type: str = "character", + token_separator: str = "#", + lexicon_fst_path: Optional[str] = None, ): self._cutoff_top_n = cutoff_top_n self._beam_width = beam_width @@ -180,9 +302,29 @@ def __init__( self._num_labels = len(labels) self._blank_id = blank_id self._log_probs = 1 if log_probs_input else 0 + lexicon_fst_path = lexicon_fst_path if lexicon_fst_path is not None else "" + + self.decoder_options = ctc_decode.paddle_get_decoder_options( + self._labels, + cutoff_top_n, + cutoff_prob, + beam_width, + num_processes, + blank_id, + self._log_probs, + is_bpe_based, + unk_score, + token_separator, + ) + if model_path: self._scorer = ctc_decode.paddle_get_scorer( - alpha, beta, model_path.encode(), self._labels, self._num_labels + alpha, + beta, + model_path.encode(), + self._labels, + lm_type, + lexicon_fst_path.encode(), ) self._cutoff_prob = cutoff_prob @@ -230,7 +372,7 @@ def decode(self, probs, states, is_eos_s, seq_lens=None): [state.state for state in states], is_eos_s, scores, - out_seq_len + out_seq_len, ) res_beam_results = res_beam_results.int() res_timesteps = res_timesteps.int() @@ -257,14 +399,10 @@ class DecoderState: Args: decoder (OnlineCTCBeamDecoder) - decoder you will use for decoding. """ + def __init__(self, decoder): self.state = ctc_decode.paddle_get_decoder_state( - decoder._labels, - decoder._beam_width, - decoder._cutoff_prob, - decoder._cutoff_top_n, - decoder._blank_id, - decoder._log_probs, + decoder.decoder_options, decoder._scorer, ) diff --git a/ctcdecode/src/binding.cpp b/ctcdecode/src/binding.cpp index c14a8221..08e543f6 100644 --- a/ctcdecode/src/binding.cpp +++ b/ctcdecode/src/binding.cpp @@ -1,29 +1,26 @@ -#include -#include -#include -#include +#include "boost/python.hpp" +#include "boost/python/stl_iterator.hpp" +#include "boost/shared_ptr.hpp" +#include #include -#include -#include "scorer.h" + #include "ctc_beam_search_decoder.h" +#include "decoder_options.h" +#include "scorer.h" #include "utf8.h" -#include "boost/shared_ptr.hpp" -#include "boost/python.hpp" -#include "boost/python/stl_iterator.hpp" -using namespace std; +namespace py = pybind11; -template -inline -std::vector< T > py_list_to_std_vector( const boost::python::object& iterable ) +template +inline std::vector py_list_to_std_vector(const boost::python::object& iterable) { - return std::vector< T >( boost::python::stl_input_iterator< T >( iterable ), - boost::python::stl_input_iterator< T >( ) ); + return std::vector(boost::python::stl_input_iterator(iterable), + boost::python::stl_input_iterator()); } template -inline -boost::python::list std_vector_to_py_list(std::vector vector) { +inline boost::python::list std_vector_to_py_list(std::vector vector) +{ typename std::vector::iterator iter; boost::python::list list; for (iter = vector.begin(); iter != vector.end(); ++iter) { @@ -32,26 +29,29 @@ boost::python::list std_vector_to_py_list(std::vector vector) { return list; } -int beam_decode(at::Tensor th_probs, - at::Tensor th_seq_lens, - std::vector new_vocab, - int vocab_size, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n, - size_t blank_id, - bool log_input, - void *scorer, - at::Tensor th_output, - at::Tensor th_timesteps, - at::Tensor th_scores, - at::Tensor th_out_length) +int paddle_beam_decode_with_lm_and_hotwords(at::Tensor th_probs, + at::Tensor th_seq_lens, + void* decoder_options, + void* scorer, + void* hotword_scorer, + at::Tensor th_output, + at::Tensor th_timesteps, + at::Tensor th_scores, + at::Tensor th_out_length) { - Scorer *ext_scorer = NULL; - if (scorer != NULL) { - ext_scorer = static_cast(scorer); + + DecoderOptions* options = static_cast(decoder_options); + + Scorer* ext_scorer = nullptr; + if (scorer != nullptr) { + ext_scorer = static_cast(scorer); + } + + HotwordScorer* ext_hotword_scorer = nullptr; + if (hotword_scorer != nullptr) { + ext_hotword_scorer = static_cast(hotword_scorer); } + const int64_t max_time = th_probs.size(1); const int64_t batch_size = th_probs.size(0); const int64_t num_classes = th_probs.size(2); @@ -60,12 +60,14 @@ int beam_decode(at::Tensor th_probs, auto prob_accessor = th_probs.accessor(); auto seq_len_accessor = th_seq_lens.accessor(); - for (int b=0; b < batch_size; ++b) { - // avoid a crash by ensuring that an erroneous seq_len doesn't have us try to access memory we shouldn't + for (int b = 0; b < batch_size; ++b) { + // avoid a crash by ensuring that an + // erroneous seq_len doesn't have us try to access memory + // we shouldn't int seq_len = std::min((int)seq_len_accessor[b], (int)max_time); - std::vector> temp (seq_len, std::vector(num_classes)); - for (int t=0; t < seq_len; ++t) { - for (int n=0; n < num_classes; ++n) { + std::vector> temp(seq_len, std::vector(num_classes)); + for (int t = 0; t < seq_len; ++t) { + for (int n = 0; n < num_classes; ++n) { float val = prob_accessor[b][t][n]; temp[t][n] = val; } @@ -73,24 +75,22 @@ int beam_decode(at::Tensor th_probs, inputs.push_back(temp); } - - std::vector>> batch_results = - ctc_beam_search_decoder_batch(inputs, new_vocab, beam_size, num_processes, cutoff_prob, cutoff_top_n, blank_id, log_input, ext_scorer); + std::vector>> batch_results + = ctc_beam_search_decoder_batch(inputs, options, ext_scorer, ext_hotword_scorer); auto outputs_accessor = th_output.accessor(); - auto timesteps_accessor = th_timesteps.accessor(); - auto scores_accessor = th_scores.accessor(); - auto out_length_accessor = th_out_length.accessor(); + auto timesteps_accessor = th_timesteps.accessor(); + auto scores_accessor = th_scores.accessor(); + auto out_length_accessor = th_out_length.accessor(); - - for (int b = 0; b < batch_results.size(); ++b){ + for (int b = 0; b < batch_results.size(); ++b) { std::vector> results = batch_results[b]; - for (int p = 0; p < results.size();++p){ + for (int p = 0; p < results.size(); ++p) { std::pair n_path_result = results[p]; Output output = n_path_result.second; std::vector output_tokens = output.tokens; std::vector output_timesteps = output.timesteps; - for (int t = 0; t < output_tokens.size(); ++t){ - outputs_accessor[b][p][t] = output_tokens[t]; // fill output tokens + for (int t = 0; t < output_tokens.size(); ++t) { + outputs_accessor[b][p][t] = output_tokens[t]; // fill output tokens timesteps_accessor[b][p][t] = output_timesteps[t]; } scores_accessor[b][p] = n_path_result.first; @@ -102,61 +102,120 @@ int beam_decode(at::Tensor th_probs, int paddle_beam_decode(at::Tensor th_probs, at::Tensor th_seq_lens, - std::vector labels, - int vocab_size, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n, - size_t blank_id, - int log_input, + void* decoder_options, at::Tensor th_output, at::Tensor th_timesteps, at::Tensor th_scores, - at::Tensor th_out_length){ + at::Tensor th_out_length) +{ + + return paddle_beam_decode_with_lm_and_hotwords(th_probs, + th_seq_lens, + decoder_options, + nullptr, + nullptr, + th_output, + th_timesteps, + th_scores, + th_out_length); +} - return beam_decode(th_probs, th_seq_lens, labels, vocab_size, beam_size, num_processes, - cutoff_prob, cutoff_top_n, blank_id, log_input, NULL, th_output, th_timesteps, th_scores, th_out_length); +int paddle_beam_decode_with_lm(at::Tensor th_probs, + at::Tensor th_seq_lens, + void* decoder_options, + void* scorer, + at::Tensor th_output, + at::Tensor th_timesteps, + at::Tensor th_scores, + at::Tensor th_out_length) +{ + + return paddle_beam_decode_with_lm_and_hotwords(th_probs, + th_seq_lens, + decoder_options, + scorer, + nullptr, + th_output, + th_timesteps, + th_scores, + th_out_length); } -int paddle_beam_decode_lm(at::Tensor th_probs, - at::Tensor th_seq_lens, - std::vector labels, - int vocab_size, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n, - size_t blank_id, - int log_input, - void *scorer, - at::Tensor th_output, - at::Tensor th_timesteps, - at::Tensor th_scores, - at::Tensor th_out_length){ - - return beam_decode(th_probs, th_seq_lens, labels, vocab_size, beam_size, num_processes, - cutoff_prob, cutoff_top_n, blank_id, log_input, scorer, th_output, th_timesteps, th_scores, th_out_length); +int paddle_beam_decode_with_hotwords(at::Tensor th_probs, + at::Tensor th_seq_lens, + void* decoder_options, + void* hotword_scorer, + at::Tensor th_output, + at::Tensor th_timesteps, + at::Tensor th_scores, + at::Tensor th_out_length) +{ + + return paddle_beam_decode_with_lm_and_hotwords(th_probs, + th_seq_lens, + decoder_options, + nullptr, + hotword_scorer, + th_output, + th_timesteps, + th_scores, + th_out_length); } +void* paddle_get_decoder_options(std::vector vocab, + size_t cutoff_top_n, + double cutoff_prob, + size_t beam_width, + size_t num_processes, + size_t blank_id, + bool log_probs_input, + bool is_bpe_based, + float unk_score, + char token_separator) +{ + DecoderOptions* decoder_options = new DecoderOptions(vocab, + cutoff_top_n, + cutoff_prob, + beam_width, + num_processes, + blank_id, + log_probs_input, + is_bpe_based, + unk_score, + token_separator); + return static_cast(decoder_options); +} void* paddle_get_scorer(double alpha, double beta, const char* lm_path, - vector new_vocab, - int vocab_size) { - Scorer* scorer = new Scorer(alpha, beta, lm_path, new_vocab); + std::vector new_vocab, + std::string lm_type, + const char* fst_path) +{ + Scorer* scorer = new Scorer(alpha, beta, lm_path, new_vocab, lm_type, fst_path); return static_cast(scorer); } +void* get_hotword_scorer(void* decoder_options, + std::vector> hotwords, + std::vector hotword_weights, + char token_separator) +{ + DecoderOptions* options = static_cast(decoder_options); + HotwordScorer* scorer = new HotwordScorer( + options->vocab, hotwords, hotword_weights, token_separator, options->is_bpe_based); + return static_cast(scorer); +} -std::pair beam_decode_with_given_state(at::Tensor th_probs, - at::Tensor th_seq_lens, - size_t num_processes, - std::vector &states, - const std::vector &is_eos_s, - at::Tensor th_scores, - at::Tensor th_out_length) +std::pair +beam_decode_with_given_state(at::Tensor th_probs, + at::Tensor th_seq_lens, + size_t num_processes, + std::vector& states, + const std::vector& is_eos_s, + at::Tensor th_scores, + at::Tensor th_out_length) { const int64_t max_time = th_probs.size(1); const int64_t batch_size = th_probs.size(0); @@ -166,58 +225,58 @@ std::pair beam_decode_with_given_state(at::Tensor auto prob_accessor = th_probs.accessor(); auto seq_len_accessor = th_seq_lens.accessor(); - for (int b=0; b < batch_size; ++b) { - // avoid a crash by ensuring that an erroneous seq_len doesn't have us try to access memory we shouldn't + for (int b = 0; b < batch_size; ++b) { + // avoid a crash by ensuring that an erroneous seq_len doesn't have us try to access memory + // we shouldn't int seq_len = std::min((int)seq_len_accessor[b], (int)max_time); - std::vector> temp (seq_len, std::vector(num_classes)); - for (int t=0; t < seq_len; ++t) { - for (int n=0; n < num_classes; ++n) { + std::vector> temp(seq_len, std::vector(num_classes)); + for (int t = 0; t < seq_len; ++t) { + for (int n = 0; n < num_classes; ++n) { float val = prob_accessor[b][t][n]; temp[t][n] = val; } } inputs.push_back(temp); - } - std::vector>> batch_results = - ctc_beam_search_decoder_batch_with_states(inputs, num_processes, states, is_eos_s); - + std::vector>> batch_results + = ctc_beam_search_decoder_batch_with_states(inputs, num_processes, states, is_eos_s); + int max_result_size = 0; int max_output_tokens_size = 0; - for (int b = 0; b < batch_results.size(); ++b){ + for (int b = 0; b < batch_results.size(); ++b) { std::vector> results = batch_results[b]; if (batch_results[b].size() > max_result_size) { max_result_size = batch_results[b].size(); } - for (int p = 0; p < results.size();++p){ + for (int p = 0; p < results.size(); ++p) { std::pair n_path_result = results[p]; Output output = n_path_result.second; std::vector output_tokens = output.tokens; - + if (output_tokens.size() > max_output_tokens_size) { - max_output_tokens_size = output_tokens.size(); - } - } + max_output_tokens_size = output_tokens.size(); + } } - - torch::Tensor output_tokens_tensor = torch::randint(1, {batch_results.size(), max_result_size, max_output_tokens_size}); - torch::Tensor output_timesteps_tensor = torch::randint(1, {batch_results.size(), max_result_size, max_output_tokens_size}); - + } - auto scores_accessor = th_scores.accessor(); - auto out_length_accessor = th_out_length.accessor(); + torch::Tensor output_tokens_tensor + = torch::randint(1, { batch_results.size(), max_result_size, max_output_tokens_size }); + torch::Tensor output_timesteps_tensor + = torch::randint(1, { batch_results.size(), max_result_size, max_output_tokens_size }); + auto scores_accessor = th_scores.accessor(); + auto out_length_accessor = th_out_length.accessor(); - for (int b = 0; b < batch_results.size(); ++b){ + for (int b = 0; b < batch_results.size(); ++b) { std::vector> results = batch_results[b]; - for (int p = 0; p < results.size();++p){ + for (int p = 0; p < results.size(); ++p) { std::pair n_path_result = results[p]; Output output = n_path_result.second; std::vector output_tokens = output.tokens; std::vector output_timesteps = output.timesteps; for (int t = 0; t < output_tokens.size(); ++t) { - output_tokens_tensor[b][p][t] = output_tokens[t]; // fill output tokens + output_tokens_tensor[b][p][t] = output_tokens[t]; // fill output tokens output_timesteps_tensor[b][p][t] = output_timesteps[t]; } scores_accessor[b][p] = n_path_result.first; @@ -225,79 +284,112 @@ std::pair beam_decode_with_given_state(at::Tensor } } - return {output_tokens_tensor, output_timesteps_tensor}; + return { output_tokens_tensor, output_timesteps_tensor }; } +std::pair +paddle_beam_decode_with_given_state(at::Tensor th_probs, + at::Tensor th_seq_lens, + size_t num_processes, + std::vector states, + std::vector is_eos_s, + at::Tensor th_scores, + at::Tensor th_out_length) +{ -std::pair paddle_beam_decode_with_given_state(at::Tensor th_probs, - at::Tensor th_seq_lens, - size_t num_processes, - std::vector states, - std::vector is_eos_s, - at::Tensor th_scores, - at::Tensor th_out_length){ - - return beam_decode_with_given_state(th_probs, th_seq_lens, num_processes, states,is_eos_s, th_scores, th_out_length); + return beam_decode_with_given_state( + th_probs, th_seq_lens, num_processes, states, is_eos_s, th_scores, th_out_length); } - - - -void* paddle_get_decoder_state(const std::vector &vocabulary, - size_t beam_size, - double cutoff_prob, - size_t cutoff_top_n, - size_t blank_id, - int log_input, - void* scorer) +void* paddle_get_decoder_state(void* decoder_options, void* scorer) { - // DecoderState state(vocabulary, beam_size, cutoff_prob, cutoff_top_n, blank_id, log_input, ext_scorer); - Scorer *ext_scorer = NULL; - if (scorer != NULL) { - ext_scorer = static_cast(scorer); + // DecoderState state(vocabulary, beam_size, cutoff_prob, cutoff_top_n, blank_id, log_input, + // ext_scorer); + DecoderOptions* options = static_cast(decoder_options); + Scorer* ext_scorer = nullptr; + if (scorer != nullptr) { + ext_scorer = static_cast(scorer); } - DecoderState* state = new DecoderState(vocabulary, beam_size, cutoff_prob, cutoff_top_n, blank_id, log_input, ext_scorer); + DecoderState* state = new DecoderState(options, ext_scorer, nullptr); return static_cast(state); } -void paddle_release_state(void* state) { - delete static_cast(state); +void paddle_release_state(void* state) { delete static_cast(state); } + +void paddle_release_scorer(void* scorer) { delete static_cast(scorer); } + +void paddle_release_decoder_options(void* decoder_options) +{ + delete static_cast(decoder_options); } -void paddle_release_scorer(void* scorer) { - delete static_cast(scorer); +void paddle_release_hotword_scorer(void* scorer) +{ + delete static_cast(scorer); + scorer = nullptr; } -int is_character_based(void *scorer){ - Scorer *ext_scorer = static_cast(scorer); +int is_character_based(void* scorer) +{ + Scorer* ext_scorer = static_cast(scorer); return ext_scorer->is_character_based(); } -size_t get_max_order(void *scorer){ - Scorer *ext_scorer = static_cast(scorer); + +size_t get_max_order(void* scorer) +{ + Scorer* ext_scorer = static_cast(scorer); return ext_scorer->get_max_order(); } -size_t get_dict_size(void *scorer){ - Scorer *ext_scorer = static_cast(scorer); - return ext_scorer->get_dict_size(); + +size_t get_lexicon_size(void* scorer) +{ + Scorer* ext_scorer = static_cast(scorer); + return ext_scorer->get_lexicon_size(); } -void reset_params(void *scorer, double alpha, double beta){ - Scorer *ext_scorer = static_cast(scorer); +void reset_params(void* scorer, double alpha, double beta) +{ + Scorer* ext_scorer = static_cast(scorer); ext_scorer->reset_params(alpha, beta); } - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("paddle_beam_decode", &paddle_beam_decode, "paddle_beam_decode"); - m.def("paddle_beam_decode_lm", &paddle_beam_decode_lm, "paddle_beam_decode_lm"); - m.def("paddle_get_scorer", &paddle_get_scorer, "paddle_get_scorer"); - m.def("paddle_release_scorer", &paddle_release_scorer, "paddle_release_scorer"); - m.def("is_character_based", &is_character_based, "is_character_based"); - m.def("get_max_order", &get_max_order, "get_max_order"); - m.def("get_dict_size", &get_dict_size, "get_max_order"); - m.def("reset_params", &reset_params, "reset_params"); - m.def("paddle_get_decoder_state", &paddle_get_decoder_state, "paddle_get_decoder_state"); - m.def("paddle_beam_decode_with_given_state", &paddle_beam_decode_with_given_state, "paddle_beam_decode_with_given_state"); - m.def("paddle_release_state", &paddle_release_state, "paddle_release_state"); - //paddle_beam_decode_with_given_state +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("paddle_beam_decode", &paddle_beam_decode, "paddle_beam_decode"); + m.def("paddle_beam_decode_with_lm", &paddle_beam_decode_with_lm, "paddle_beam_decode_with_lm"); + m.def("paddle_beam_decode_with_hotwords", + &paddle_beam_decode_with_hotwords, + "paddle_beam_decode_with_hotwords"); + m.def("paddle_beam_decode_with_lm_and_hotwords", + &paddle_beam_decode_with_lm_and_hotwords, + "paddle_beam_decode_with_lm_and_hotwords", + py::arg("th_probs"), + py::arg("th_seq_lens"), + py::arg("decoder_options"), + py::arg("scorer").none(true), + py::arg("hotword_scorer").none(true), + py::arg("th_output"), + py::arg("th_timestamps"), + py::arg("th_scores"), + py::arg("th_out_length")); + m.def("paddle_get_decoder_options", &paddle_get_decoder_options, "paddle_get_decoder_options"); + m.def("paddle_get_scorer", &paddle_get_scorer, "paddle_get_scorer"); + m.def("get_hotword_scorer", &get_hotword_scorer, "get_hotword_scorer"); + m.def("paddle_release_scorer", &paddle_release_scorer, "paddle_release_scorer"); + m.def("paddle_release_decoder_options", + &paddle_release_decoder_options, + "paddle_release_decoder_options"); + m.def("paddle_release_hotword_scorer", + &paddle_release_hotword_scorer, + "paddle_release_hotword_scorer"); + m.def("is_character_based", &is_character_based, "is_character_based"); + m.def("get_max_order", &get_max_order, "get_max_order"); + m.def("get_lexicon_size", &get_lexicon_size, "get_max_order"); + m.def("reset_params", &reset_params, "reset_params"); + m.def("paddle_get_decoder_state", &paddle_get_decoder_state, "paddle_get_decoder_state"); + m.def("paddle_beam_decode_with_given_state", + &paddle_beam_decode_with_given_state, + "paddle_beam_decode_with_given_state"); + m.def("paddle_release_state", &paddle_release_state, "paddle_release_state"); + // paddle_beam_decode_with_given_state } diff --git a/ctcdecode/src/binding.h b/ctcdecode/src/binding.h index e6bd2eb1..09c88f07 100644 --- a/ctcdecode/src/binding.h +++ b/ctcdecode/src/binding.h @@ -1,55 +1,70 @@ -int paddle_beam_decode(THFloatTensor *th_probs, - THIntTensor *th_seq_lens, - std::vector labels, - int vocab_size, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n, - size_t blank_id, - int log_input, - THIntTensor *th_output, - THIntTensor *th_timesteps, - THFloatTensor *th_scores, - THIntTensor *th_out_length); - - -int paddle_beam_decode_lm(THFloatTensor *th_probs, - THIntTensor *th_seq_lens, - std::vector labels, - int vocab_size, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n, - size_t blank_id, - bool log_input, - int *scorer, - THIntTensor *th_output, - THIntTensor *th_timesteps, - THFloatTensor *th_scores, - THIntTensor *th_out_length); +int paddle_beam_decode(THFloatTensor* th_probs, + THIntTensor* th_seq_lens, + void* decoder_options, + THIntTensor* th_output, + THIntTensor* th_timesteps, + THFloatTensor* th_scores, + THIntTensor* th_out_length); + +int paddle_beam_decode_with_lm(THFloatTensor* th_probs, + THIntTensor* th_seq_lens, + void* decoder_options, + void* scorer, + THIntTensor* th_output, + THIntTensor* th_timesteps, + THFloatTensor* th_scores, + THIntTensor* th_out_length); + +int paddle_beam_decode_with_hotwords(THFloatTensor* th_probs, + THIntTensor* th_seq_lens, + void* decoder_options, + void* hotword_scorer, + THIntTensor* th_output, + THIntTensor* th_timesteps, + THFloatTensor* th_scores, + THIntTensor* th_out_length); + +int paddle_beam_decode_with_lm_and_hotwords(THFloatTensor* th_probs, + THIntTensor* th_seq_lens, + void* decoder_options, + void* scorer, + void* hotword_scorer, + THIntTensor* th_output, + THIntTensor* th_timesteps, + THFloatTensor* th_scores, + THIntTensor* th_out_length); + +void* paddle_get_decoder_options(std::vector vocab, + size_t cutoff_top_n, + double cutoff_prob, + size_t beam_width, + size_t num_processes, + size_t blank_id, + bool log_probs_input, + bool is_bpe_based, + float unk_score, + char token_separator); void* paddle_get_scorer(double alpha, double beta, const char* lm_path, std::vector labels, - int vocab_size); + std::string lm_type, + const char* fst_path); +void* get_hotword_scorer(void* decoder_options, + std::vector> hotwords, + std::vector hotword_weights, + char token_separator); -void* paddle_get_decoder_state(const std::vector &vocabulary, - size_t beam_size, - double cutoff_prob, - size_t cutoff_top_n, - size_t blank_id, - int log_input, - void* scorer); +void* paddle_get_decoder_state(void* decoder_options, void* scorer); void paddle_release_scorer(void* scorer); +void paddle_release_decoder_options(void* decoder_options); +void paddle_release_hotword_scorer(void* scorer); void paddle_release_state(void* state); - -int is_character_based(void *scorer); -size_t get_max_order(void *scorer); -size_t get_dict_size(void *scorer); -void reset_params(void *scorer, double alpha, double beta); +int is_character_based(void* scorer); +size_t get_max_order(void* scorer); +size_t get_lexicon_size(void* scorer); +void reset_params(void* scorer, double alpha, double beta); diff --git a/ctcdecode/src/ctc_beam_search_decoder.cpp b/ctcdecode/src/ctc_beam_search_decoder.cpp index 0d0365d3..bd323fb3 100644 --- a/ctcdecode/src/ctc_beam_search_decoder.cpp +++ b/ctcdecode/src/ctc_beam_search_decoder.cpp @@ -1,317 +1,436 @@ #include "ctc_beam_search_decoder.h" -#include #include #include -#include #include -#include -#include "decoder_utils.h" #include "ThreadPool.h" +#include "decoder_utils.h" #include "fst/fstlib.h" #include "path_trie.h" using FSTMATCH = fst::SortedMatcher; -DecoderState::DecoderState(const std::vector &vocabulary, - size_t beam_size, - double cutoff_prob, - size_t cutoff_top_n, - size_t blank_id, - int log_input, - Scorer *ext_scorer) - : abs_time_step(0) - , beam_size(beam_size) - , cutoff_prob(cutoff_prob) - , cutoff_top_n(cutoff_top_n) - , blank_id(blank_id) - , log_input(log_input) - , vocabulary(vocabulary) - , ext_scorer(ext_scorer) +DecoderState::DecoderState(DecoderOptions* options, + Scorer* ext_scorer, + HotwordScorer* hotword_scorer) + : abs_time_step(0) + , options(options) + , ext_scorer(ext_scorer) + , hotword_scorer(hotword_scorer) { - // assign space id - auto it = std::find(vocabulary.begin(), vocabulary.end(), " "); - // if no space in vocabulary - if (it == vocabulary.end()) { space_id = -2; - } else { - space_id = std::distance(vocabulary.begin(), it); - } - - // init prefixes' root - root.score = root.log_prob_b_prev = 0.0; - prefixes.push_back(&root); - - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { - auto fst_dict = static_cast(ext_scorer->dictionary); - fst::StdVectorFst *dict_ptr = fst_dict->Copy(true); - root.set_dictionary(dict_ptr); - auto matcher = std::make_shared(*dict_ptr, fst::MATCH_INPUT); - root.set_matcher(matcher); - } + apostrophe_id = -3; + + // assign space id and apostrophe id if present in vocabulary + int id = 0; + for (auto it = options->vocab.begin(); it != options->vocab.end(); ++it) { + if (*it == " ") { + space_id = id; + } else if (*it == "'") { + apostrophe_id = id; + } + ++id; + } + + // init prefixes' root + root.score = root.log_prob_b_prev = 0.0; + root.score_hw = root.log_prob_b_prev_hw = 0.0; + prefixes.push_back(&root); + + if (ext_scorer != nullptr && ext_scorer->has_lexicon()) { + + auto fst_dict = static_cast(ext_scorer->lexicon); + fst::StdVectorFst* dict_ptr = fst_dict->Copy(true); + root.set_lexicon(dict_ptr); + auto matcher + = std::make_shared>(*dict_ptr, fst::MATCH_INPUT); + root.set_matcher(matcher); + } + + if (hotword_scorer != nullptr) { + auto hotword_matcher + = std::make_shared(hotword_scorer->dictionary, fst::MATCH_INPUT); + root.hotword_matcher = hotword_matcher; + } } +/** + * @brief This methods returns true when the given node can be a start of the word. + * Supports both bpe and character based labels + * + * @param path, PathTrie node + * @return true, if the current node's character/token can start a word + * @return false, if the current node's character/token cannot start a word + */ +bool DecoderState::is_start_of_word(PathTrie* path) +{ + + bool is_bpe_based_start_token = options->is_bpe_based + && !is_mergeable_bpe_token(options->vocab[path->character], + path->character, + path->parent->character, + apostrophe_id, + options->token_separator); + + bool is_char_based_start_token + = !options->is_bpe_based + && (path->parent->character == space_id || path->parent->character == -1); + + return is_bpe_based_start_token || is_char_based_start_token; +} -void -DecoderState::next(const std::vector> &probs_seq) +/** + * @brief Updates both original and hotword non-blank scores of the current path node. If the + current ends a + * hotword then the original score (log_p) is updated with the actual score (log_p_hw, contains both + original and hotword scores) + * + * @param path, PathTrie node + * @param log_prob_c, log probablity of the node + * @param lm_score, language model score for the current node + * @param reset_score, whether to consider previous node's original score instead of original + + hotword score . Score resetting happens when partial hotword is formed + */ +void DecoderState::update_score(PathTrie* path, float log_prob_c, float lm_score, bool reset_score) { - // dimension check - size_t num_time_steps = probs_seq.size(); - for (size_t i = 0; i < num_time_steps; ++i) { - VALID_CHECK_EQ(probs_seq[i].size(), - vocabulary.size(), - "The shape of probs_seq does not match with " - "the shape of the vocabulary"); - } - - // prefix search over time - for (size_t time_step = 0; time_step < num_time_steps; ++time_step, ++abs_time_step) { - auto &prob = probs_seq[time_step]; - - float min_cutoff = -NUM_FLT_INF; - bool full_beam = false; - if (ext_scorer != nullptr) { - size_t num_prefixes = std::min(prefixes.size(), beam_size); - std::sort( - prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); - float blank_prob = log_input ? prob[blank_id] : std::log(prob[blank_id]); - min_cutoff = prefixes[num_prefixes - 1]->score + - blank_prob - std::max(0.0, ext_scorer->beta); - full_beam = (num_prefixes == beam_size); - } + float log_p_lm_score = log_prob_c + lm_score; - std::vector> log_prob_idx = - get_pruned_log_probs(prob, cutoff_prob, cutoff_top_n, log_input); - // loop over chars - for (size_t index = 0; index < log_prob_idx.size(); index++) { - auto c = log_prob_idx[index].first; - auto log_prob_c = log_prob_idx[index].second; - - for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) { - auto prefix = prefixes[i]; - if (full_beam && log_prob_c + prefix->score < min_cutoff) { - break; - } - // blank - if (c == blank_id) { - prefix->log_prob_b_cur = - log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); - continue; - } - // repeated character - if (c == prefix->character) { - prefix->log_prob_nb_cur = log_sum_exp( - prefix->log_prob_nb_cur, log_prob_c + prefix->log_prob_nb_prev); - } - // get new prefix - auto prefix_new = prefix->get_path_trie(c, abs_time_step, log_prob_c); - - if (prefix_new != nullptr) { - float log_p = -NUM_FLT_INF; - - if (c == prefix->character && - prefix->log_prob_b_prev > -NUM_FLT_INF) { - log_p = log_prob_c + prefix->log_prob_b_prev; - } else if (c != prefix->character) { - log_p = log_prob_c + prefix->score; - } - - // language model scoring - if (ext_scorer != nullptr && - (c == space_id || ext_scorer->is_character_based())) { - PathTrie *prefix_to_score = nullptr; - // skip scoring the space - if (ext_scorer->is_character_based()) { - prefix_to_score = prefix_new; + float log_p = -NUM_FLT_INF; + float log_p_hw = -NUM_FLT_INF; + + bool is_complete_hotword + = (path->hotword_match_len > 0 && path->hotword_match_len == path->shortest_unigram_length); + + if (path->character == path->parent->character) { + + if (path->parent->log_prob_b_prev > -NUM_FLT_INF) { + log_p = log_p_lm_score + path->parent->log_prob_b_prev; + + if (reset_score) { + // when the current token is not part of hotword whereas prev token + // is, then consider the original score for the current token + // scoring + log_p_hw = log_p + path->hotword_score; } else { - prefix_to_score = prefix; - } + log_p_hw = log_p_lm_score + path->parent->log_prob_b_prev_hw + path->hotword_score; - float score = 0.0; - std::vector ngram; - ngram = ext_scorer->make_ngram(prefix_to_score); - score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; - log_p += score; - log_p += ext_scorer->beta; - } - prefix_new->log_prob_nb_cur = - log_sum_exp(prefix_new->log_prob_nb_cur, log_p); + if (is_complete_hotword) { + // original score needs to be updated with hotword score when + // complete hotword is formed. + log_p = log_p_hw; + } + } } - } // end of loop over prefix - } // end of loop over vocabulary + } else if (path->character != path->parent->character) { + log_p = log_p_lm_score + path->parent->score; - prefixes.clear(); - // update log probs - root.iterate_to_vec(prefixes); + if (reset_score) { + log_p_hw = log_p + path->hotword_score; + } else { + log_p_hw = log_p_lm_score + path->parent->score_hw + path->hotword_score; - // only preserve top beam_size prefixes - if (prefixes.size() >= beam_size) { - std::nth_element(prefixes.begin(), - prefixes.begin() + beam_size, - prefixes.end(), - prefix_compare); - for (size_t i = beam_size; i < prefixes.size(); ++i) { - prefixes[i]->remove(); - } + if (is_complete_hotword) { + log_p = log_p_hw; + } + } + } + + path->log_prob_nb_cur = log_sum_exp(path->log_prob_nb_cur, log_p); + path->log_prob_nb_cur_hw = log_sum_exp(path->log_prob_nb_cur_hw, log_p_hw); +} - prefixes.resize(beam_size); +void DecoderState::next(const std::vector>& probs_seq) +{ + // dimension check + size_t num_time_steps = probs_seq.size(); + for (size_t i = 0; i < num_time_steps; ++i) { + VALID_CHECK_EQ(probs_seq[i].size(), + options->vocab.size(), + "The shape of probs_seq does not match with " + "the shape of the vocabulary"); } - } // end of loop over time + + // prefix search over time + for (size_t time_step = 0; time_step < num_time_steps; ++time_step, ++abs_time_step) { + auto& prob = probs_seq[time_step]; + + float min_cutoff = -NUM_FLT_INF; + bool full_beam = false; + if (ext_scorer != nullptr) { + size_t num_prefixes = std::min(prefixes.size(), options->beam_width); + std::sort(prefixes.begin(), prefixes.begin() + num_prefixes, prefix_compare); + float blank_prob = options->log_probs_input ? prob[options->blank_id] + : std::log(prob[options->blank_id]); + min_cutoff = prefixes[num_prefixes - 1]->score_hw + blank_prob + - std::max(0.0, ext_scorer->beta); + full_beam = (num_prefixes == options->beam_width); + } + + std::vector> log_prob_idx = get_pruned_log_probs( + prob, options->cutoff_prob, options->cutoff_top_n, options->log_probs_input); + + // loop over chars + for (size_t index = 0; index < log_prob_idx.size(); ++index) { + auto c = log_prob_idx[index].first; + auto log_prob_c = log_prob_idx[index].second; + + for (size_t i = 0; i < prefixes.size() && i < options->beam_width; ++i) { + + auto prefix = prefixes[i]; + + if (full_beam && log_prob_c + prefix->score_hw < min_cutoff) { + break; + } + // blank + if (c == options->blank_id) { + prefix->log_prob_b_cur + = log_sum_exp(prefix->log_prob_b_cur, log_prob_c + prefix->score); + prefix->log_prob_b_cur_hw + = log_sum_exp(prefix->log_prob_b_cur_hw, log_prob_c + prefix->score_hw); + continue; + } + + // repeated character + if (c == prefix->character) { + prefix->log_prob_nb_cur = log_sum_exp(prefix->log_prob_nb_cur, + log_prob_c + prefix->log_prob_nb_prev); + + prefix->log_prob_nb_cur_hw = log_sum_exp( + prefix->log_prob_nb_cur_hw, log_prob_c + prefix->log_prob_nb_prev_hw); + } + + // get new prefix + auto new_path = prefix->get_path_trie( + c, abs_time_step, log_prob_c, true, !options->is_bpe_based); + + if (new_path != nullptr) { + + float lm_score = 0.0; + bool is_hotpath = false; + bool reset_score = false; + + // check if the current node is a start of the word + if ((ext_scorer != nullptr || hotword_scorer != nullptr) + && is_start_of_word(new_path)) { + new_path->mark_as_word_start_char(); + } + + // check if the current node is part of a hotword + if (hotword_scorer != nullptr) { + new_path->copy_parent_hotword_params(); + is_hotpath = hotword_scorer->is_hotpath(new_path, space_id, apostrophe_id); + + if (!is_hotpath) { + new_path->reset_hotword_params(); + if (prefix->is_hotpath()) { + reset_score = true; + } + } + } + + // hotword scoring + if (is_hotpath) { + new_path->mark_as_hotpath(); + + // need to consider original score when previous word is a + // partial hotword + if (prefix->is_hotpath() && new_path->hotword_dictionary_state == 0) { + reset_score = true; + } + + // update hotword related params of new node and calculate hotword score + hotword_scorer->estimate_hw_score(new_path); + } + // unknown scoring + else { + // check if the current node forms OOV word and add unk score + if (options->is_bpe_based && ext_scorer != nullptr + && ext_scorer->has_lexicon()) { + bool is_oov = new_path->is_oov_token(); + if (is_oov) { + lm_score += options->unk_score; + } + } + } + + // language model scoring + if (ext_scorer != nullptr + && (c == space_id || ext_scorer->is_character_based() + || ext_scorer->is_bpe_based())) { + + PathTrie* prefix_to_score = nullptr; + // skip scoring the space + if (ext_scorer->is_character_based() || ext_scorer->is_bpe_based()) { + prefix_to_score = new_path; + } else { + prefix_to_score = prefix; + } + std::vector ngram; + ngram = ext_scorer->make_ngram(prefix_to_score); + lm_score += ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; + lm_score += ext_scorer->beta; + } + + // update original and hotword score for the new path + update_score(new_path, log_prob_c, lm_score, reset_score); + } + + } // end of loop over prefix + } // end of loop over vocabulary + + prefixes.clear(); + // update log probs + root.iterate_to_vec(prefixes); + + // only preserve top beam_size prefixes + if (prefixes.size() >= options->beam_width) { + std::nth_element(prefixes.begin(), + prefixes.begin() + options->beam_width, + prefixes.end(), + prefix_compare); + for (size_t i = options->beam_width; i < prefixes.size(); ++i) { + prefixes[i]->remove(); + } + + prefixes.resize(options->beam_width); + } + + } // end of loop over time } -std::vector> -DecoderState::decode() +std::vector> DecoderState::decode() { - std::vector prefixes_copy = prefixes; - std::unordered_map scores; - for (PathTrie* prefix : prefixes_copy) { - scores[prefix] = prefix->score; - } - - // score the last word of each prefix that doesn't end with space - if (ext_scorer != nullptr && !ext_scorer->is_character_based()) { - for (size_t i = 0; i < beam_size && i < prefixes_copy.size(); ++i) { - auto prefix = prefixes_copy[i]; - if (!prefix->is_empty() && prefix->character != space_id) { - float score = 0.0; - std::vector ngram = ext_scorer->make_ngram(prefix); - score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; - score += ext_scorer->beta; - scores[prefix] += score; - } + std::vector prefixes_copy = prefixes; + std::unordered_map scores; + for (PathTrie* prefix : prefixes_copy) { + scores[prefix] = prefix->score_hw; } - } - - using namespace std::placeholders; - size_t num_prefixes = std::min(prefixes_copy.size(), beam_size); - std::sort(prefixes_copy.begin(), prefixes_copy.begin() + num_prefixes, - std::bind(prefix_compare_external_scores, _1, _2, scores)); - - // compute aproximate ctc score as the return score, without affecting the - // return order of decoding result. To delete when decoder gets stable. - for (size_t i = 0; i < beam_size && i < prefixes_copy.size(); ++i) { - double approx_ctc = scores[prefixes_copy[i]]; - if (ext_scorer != nullptr) { - std::vector output; - std::vector timesteps; - prefixes_copy[i]->get_path_vec(output, timesteps); - auto prefix_length = output.size(); - auto words = ext_scorer->split_labels(output); - // remove word insert - approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; - // remove language model weight: - approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; + + // score the last word of each prefix that doesn't end with space + if (ext_scorer != nullptr + && !(ext_scorer->is_character_based() || ext_scorer->is_bpe_based())) { + for (size_t i = 0; i < options->beam_width && i < prefixes_copy.size(); ++i) { + auto prefix = prefixes_copy[i]; + if (!prefix->is_empty() && prefix->character != space_id) { + float score = 0.0; + std::vector ngram = ext_scorer->make_ngram(prefix); + score = ext_scorer->get_log_cond_prob(ngram) * ext_scorer->alpha; + score += ext_scorer->beta; + scores[prefix] += score; + } + } + } + + using namespace std::placeholders; + size_t num_prefixes = std::min(prefixes_copy.size(), options->beam_width); + std::sort(prefixes_copy.begin(), + prefixes_copy.begin() + num_prefixes, + std::bind(prefix_compare_external_scores, _1, _2, scores)); + + // compute aproximate ctc score as the return score, without affecting the + // return order of decoding result. To delete when decoder gets stable. + for (size_t i = 0; i < options->beam_width && i < prefixes_copy.size(); ++i) { + double approx_ctc = scores[prefixes_copy[i]]; + if (ext_scorer != nullptr + && !(ext_scorer->is_character_based() || ext_scorer->is_bpe_based())) { + std::vector output; + std::vector timesteps; + prefixes_copy[i]->get_path_vec(output, timesteps); + auto prefix_length = output.size(); + auto words = ext_scorer->split_labels(output); + // remove word insert + approx_ctc = approx_ctc - prefix_length * ext_scorer->beta; + // remove language model weight: + approx_ctc -= (ext_scorer->get_sent_log_prob(words)) * ext_scorer->alpha; + } + prefixes_copy[i]->approx_ctc = approx_ctc; } - prefixes_copy[i]->approx_ctc = approx_ctc; - } - return get_beam_search_result(prefixes_copy, beam_size); + return get_beam_search_result(prefixes_copy, options->beam_width); } -std::vector> ctc_beam_search_decoder( - const std::vector> &probs_seq, - const std::vector &vocabulary, - size_t beam_size, - double cutoff_prob, - size_t cutoff_top_n, - size_t blank_id, - int log_input, - Scorer *ext_scorer) +std::vector> +ctc_beam_search_decoder(const std::vector>& probs_seq, + DecoderOptions* options, + Scorer* ext_scorer, + HotwordScorer* hotword_scorer) { - DecoderState state(vocabulary, beam_size, cutoff_prob, cutoff_top_n, blank_id, - log_input, ext_scorer); - state.next(probs_seq); - return state.decode(); + DecoderState state(options, ext_scorer, hotword_scorer); + state.next(probs_seq); + return state.decode(); } - -std::vector> ctc_beam_search_decoder_with_given_state( - const std::vector> &probs_seq, - DecoderState *state, - bool is_eos) - { - state->next(probs_seq); - if (is_eos) { +std::vector> +ctc_beam_search_decoder_with_given_state(const std::vector>& probs_seq, + DecoderState* state, + bool is_eos) +{ + state->next(probs_seq); + if (is_eos) { return state->decode(); - } - else { + } else { return {}; - } - } +} std::vector>> -ctc_beam_search_decoder_batch( - const std::vector>> &probs_split, - const std::vector &vocabulary, - size_t beam_size, - size_t num_processes, - double cutoff_prob, - size_t cutoff_top_n, - size_t blank_id, - int log_input, - Scorer *ext_scorer) +ctc_beam_search_decoder_batch(const std::vector>>& probs_split, + DecoderOptions* options, + Scorer* ext_scorer, + HotwordScorer* hotword_scorer) { - VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); - // thread pool - ThreadPool pool(num_processes); - // number of samples - size_t batch_size = probs_split.size(); - - // enqueue the tasks of decoding - std::vector>>> res; - for (size_t i = 0; i < batch_size; ++i) { - res.emplace_back(pool.enqueue(ctc_beam_search_decoder, - std::cref(probs_split[i]), - std::cref(vocabulary), - beam_size, - cutoff_prob, - cutoff_top_n, - blank_id, - log_input, - ext_scorer)); - } - - - - // get decoding results - std::vector>> batch_results; - for (size_t i = 0; i < batch_size; ++i) { - batch_results.emplace_back(res[i].get()); - } - return batch_results; -} + VALID_CHECK_GT(options->num_processes, 0, "num_processes must be nonnegative!"); + // thread pool + ThreadPool pool(options->num_processes); + // number of samples + size_t batch_size = probs_split.size(); + + // enqueue the tasks of decoding + std::vector>>> res; + for (size_t i = 0; i < batch_size; ++i) { + res.emplace_back(pool.enqueue(ctc_beam_search_decoder, + std::cref(probs_split[i]), + options, + ext_scorer, + hotword_scorer)); + } + // get decoding results + std::vector>> batch_results; + for (size_t i = 0; i < batch_size; ++i) { + batch_results.emplace_back(res[i].get()); + } + return batch_results; +} -std::vector>> ctc_beam_search_decoder_batch_with_states -(const std::vector>> &probs_split, +std::vector>> ctc_beam_search_decoder_batch_with_states( + const std::vector>>& probs_split, size_t num_processes, - std::vector &states, - const std::vector &is_eos_s) + std::vector& states, + const std::vector& is_eos_s) { - VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); - // thread pool - ThreadPool pool(num_processes); - // number of samples - size_t batch_size = probs_split.size(); - - - // enqueue the tasks of decoding - std::vector>>> res; - for (size_t i = 0; i < batch_size; ++i) { - res.emplace_back(pool.enqueue(ctc_beam_search_decoder_with_given_state, - std::cref(probs_split[i]), - static_cast(states[i]), - is_eos_s[i])); - } - - // get decoding results - std::vector>> batch_results; - for (size_t i = 0; i < batch_size; ++i) { - batch_results.emplace_back(res[i].get()); - } - return batch_results; + VALID_CHECK_GT(num_processes, 0, "num_processes must be nonnegative!"); + // thread pool + ThreadPool pool(num_processes); + // number of samples + size_t batch_size = probs_split.size(); + + // enqueue the tasks of decoding + std::vector>>> res; + for (size_t i = 0; i < batch_size; ++i) { + res.emplace_back(pool.enqueue(ctc_beam_search_decoder_with_given_state, + std::cref(probs_split[i]), + static_cast(states[i]), + is_eos_s[i])); + } + + // get decoding results + std::vector>> batch_results; + for (size_t i = 0; i < batch_size; ++i) { + batch_results.emplace_back(res[i].get()); + } + return batch_results; } diff --git a/ctcdecode/src/ctc_beam_search_decoder.h b/ctcdecode/src/ctc_beam_search_decoder.h index 09ff2eec..238eb188 100644 --- a/ctcdecode/src/ctc_beam_search_decoder.h +++ b/ctcdecode/src/ctc_beam_search_decoder.h @@ -1,134 +1,110 @@ #ifndef CTC_BEAM_SEARCH_DECODER_H_ #define CTC_BEAM_SEARCH_DECODER_H_ -#include #include #include -#include "scorer.h" +#include "decoder_options.h" +#include "hotword_scorer.h" #include "output.h" +#include "scorer.h" /* CTC Beam Search Decoder * Parameters: * probs_seq: 2-D vector that each element is a vector of probabilities * over vocabulary of one time step. - * vocabulary: A vector of vocabulary. - * beam_size: The width of beam search. - * cutoff_prob: Cutoff probability for pruning. - * cutoff_top_n: Cutoff number for pruning. + * DecoderOptions: Contains vocabulary, beam width, cutoff_top_n, cutoff_prob, etc + * that are required for beam decoding * ext_scorer: External scorer to evaluate a prefix, which consists of * n-gram language model scoring and word insertion term. * Default null, decoding the input sample without scorer. + * hotword_scorer: External hotword scorer to boost the score for specific + * words. Default null, decoding the input sample without hotword scorer * Return: * A vector that each element is a pair of score and decoding result, * in desending order. */ -std::vector> ctc_beam_search_decoder( - const std::vector> &probs_seq, - const std::vector &vocabulary, - size_t beam_size, - double cutoff_prob = 1.0, - size_t cutoff_top_n = 40, - size_t blank_id = 0, - int log_input = 0, - Scorer *ext_scorer = nullptr); - - +std::vector> +ctc_beam_search_decoder(const std::vector>& probs_seq, + DecoderOptions* options, + Scorer* ext_scorer = nullptr, + HotwordScorer* hotword_scorer = nullptr); /* CTC Beam Search Decoder for batch data * Parameters: * probs_seq: 3-D vector that each element is a 2-D vector that can be used * by ctc_beam_search_decoder(). - * vocabulary: A vector of vocabulary. - * beam_size: The width of beam search. - * num_processes: Number of threads for beam search. - * cutoff_prob: Cutoff probability for pruning. - * cutoff_top_n: Cutoff number for pruning. + * DecoderOptions: Contains vocabulary, beam width, cutoff_top_n, cutoff_prob, etc + * that are required for beam decoding * ext_scorer: External scorer to evaluate a prefix, which consists of * n-gram language model scoring and word insertion term. * Default null, decoding the input sample without scorer. + * hotword_scorer: External hotword scorer to boost the score for specific + * words. Default null, decoding the input sample without hotword scorer * Return: * A 2-D vector that each element is a vector of beam search decoding * result for one audio sample. */ std::vector>> -ctc_beam_search_decoder_batch( - const std::vector>> &probs_split, - const std::vector &vocabulary, - size_t beam_size, - size_t num_processes, - double cutoff_prob = 1.0, - size_t cutoff_top_n = 40, - size_t blank_id = 0, - int log_input = 0, - Scorer *ext_scorer = nullptr); - - - - - -class DecoderState -{ - int abs_time_step; - int space_id; - size_t beam_size; - double cutoff_prob; - size_t cutoff_top_n; - size_t blank_id; - int log_input; - std::vector vocabulary; - Scorer *ext_scorer; - - std::vector prefixes; - PathTrie root; +ctc_beam_search_decoder_batch(const std::vector>>& probs_split, + DecoderOptions* options, + Scorer* ext_scorer = nullptr, + HotwordScorer* hotword_scorer = nullptr); + +class DecoderState { + int abs_time_step; + int space_id; + int apostrophe_id; + DecoderOptions* options; + Scorer* ext_scorer; + HotwordScorer* hotword_scorer; + + std::vector prefixes; + PathTrie root; public: - /* Initialize CTC beam search decoder for streaming - * - * Parameters: - * vocabulary: A vector of vocabulary. - * beam_size: The width of beam search. - * cutoff_prob: Cutoff probability for pruning. - * cutoff_top_n: Cutoff number for pruning. - * ext_scorer: External scorer to evaluate a prefix, which consists of - * n-gram language model scoring and word insertion term. - * Default null, decoding the input sample without scorer. - */ - DecoderState(const std::vector &vocabulary, - size_t beam_size, - double cutoff_prob, - size_t cutoff_top_n, - size_t blank_id, - int log_input, - Scorer *ext_scorer); - ~DecoderState() = default; - - /* Process logits in decoder stream - * - * Parameters: - * probs: 2-D vector where each element is a vector of probabilities - * over alphabet of one time step. - */ - void next(const std::vector> &probs_seq); - - /* Get current transcription from the decoder stream state - * - * Return: - * A vector where each element is a pair of score and decoding result, - * in descending order. - */ - std::vector> decode(); + /* Initialize CTC beam search decoder for streaming + * + * Parameters: + * DecoderOptions: Contains vocabulary, beam width, cutoff_top_n, cutoff_prob, etc + * that are required for beam decoding + * ext_scorer: External scorer to evaluate a prefix, which consists of + * n-gram language model scoring and word insertion term. + * Default null, decoding the input sample without scorer. + * hotword_scorer: External hotword scorer to boost the score for specific + * words. Default null, decoding the input sample without hotword scorer + */ + DecoderState(DecoderOptions* options, Scorer* ext_scorer, HotwordScorer* hotword_scorer); + ~DecoderState() = default; + + /* Process logits in decoder stream + * + * Parameters: + * probs: 2-D vector where each element is a vector of probabilities + * over alphabet of one time step. + */ + void next(const std::vector>& probs_seq); + + bool is_start_of_word(PathTrie* path); + + void update_score(PathTrie* path, float log_prob_c, float lm_score, bool reset_score); + + /* Get current transcription from the decoder stream state + * + * Return: + * A vector where each element is a pair of score and decoding result, + * in descending order. + */ + std::vector> decode(); }; - -std::vector>> -ctc_beam_search_decoder_batch_with_states( - const std::vector>> &probs_split, +std::vector>> ctc_beam_search_decoder_batch_with_states( + const std::vector>>& probs_split, size_t num_processes, - std::vector &states, - const std::vector &is_eos_s); + std::vector& states, + const std::vector& is_eos_s); -#endif // CTC_BEAM_SEARCH_DECODER_H_ +#endif // CTC_BEAM_SEARCH_DECODER_H_ diff --git a/ctcdecode/src/decoder_options.h b/ctcdecode/src/decoder_options.h new file mode 100644 index 00000000..3b85a656 --- /dev/null +++ b/ctcdecode/src/decoder_options.h @@ -0,0 +1,73 @@ +#ifndef DECODER_OPTIONS_H +#define DECODER_OPTIONS_H + +#include + +class DecoderOptions { +public: + /* Initialize DecoderOptions for CTC beam decoding + * + * Parameters: + * vocab: A vector of vocabulary (labels). + * cutoff_top_n: Cutoff number in pruning. Only the top cutoff_top_n characters + with the highest probability in the vocab will be used in beam search. + * cutoff_prob: Cutoff probability in pruning. 1.0 means no pruning. + * beam_width: This controls how broad the beam search is. Higher values are more + likely to find top beams, but they also will make your beam search exponentially + slower. + * num_processes: Parallelize the batch using num_processes workers. + * blank_id: Index of the CTC blank token used when training your + model. + * log_probs_input (bool): False if the model has passed through a softmax and output + probabilities sum to 1. + * is_bpe_based (bool): True if the labels contains bpe tokens else False + * unk_score (float): Extra score to be added when an unknown word forms ( default = '-5' ) + * token_separator (char): prefix of the bpe tokens ( default = '#' ) + */ + DecoderOptions(std::vector vocab, + size_t cutoff_top_n, + double cutoff_prob, + size_t beam_width, + size_t num_processes, + size_t blank_id, + bool log_probs_input, + bool is_bpe_based, + float unk_score, + char token_separator) + : vocab(vocab) + , cutoff_top_n(cutoff_top_n) + , cutoff_prob(cutoff_prob) + , beam_width(beam_width) + , num_processes(num_processes) + , blank_id(blank_id) + , log_probs_input(log_probs_input) + , is_bpe_based(is_bpe_based) + , unk_score(unk_score) + , token_separator(token_separator) + { + } + + /* Initialize DecoderOptions with vocabulary alone + * + * Parameters: + * vocab: A vector of vocabulary (labels). + */ + DecoderOptions(std::vector vocab) + : vocab(vocab) + { + } + ~DecoderOptions() = default; + + std::vector vocab; + size_t beam_width = 100; + size_t cutoff_top_n = 40; + double cutoff_prob = 1.0; + size_t num_processes = 4; + size_t blank_id = 0; + bool log_probs_input = false; + bool is_bpe_based = false; + float unk_score = -5; + char token_separator = '#'; +}; + +#endif // DECODER_OPTIONS_H diff --git a/ctcdecode/src/decoder_utils.cpp b/ctcdecode/src/decoder_utils.cpp index 2a35fe48..dbf388a0 100644 --- a/ctcdecode/src/decoder_utils.cpp +++ b/ctcdecode/src/decoder_utils.cpp @@ -1,193 +1,242 @@ #include "decoder_utils.h" #include +#include #include #include -#include using namespace std; - -std::vector> get_pruned_log_probs( - const std::vector &prob_step, - double cutoff_prob, - size_t cutoff_top_n, - int log_input) { - std::vector> prob_idx; - double log_cutoff_prob = log(cutoff_prob); - for (size_t i = 0; i < prob_step.size(); ++i) { - prob_idx.push_back(std::pair(i, prob_step[i])); - } - // pruning of vacobulary - size_t cutoff_len = prob_step.size(); - if (log_cutoff_prob < 0.0 || cutoff_top_n < cutoff_len) { - std::sort( - prob_idx.begin(), prob_idx.end(), pair_comp_second_rev); - if (log_cutoff_prob < 0.0) { - double cum_prob = 0.0; - cutoff_len = 0; - for (size_t i = 0; i < prob_idx.size(); ++i) { - cum_prob = log_sum_exp(cum_prob, log_input ? prob_idx[i].second : log(prob_idx[i].second) ); - cutoff_len += 1; - if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) break; - } - }else{ - cutoff_len = cutoff_top_n; - } - prob_idx = std::vector>( - prob_idx.begin(), prob_idx.begin() + cutoff_len); - } - std::vector> log_prob_idx; - for (size_t i = 0; i < cutoff_len; ++i) { - log_prob_idx.push_back(std::pair( - prob_idx[i].first, log_input ? prob_idx[i].second : log(prob_idx[i].second + NUM_FLT_MIN))); - } - return log_prob_idx; +std::vector> get_pruned_log_probs(const std::vector& prob_step, + double cutoff_prob, + size_t cutoff_top_n, + int log_input) +{ + std::vector> prob_idx; + double log_cutoff_prob = log(cutoff_prob); + for (size_t i = 0; i < prob_step.size(); ++i) { + prob_idx.push_back(std::pair(i, prob_step[i])); + } + // pruning of vacobulary + size_t cutoff_len = prob_step.size(); + if (log_cutoff_prob < 0.0 || cutoff_top_n < cutoff_len) { + std::sort(prob_idx.begin(), prob_idx.end(), pair_comp_second_rev); + if (log_cutoff_prob < 0.0) { + double cum_prob = 0.0; + cutoff_len = 0; + for (size_t i = 0; i < prob_idx.size(); ++i) { + cum_prob = log_sum_exp(cum_prob, + log_input ? prob_idx[i].second : log(prob_idx[i].second)); + cutoff_len += 1; + if (cum_prob >= cutoff_prob || cutoff_len >= cutoff_top_n) + break; + } + } else { + cutoff_len = cutoff_top_n; + } + prob_idx + = std::vector>(prob_idx.begin(), prob_idx.begin() + cutoff_len); + } + std::vector> log_prob_idx; + for (size_t i = 0; i < cutoff_len; ++i) { + log_prob_idx.push_back(std::pair( + prob_idx[i].first, + log_input ? prob_idx[i].second : log(prob_idx[i].second + NUM_FLT_MIN))); + } + return log_prob_idx; } +std::vector> +get_beam_search_result(const std::vector& prefixes, size_t beam_size) +{ + // allow for the post processing + std::vector space_prefixes; + if (space_prefixes.empty()) { + for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { + space_prefixes.push_back(prefixes[i]); + } + } -std::vector> get_beam_search_result( - const std::vector &prefixes, - size_t beam_size) { - // allow for the post processing - std::vector space_prefixes; - if (space_prefixes.empty()) { - for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) { - space_prefixes.push_back(prefixes[i]); - } - } - - std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); - std::vector> output_vecs; - for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) { - std::vector output; - std::vector timesteps; - space_prefixes[i]->get_path_vec(output, timesteps); - Output outputs; - outputs.tokens = output; - outputs.timesteps = timesteps; - std::pair output_pair(-space_prefixes[i]->approx_ctc, - outputs); - output_vecs.emplace_back(output_pair); - } - return output_vecs; + std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); + std::vector> output_vecs; + for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) { + std::vector output; + std::vector timesteps; + space_prefixes[i]->get_path_vec(output, timesteps); + Output outputs; + outputs.tokens = output; + outputs.timesteps = timesteps; + std::pair output_pair(-space_prefixes[i]->approx_ctc, outputs); + output_vecs.emplace_back(output_pair); + } + return output_vecs; } -size_t get_utf8_str_len(const std::string &str) { - size_t str_len = 0; - for (char c : str) { - str_len += ((c & 0xc0) != 0x80); - } - return str_len; +size_t get_utf8_str_len(const std::string& str) +{ + size_t str_len = 0; + for (char c : str) { + str_len += ((c & 0xc0) != 0x80); + } + return str_len; } -std::vector split_utf8_str(const std::string &str) { - std::vector result; - std::string out_str; - - for (char c : str) { - if ((c & 0xc0) != 0x80) // new UTF-8 character - { - if (!out_str.empty()) { - result.push_back(out_str); - out_str.clear(); - } +std::vector split_utf8_str(const std::string& str) +{ + std::vector result; + std::string out_str; + + for (char c : str) { + if ((c & 0xc0) != 0x80) // new UTF-8 character + { + if (!out_str.empty()) { + result.push_back(out_str); + out_str.clear(); + } + } + + out_str.append(1, c); } - - out_str.append(1, c); - } - result.push_back(out_str); - return result; + result.push_back(out_str); + return result; } -std::vector split_str(const std::string &s, - const std::string &delim) { - std::vector result; - std::size_t start = 0, delim_len = delim.size(); - while (true) { - std::size_t end = s.find(delim, start); - if (end == std::string::npos) { - if (start < s.size()) { - result.push_back(s.substr(start)); - } - break; - } - if (end > start) { - result.push_back(s.substr(start, end - start)); - } - start = end + delim_len; - } - return result; +std::vector split_str(const std::string& s, const std::string& delim) +{ + std::vector result; + std::size_t start = 0, delim_len = delim.size(); + while (true) { + std::size_t end = s.find(delim, start); + if (end == std::string::npos) { + if (start < s.size()) { + result.push_back(s.substr(start)); + } + break; + } + if (end > start) { + result.push_back(s.substr(start, end - start)); + } + start = end + delim_len; + } + return result; } -bool prefix_compare(const PathTrie *x, const PathTrie *y) { - if (x->score == y->score) { - if (x->character == y->character) { - return false; +bool prefix_compare(const PathTrie* x, const PathTrie* y) +{ + if (x->score_hw == y->score_hw) { + if (x->character == y->character) { + return false; + } else { + return (x->character < y->character); + } } else { - return (x->character < y->character); + return x->score_hw > y->score_hw; } - } else { - return x->score > y->score; - } } -bool prefix_compare_external_scores(const PathTrie *x, const PathTrie *y, - const std::unordered_map& scores) { - if (scores.at(x) == scores.at(y)) { - if (x->character == y->character) { - return false; +bool prefix_compare_external_scores(const PathTrie* x, + const PathTrie* y, + const std::unordered_map& scores) +{ + if (scores.at(x) == scores.at(y)) { + if (x->character == y->character) { + return false; + } else { + return (x->character < y->character); + } } else { - return (x->character < y->character); + return scores.at(x) > scores.at(y); } - } else { - return scores.at(x) > scores.at(y); - } } -void add_word_to_fst(const std::vector &word, - fst::StdVectorFst *dictionary) { - if (dictionary->NumStates() == 0) { - fst::StdVectorFst::StateId start = dictionary->AddState(); - assert(start == 0); - dictionary->SetStart(start); - } - fst::StdVectorFst::StateId src = dictionary->Start(); - fst::StdVectorFst::StateId dst; - for (auto c : word) { - dst = dictionary->AddState(); - dictionary->AddArc(src, fst::StdArc(c, c, 0, dst)); - src = dst; - } - dictionary->SetFinal(dst, fst::StdArc::Weight::One()); +void add_word_to_fst(const std::vector& word, fst::StdVectorFst* lexicon) +{ + if (lexicon->NumStates() == 0) { + fst::StdVectorFst::StateId start = lexicon->AddState(); + assert(start == 0); + lexicon->SetStart(start); + } + fst::StdVectorFst::StateId src = lexicon->Start(); + fst::StdVectorFst::StateId dst; + for (auto c : word) { + dst = lexicon->AddState(); + lexicon->AddArc(src, fst::StdArc(c, c, 0, dst)); + src = dst; + } + lexicon->SetFinal(dst, fst::StdArc::Weight::One()); } -bool add_word_to_dictionary( - const std::string &word, - const std::unordered_map &char_map, - bool add_space, - int SPACE_ID, - fst::StdVectorFst *dictionary) { - auto characters = split_utf8_str(word); - - std::vector int_word; +bool add_word_to_lexicon(const std::vector& characters, + const std::unordered_map& char_map, + bool add_space, + int SPACE_ID, + fst::StdVectorFst* lexicon) +{ + + std::vector int_word; + + for (auto& c : characters) { + if (c == " ") { + int_word.push_back(SPACE_ID); + } else { + auto int_c = char_map.find(c); + if (int_c != char_map.end()) { + int_word.push_back(int_c->second); + } else { + return false; // return without + // adding + } + } + } - for (auto &c : characters) { - if (c == " ") { - int_word.push_back(SPACE_ID); - } else { - auto int_c = char_map.find(c); - if (int_c != char_map.end()) { - int_word.push_back(int_c->second); - } else { - return false; // return without adding - } + if (add_space) { + int_word.push_back(SPACE_ID); } - } - if (add_space) { - int_word.push_back(SPACE_ID); - } + add_word_to_fst(int_word, lexicon); + return true; // return with successful adding +} - add_word_to_fst(int_word, dictionary); - return true; // return with successful adding +void set_char_map(const std::vector& char_list, + std::unordered_map& char_map, + int& space_id) +{ + char_map.clear(); + int i = 0; + for (auto it = char_list.begin(); it != char_list.end(); ++it) { + if (*it == " ") { + space_id = i; + } + // The initial state of FST is state 0, hence the index of chars in + // the FST should start from 1 to avoid the conflict with the initial + // state, otherwise wrong decoding results would be given. + char_map[*it] = i + 1; + ++i; + } } + +/** + * @brief This methods returns true when the current node bpe token can be mergeable + * with the parent node token. + * + * @param cur_token, current token string + * @param cur_char, id of current token + * @param parent_char, id of parent token + * @param apostrophe_id, id of apostrophe + * @param token_separator, bpe token separator character, Ex: '#' + * @return true, if the current node token merges with the parent token + * @return false, if the current node token does not merge with the parent token + */ +bool is_mergeable_bpe_token(std::string cur_token, + int cur_char, + int parent_char, + int apostrophe_id, + char token_separator) +{ + + bool is_token_seperator = false; + if (cur_token.size() > 0) { + is_token_seperator = (cur_token.at(0) == token_separator); + } + + return is_token_seperator || parent_char == apostrophe_id || cur_char == apostrophe_id; +} \ No newline at end of file diff --git a/ctcdecode/src/decoder_utils.h b/ctcdecode/src/decoder_utils.h index 2f3eea62..1f80e7ef 100644 --- a/ctcdecode/src/decoder_utils.h +++ b/ctcdecode/src/decoder_utils.h @@ -6,97 +6,106 @@ #include #include "fst/log.h" -#include "path_trie.h" #include "output.h" +#include "path_trie.h" -const float NUM_FLT_INF = std::numeric_limits::max(); -const float NUM_FLT_MIN = std::numeric_limits::min(); +const float NUM_FLT_INF = std::numeric_limits::max(); +const float NUM_FLT_MIN = std::numeric_limits::min(); +const int NUM_INT_INF = std::numeric_limits::max(); const float NUM_FLT_LOGE = 0.4342944819; // inline function for validation check -inline void check( - bool x, const char *expr, const char *file, int line, const char *err) { - if (!x) { - std::cout << "[" << file << ":" << line << "] "; - LOG(FATAL) << "\"" << expr << "\" check failed. " << err; - } +inline void check(bool x, const char* expr, const char* file, int line, const char* err) +{ + if (!x) { + std::cout << "[" << file << ":" << line << "] "; + LOG(FATAL) << "\"" << expr << "\" check failed. " << err; + } } -#define VALID_CHECK(x, info) \ - check(static_cast(x), #x, __FILE__, __LINE__, info) +#define VALID_CHECK(x, info) check(static_cast(x), #x, __FILE__, __LINE__, info) #define VALID_CHECK_EQ(x, y, info) VALID_CHECK((x) == (y), info) #define VALID_CHECK_GT(x, y, info) VALID_CHECK((x) > (y), info) #define VALID_CHECK_LT(x, y, info) VALID_CHECK((x) < (y), info) - // Function template for comparing two pairs template -bool pair_comp_first_rev(const std::pair &a, - const std::pair &b) { - return a.first > b.first; +bool pair_comp_first_rev(const std::pair& a, const std::pair& b) +{ + return a.first > b.first; } // Function template for comparing two pairs template -bool pair_comp_second_rev(const std::pair &a, - const std::pair &b) { - return a.second > b.second; +bool pair_comp_second_rev(const std::pair& a, const std::pair& b) +{ + return a.second > b.second; } // Return the sum of two probabilities in log scale template -T log_sum_exp(const T &x, const T &y) { - static T num_min = -std::numeric_limits::max(); - if (x <= num_min) return y; - if (y <= num_min) return x; - T xmax = std::max(x, y); - return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax; +T log_sum_exp(const T& x, const T& y) +{ + static T num_min = -std::numeric_limits::max(); + if (x <= num_min) + return y; + if (y <= num_min) + return x; + T xmax = std::max(x, y); + return std::log(std::exp(x - xmax) + std::exp(y - xmax)) + xmax; } // Get pruned probability vector for each time step's beam search -std::vector> get_pruned_log_probs( - const std::vector &prob_step, - double cutoff_prob, - size_t cutoff_top_n, - int log_input); +std::vector> get_pruned_log_probs(const std::vector& prob_step, + double cutoff_prob, + size_t cutoff_top_n, + int log_input); // Get beam search result from prefixes in trie tree -std::vector> get_beam_search_result( - const std::vector &prefixes, - size_t beam_size); +std::vector> +get_beam_search_result(const std::vector& prefixes, size_t beam_size); // Functor for prefix comparison -bool prefix_compare(const PathTrie *x, const PathTrie *y); +bool prefix_compare(const PathTrie* x, const PathTrie* y); -bool prefix_compare_external_scores(const PathTrie *x, const PathTrie *y, +bool prefix_compare_external_scores(const PathTrie* x, + const PathTrie* y, const std::unordered_map& scores); /* Get length of utf8 encoding string * See: http://stackoverflow.com/a/4063229 */ -size_t get_utf8_str_len(const std::string &str); +size_t get_utf8_str_len(const std::string& str); /* Split a string into a list of strings on a given string * delimiter. NB: delimiters on beginning / end of string are * trimmed. Eg, "FooBarFoo" split on "Foo" returns ["Bar"]. */ -std::vector split_str(const std::string &s, - const std::string &delim); +std::vector split_str(const std::string& s, const std::string& delim); /* Splits string into vector of strings representing * UTF-8 characters (not same as chars) */ -std::vector split_utf8_str(const std::string &str); - -// Add a word in index to the dicionary of fst -void add_word_to_fst(const std::vector &word, - fst::StdVectorFst *dictionary); - -// Add a word in string to dictionary -bool add_word_to_dictionary( - const std::string &word, - const std::unordered_map &char_map, - bool add_space, - int SPACE_ID, - fst::StdVectorFst *dictionary); -#endif // DECODER_UTILS_H +std::vector split_utf8_str(const std::string& str); + +// Add a word in index to the lexicon fst +void add_word_to_fst(const std::vector& word, fst::StdVectorFst* lexicon); + +// Add a word in string to lexicon +bool add_word_to_lexicon(const std::vector& characters, + const std::unordered_map& char_map, + bool add_space, + int SPACE_ID, + fst::StdVectorFst* lexicon); + +void set_char_map(const std::vector& char_list, + std::unordered_map& char_map, + int& space_id); + +bool is_mergeable_bpe_token(std::string cur_token, + int cur_char, + int parent_char, + int apostrophe_id, + char token_separator); + +#endif // DECODER_UTILS_H diff --git a/ctcdecode/src/hotword_scorer.cpp b/ctcdecode/src/hotword_scorer.cpp new file mode 100644 index 00000000..dc9430d5 --- /dev/null +++ b/ctcdecode/src/hotword_scorer.cpp @@ -0,0 +1,237 @@ +#include "hotword_scorer.h" +#include "decoder_utils.h" + +/** + * @brief Initializes the vocabulary list, hotwords, hotword weights and creates the hotword FST + */ +HotwordScorer::HotwordScorer(const std::vector& vocab_list, + const std::vector>& hotwords, + const std::vector& hotword_weights, + char token_separator, + bool is_bpe_based) + : vocabulary(vocab_list) +{ + + this->hotword_weights = hotword_weights; + this->hotwords = hotwords; + + dictionary = nullptr; + + dict_size_ = 0; + + SPACE_ID_ = -1; + is_bpe_based_ = is_bpe_based; + FSTZERO = fst::TropicalWeight::Zero(); + delimiter_ = "$$"; + token_separator = token_separator; + setup(vocab_list); +} + +/** + * @brief Deletes the hotword FST of the object + * + */ +HotwordScorer::~HotwordScorer() +{ + if (dictionary != nullptr) { + delete dictionary; + } +} + +/** + * @brief To map characters with it's index and create hotword dictionary + * + * @param vocab_list, list of labels provided + */ +void HotwordScorer::setup(const std::vector& vocab_list) +{ + set_char_map(vocab_list, char_map_, SPACE_ID_); + fill_hotword_dictionary(); +} + +/** + * @brief To add single word to the hotword dictionary (FST). This method also maps + * the hotword and it's corresponding weight in `hotword_weight_map_` hash map + * + * @param characters, list of characters/tokens of the word + * @param dictionary, FST dictionary to which characters needs to be added + * @param weight, hotword weight value for the given hotword + * @return true, when added successfully. + * @return false, when not added successfully + */ +bool HotwordScorer::add_word_to_hotword_dictionary(const std::vector& characters, + fst::StdVectorFst* dictionary, + float weight) +{ + + std::vector int_word; + std::string hotword = ""; + + for (auto& c : characters) { + if (c == " ") { + int_word.push_back(SPACE_ID_ + 1); + hotword += std::to_string(SPACE_ID_ + 1) + delimiter_; + } else { + auto int_c = char_map_.find(c); + if (int_c != char_map_.end()) { + int_word.push_back(int_c->second); + hotword += std::to_string(int_c->second) + delimiter_; + } else { + return false; // return without adding + } + } + } + + hotword_weight_map_[hotword] = weight; + // add word to dictionary + add_word_to_fst(int_word, dictionary); + return true; // return with successful adding +} + +/** + * @brief Creates a FST for the hotwords provided + * + */ +void HotwordScorer::fill_hotword_dictionary() +{ + + fst::StdVectorFst dictionary; + // For each unigram convert to ints and put in trie + int dict_size = 0; + int i = 0; + for (const auto& char_list : hotwords) { + bool added = add_word_to_hotword_dictionary(char_list, &dictionary, hotword_weights[i]); + dict_size += added ? 1 : 0; + ++i; + } + + dict_size_ = dict_size; + + /* Simplify FST + + * This gets rid of "epsilon" transitions in the FST. + * These are transitions that don't require a string input to be taken. + * Getting rid of them is necessary to make the FST determinisitc, but + * can greatly increase the size of the FST + */ + fst::RmEpsilon(&dictionary); + fst::StdVectorFst* new_dict = new fst::StdVectorFst; + + // /* This makes the FST deterministic, meaning for any string input there's + // * only one possible state the FST could be in. It is assumed our + // * dictionary is deterministic when using it. + // * (lest we'd have to check for multiple transitions at each state) + // */ + fst::Determinize(dictionary, new_dict); + + // /* Finds the simplest equivalent fst. This is unnecessary but decreases + // * memory usage of the dictionary + // */ + fst::Minimize(new_dict); + this->dictionary = new_dict; +} + +/** + * @brief This methods returns true when the current node can be extended from + * the given hotword FST dictionary state. + * + * @param path, PathTrie node + * @param dict_state, FST state id + * @return true, if the current node's character can be extended from the given FST state + * @return false, if the current node's character cannot be extended from the given FST state + */ +bool HotwordScorer::is_char_extendable_from_state(PathTrie* path, + fst::StdVectorFst::StateId dict_state) +{ + path->hotword_matcher->SetState(dict_state); + return path->hotword_matcher->Find(path->character + 1); +} + +/** + * @brief This methods returns true if the given node can form hotword. If true then the hotword can + * either be extended from the parent node or the current node starts the hotword. + * + * @param path, PathTrie node + * @param space_id, space id from vocabulary list + * @param apostrophe_id, apostrophe id from vocabulary list + * @return true, if the path forms a hotword + * @return false, if the path doesn't form a hotword + */ +bool HotwordScorer::is_hotpath(PathTrie* path, int space_id, int apostrophe_id) +{ + bool is_hotpath_ = path->parent->is_hotpath(); + + is_hotpath_ &= is_char_extendable_from_state(path, path->parent->hotword_dictionary_state); + + if (!is_hotpath_ && path->is_word_start_char()) { + path->reset_hotword_params(); + is_hotpath_ = is_char_extendable_from_state(path, 0); + } + + return is_hotpath_; +} + +/** + * @brief Finds the shortest possible candidate hotword for the current node + * + * @param path, PathTrie node + * @return candidate hotword and it's length + */ +std::tuple HotwordScorer::find_shortest_candidate_hotword_length(PathTrie* path) +{ + bool is_final = false; + + int len = path->hotword_match_len; + fst::StdVectorFst::StateId matcher_state = path->hotword_dictionary_state; + + std::string hotword = path->partial_hotword; + + // loop until final state is reached + while (!is_final) { + path->hotword_matcher->SetState(matcher_state); + + auto final_weight = dictionary->Final(matcher_state); + is_final = (final_weight != FSTZERO); + + if (!is_final) { + ++len; + // go to next state + matcher_state = path->hotword_matcher->Value().nextstate; + hotword += std::to_string(path->hotword_matcher->Value().ilabel) + delimiter_; + } + } + + return std::make_tuple(hotword, len); +} + +/** + * @brief This method updates the current node's hotword param values such as hotword match length, + * hotword dictionary state, partial_hotword, hotword weight and shortest unigram length and + * computes the hotword score + * + * @param path, PathTrie node + */ +void HotwordScorer::estimate_hw_score(PathTrie* path) +{ + // update state and match length + path->hotword_match_len += 1; + path->hotword_dictionary_state = path->hotword_matcher->Value().nextstate; + + // update partial hotword + path->partial_hotword + = path->partial_hotword + std::to_string(path->character + 1) + delimiter_; + + // update shortest unigram length, hotword_weight + int candidate_hotword_length; + std::string candidate_hotword; + + std::tie(candidate_hotword, candidate_hotword_length) + = find_shortest_candidate_hotword_length(path); + + path->shortest_unigram_length = candidate_hotword_length; + path->hotword_weight = hotword_weight_map_[candidate_hotword]; + + // calculate hotword score + path->hotword_score = (path->hotword_weight * (float)(path->hotword_match_len)) + / (float)(path->shortest_unigram_length); +} \ No newline at end of file diff --git a/ctcdecode/src/hotword_scorer.h b/ctcdecode/src/hotword_scorer.h new file mode 100644 index 00000000..e9b064f8 --- /dev/null +++ b/ctcdecode/src/hotword_scorer.h @@ -0,0 +1,62 @@ +#ifndef HOTWORD_SCORER_H_ +#define HOTWORD_SCORER_H_ + +#include + +#include "fst/fstlib.h" +#include "path_trie.h" + +class HotwordScorer { +public: + /* Initialize HotwordScorer for CTC beam decoding + * + * Parameters: + * vocab_list: A vector of vocabulary (labels). + * hotwords: A vector of hotwords containing character/token vectors + * hotword_weights: A vector of weights corresponds to each hotword in `hotwords`. + * token_separator: Token seperator character for bpe based vocabulary. + * is_bpe_based: Whether the vocabulary is bpe based + */ + HotwordScorer(const std::vector& vocab_list, + const std::vector>& hotwords, + const std::vector& hotword_weights, + char token_separator, + bool is_bpe_based); + ~HotwordScorer(); + + size_t get_hotword_dict_size() const { return dict_size_; } + + bool is_bpe_based() const { return is_bpe_based_; } + + bool is_hotpath(PathTrie* path, int space_id, int apostrophe_id); + void estimate_hw_score(PathTrie* path); + std::tuple find_shortest_candidate_hotword_length(PathTrie*); + bool is_char_extendable_from_state(PathTrie* path, fst::StdVectorFst::StateId dict_state); + + fst::StdVectorFst* dictionary; + std::vector hotword_weights; + std::vector> hotwords; + fst::TropicalWeight FSTZERO; + char token_separator; + +protected: + void setup(const std::vector& vocab_list); + + // fill hotword dictionary FST + void fill_hotword_dictionary(); + bool add_word_to_hotword_dictionary(const std::vector& characters, + fst::StdVectorFst* dictionary, + float weight); + +private: + /* data */ + size_t dict_size_; + int SPACE_ID_; + const std::vector& vocabulary; + std::unordered_map char_map_; + bool is_bpe_based_; + std::string delimiter_; + std::unordered_map hotword_weight_map_; +}; + +#endif // HOTWORD_SCORER_H_ diff --git a/ctcdecode/src/output.h b/ctcdecode/src/output.h index a921ee2c..f3fd3d58 100644 --- a/ctcdecode/src/output.h +++ b/ctcdecode/src/output.h @@ -1,11 +1,11 @@ #ifndef OUTPUT_H_ #define OUTPUT_H_ -/* Struct for the beam search output, containing the tokens based on the vocabulary indices, and the timesteps - * for each token in the beam search output +/* Struct for the beam search output, containing the tokens based on the vocabulary indices, and the + * timesteps for each token in the beam search output */ struct Output { std::vector tokens, timesteps; }; -#endif // OUTPUT_H_ +#endif // OUTPUT_H_ diff --git a/ctcdecode/src/path_trie.cpp b/ctcdecode/src/path_trie.cpp index bdf2bab4..ed67da44 100644 --- a/ctcdecode/src/path_trie.cpp +++ b/ctcdecode/src/path_trie.cpp @@ -1,174 +1,297 @@ #include "path_trie.h" -#include -#include -#include -#include -#include - #include "decoder_utils.h" -PathTrie::PathTrie() { - log_prob_b_prev = -NUM_FLT_INF; - log_prob_nb_prev = -NUM_FLT_INF; - log_prob_b_cur = -NUM_FLT_INF; - log_prob_nb_cur = -NUM_FLT_INF; - log_prob_c = -NUM_FLT_INF; - score = -NUM_FLT_INF; - - ROOT_ = -1; - character = ROOT_; - timestep = 0; - exists_ = true; - parent = nullptr; - - dictionary_ = nullptr; - dictionary_state_ = 0; - has_dictionary_ = false; - - matcher_ = nullptr; -} +PathTrie::PathTrie() +{ + log_prob_b_prev = -NUM_FLT_INF; + log_prob_nb_prev = -NUM_FLT_INF; + log_prob_b_cur = -NUM_FLT_INF; + log_prob_nb_cur = -NUM_FLT_INF; + + log_prob_b_prev_hw = -NUM_FLT_INF; + log_prob_nb_prev_hw = -NUM_FLT_INF; + log_prob_b_cur_hw = -NUM_FLT_INF; + log_prob_nb_cur_hw = -NUM_FLT_INF; + + log_prob_c = -NUM_FLT_INF; + score = -NUM_FLT_INF; + score_hw = -NUM_FLT_INF; + + ROOT_ = -1; + character = ROOT_; + timestep = 0; + exists_ = true; + parent = nullptr; + is_hotpath_ = false; + hotword_score = 0.0; + shortest_unigram_length = 0; + hotword_weight = 0.0; + partial_hotword = ""; -PathTrie::~PathTrie() { - for (auto child : children_) { - delete child.second; - } + lexicon_ = nullptr; + lexicon_state_ = 0; + has_lexicon_ = false; + is_word_start_char_ = false; + + matcher_ = nullptr; + hotword_matcher = nullptr; + hotword_dictionary_state = 0; + hotword_match_len = 0; } -PathTrie* PathTrie::get_path_trie(int new_char, int new_timestep, float cur_log_prob_c, bool reset) { - auto child = children_.begin(); - for (child = children_.begin(); child != children_.end(); ++child) { - if (child->first == new_char) { - if (child->second->log_prob_c < cur_log_prob_c) { - child->second->log_prob_c = cur_log_prob_c; - child->second->timestep = new_timestep; - } - break; +PathTrie::~PathTrie() +{ + for (auto child : children_) { + delete child.second; } - } - if (child != children_.end()) { - if (!child->second->exists_) { - child->second->exists_ = true; - child->second->log_prob_b_prev = -NUM_FLT_INF; - child->second->log_prob_nb_prev = -NUM_FLT_INF; - child->second->log_prob_b_cur = -NUM_FLT_INF; - child->second->log_prob_nb_cur = -NUM_FLT_INF; +} + +PathTrie* PathTrie::get_path_trie(int new_char, + int new_timestep, + float cur_log_prob_c, + bool reset, + bool check_lexicon) +{ + auto child = children_.begin(); + for (child = children_.begin(); child != children_.end(); ++child) { + if (child->first == new_char) { + if (child->second->log_prob_c < cur_log_prob_c) { + child->second->log_prob_c = cur_log_prob_c; + child->second->timestep = new_timestep; + } + break; + } } - return (child->second); - } else { - if (has_dictionary_) { - matcher_->SetState(dictionary_state_); - bool found = matcher_->Find(new_char + 1); - if (!found) { - // Adding this character causes word outside dictionary - auto FSTZERO = fst::TropicalWeight::Zero(); - auto final_weight = dictionary_->Final(dictionary_state_); - bool is_final = (final_weight != FSTZERO); - if (is_final && reset) { - dictionary_state_ = dictionary_->Start(); + if (child != children_.end()) { + if (!child->second->exists_) { + child->second->exists_ = true; + child->second->log_prob_b_prev = -NUM_FLT_INF; + child->second->log_prob_nb_prev = -NUM_FLT_INF; + child->second->log_prob_b_cur = -NUM_FLT_INF; + child->second->log_prob_nb_cur = -NUM_FLT_INF; + child->second->log_prob_b_prev_hw = -NUM_FLT_INF; + child->second->log_prob_nb_prev_hw = -NUM_FLT_INF; + child->second->log_prob_b_cur_hw = -NUM_FLT_INF; + child->second->log_prob_nb_cur_hw = -NUM_FLT_INF; + child->second->hotword_matcher = hotword_matcher; } - return nullptr; - } else { - PathTrie* new_path = new PathTrie; - new_path->character = new_char; - new_path->timestep = new_timestep; - new_path->parent = this; - new_path->dictionary_ = dictionary_; - new_path->has_dictionary_ = true; - new_path->matcher_ = matcher_; - new_path->log_prob_c = cur_log_prob_c; - - // set spell checker state - // check to see if next state is final - auto FSTZERO = fst::TropicalWeight::Zero(); - auto final_weight = dictionary_->Final(matcher_->Value().nextstate); - bool is_final = (final_weight != FSTZERO); - if (is_final && reset) { - // restart spell checker at the start state - new_path->dictionary_state_ = dictionary_->Start(); + return (child->second); + } else { + if (has_lexicon_ && check_lexicon) { + matcher_->SetState(lexicon_state_); + bool found = matcher_->Find(new_char + 1); + if (!found) { + // Adding this character causes word outside + // lexicon + auto FSTZERO = fst::TropicalWeight::Zero(); + auto final_weight = lexicon_->Final(lexicon_state_); + bool is_final = (final_weight != FSTZERO); + if (is_final && reset) { + lexicon_state_ = lexicon_->Start(); + } + return nullptr; + } else { + + PathTrie* new_path = create_new_node(new_char, new_timestep, cur_log_prob_c); + // set spell checker state + // check to see if next state is final + auto FSTZERO = fst::TropicalWeight::Zero(); + auto final_weight = lexicon_->Final(matcher_->Value().nextstate); + bool is_final = (final_weight != FSTZERO); + if (is_final && reset) { + // restart spell checker at the start state + new_path->lexicon_state_ = lexicon_->Start(); + } else { + // go to next state + new_path->lexicon_state_ = matcher_->Value().nextstate; + } + + children_.push_back(std::make_pair(new_char, new_path)); + return new_path; + } } else { - // go to next state - new_path->dictionary_state_ = matcher_->Value().nextstate; + PathTrie* new_path = create_new_node(new_char, new_timestep, cur_log_prob_c); + children_.push_back(std::make_pair(new_char, new_path)); + return new_path; } + } +} - children_.push_back(std::make_pair(new_char, new_path)); - return new_path; - } - } else { - PathTrie* new_path = new PathTrie; - new_path->character = new_char; - new_path->timestep = new_timestep; - new_path->parent = this; - new_path->log_prob_c = cur_log_prob_c; - children_.push_back(std::make_pair(new_char, new_path)); - return new_path; +/** + * @brief Creates new PathTrie node with the given character, timestep and log prob + * + * @param new_char, character id + * @param new_timestep, timestep + * @param cur_log_prob_c, character probability at this timestep + * @param has_lexicon, if lexicon is set + */ +PathTrie* PathTrie::create_new_node(int new_char, int new_timestep, float cur_log_prob_c) +{ + PathTrie* new_path = new PathTrie; + + new_path->character = new_char; + new_path->timestep = new_timestep; + new_path->parent = this; + new_path->log_prob_c = cur_log_prob_c; + new_path->hotword_matcher = hotword_matcher; + new_path->hotword_dictionary_state = hotword_dictionary_state; + new_path->hotword_match_len = hotword_match_len; + new_path->shortest_unigram_length = shortest_unigram_length; + new_path->hotword_weight = hotword_weight; + new_path->partial_hotword = partial_hotword; + + if (has_lexicon_) { + new_path->lexicon_ = lexicon_; + new_path->has_lexicon_ = true; + new_path->matcher_ = matcher_; } - } + + return new_path; } -PathTrie* PathTrie::get_path_vec(std::vector& output, std::vector& timesteps) { - return get_path_vec(output, timesteps, ROOT_); +PathTrie* PathTrie::get_path_vec(std::vector& output, std::vector& timesteps) +{ + return get_path_vec(output, timesteps, ROOT_); } PathTrie* PathTrie::get_path_vec(std::vector& output, std::vector& timesteps, int stop, - size_t max_steps) { - if (character == stop || character == ROOT_ || output.size() == max_steps) { - std::reverse(output.begin(), output.end()); - std::reverse(timesteps.begin(), timesteps.end()); - return this; - } else { - output.push_back(character); - timesteps.push_back(timestep); - return parent->get_path_vec(output, timesteps, stop, max_steps); - } + size_t max_steps) +{ + if (character == stop || character == ROOT_ || output.size() == max_steps) { + std::reverse(output.begin(), output.end()); + std::reverse(timesteps.begin(), timesteps.end()); + return this; + } else { + output.push_back(character); + timesteps.push_back(timestep); + return parent->get_path_vec(output, timesteps, stop, max_steps); + } } -void PathTrie::iterate_to_vec(std::vector& output) { - if (exists_) { - log_prob_b_prev = log_prob_b_cur; - log_prob_nb_prev = log_prob_nb_cur; +void PathTrie::iterate_to_vec(std::vector& output) +{ + if (exists_) { - log_prob_b_cur = -NUM_FLT_INF; - log_prob_nb_cur = -NUM_FLT_INF; + log_prob_b_prev = log_prob_b_cur; + log_prob_nb_prev = log_prob_nb_cur; - score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev); - output.push_back(this); - } - for (auto child : children_) { - child.second->iterate_to_vec(output); - } -} + log_prob_b_prev_hw = log_prob_b_cur_hw; + log_prob_nb_prev_hw = log_prob_nb_cur_hw; -void PathTrie::remove() { - exists_ = false; - - if (children_.size() == 0) { - auto child = parent->children_.begin(); - for (child = parent->children_.begin(); child != parent->children_.end(); - ++child) { - if (child->first == character) { - parent->children_.erase(child); - break; - } - } + score = log_sum_exp(log_prob_b_prev, log_prob_nb_prev); + score_hw = log_sum_exp(log_prob_b_prev_hw, log_prob_nb_prev_hw); - if (parent->children_.size() == 0 && !parent->exists_) { - parent->remove(); + log_prob_b_cur = -NUM_FLT_INF; + log_prob_nb_cur = -NUM_FLT_INF; + log_prob_b_cur_hw = -NUM_FLT_INF; + log_prob_nb_cur_hw = -NUM_FLT_INF; + + output.push_back(this); + } + for (auto child : children_) { + child.second->iterate_to_vec(output); } +} + +void PathTrie::remove() +{ + exists_ = false; + + if (children_.size() == 0) { + auto child = parent->children_.begin(); + for (child = parent->children_.begin(); child != parent->children_.end(); ++child) { + if (child->first == character) { + parent->children_.erase(child); + break; + } + } + + if (parent->children_.size() == 0 && !parent->exists_) { + parent->remove(); + } - delete this; - } + delete this; + } } -void PathTrie::set_dictionary(fst::StdVectorFst* dictionary) { - dictionary_ = dictionary; - dictionary_state_ = dictionary->Start(); - has_dictionary_ = true; +void PathTrie::set_lexicon(fst::StdVectorFst* lexicon) +{ + lexicon_ = lexicon; + lexicon_state_ = lexicon->Start(); + has_lexicon_ = true; } using FSTMATCH = fst::SortedMatcher; -void PathTrie::set_matcher(std::shared_ptr matcher) { - matcher_ = matcher; +void PathTrie::set_matcher(std::shared_ptr matcher) { matcher_ = matcher; } + +/** + * @brief Copies parent's hotword related params to the current node + */ +void PathTrie::copy_parent_hotword_params() +{ + hotword_match_len = parent->hotword_match_len; + hotword_dictionary_state = parent->hotword_dictionary_state; + shortest_unigram_length = parent->shortest_unigram_length; + hotword_weight = parent->hotword_weight; + partial_hotword = parent->partial_hotword; } + +/** + * @brief Resets the hotword related params of the current node + */ +void PathTrie::reset_hotword_params() +{ + hotword_match_len = 0; + hotword_dictionary_state = 0; + shortest_unigram_length = 0; + hotword_weight = 0.0; + partial_hotword = ""; +} + +/** + * @brief Checks if the current node forms OOV word and accordingly updates its + * lexicon state + * + * @param true, if current node forms OOV word + * @param false, if current node doesn't form OOV word + */ +bool PathTrie::is_oov_token() +{ + + if (has_lexicon_) { + + fst::StdVectorFst::StateId lexicon_state; + + // If this is the start token of the word, then set the lexicon state + // to the start state of the lexicon, else + // use the parent's lexicon state + if (is_word_start_char_) { + lexicon_state = lexicon_->Start(); + + } else { + lexicon_state = parent->lexicon_state_; + } + + // check if the character can be extended from the + // lexicon state + matcher_->SetState(lexicon_state); + bool found = matcher_->Find(character + 1); + + // If the character can be extended, then update the lexicon state + // of the current node to the next state of the matcher, else + // reset the lexicon state of the current node to the start state + if (found) { + lexicon_state_ = matcher_->Value().nextstate; + + } else { + lexicon_state_ = lexicon_->Start(); + } + + return !found; + } + + return false; +} \ No newline at end of file diff --git a/ctcdecode/src/path_trie.h b/ctcdecode/src/path_trie.h index baa27dbe..23fe1aad 100644 --- a/ctcdecode/src/path_trie.h +++ b/ctcdecode/src/path_trie.h @@ -14,57 +14,98 @@ */ class PathTrie { public: - PathTrie(); - ~PathTrie(); + PathTrie(); + ~PathTrie(); - // get new prefix after appending new char - PathTrie* get_path_trie(int new_char, int new_timestep, float log_prob_c, bool reset = true); + // get new prefix after appending new char + PathTrie* get_path_trie(int new_char, + int new_timestep, + float log_prob_c, + bool reset = true, + bool check_lexicon = true); - // get the prefix in index from root to current node - PathTrie* get_path_vec(std::vector& output, std::vector& timesteps); + // get the prefix in index from root to current node + PathTrie* get_path_vec(std::vector& output, std::vector& timesteps); - // get the prefix in index from some stop node to current nodel - PathTrie* get_path_vec(std::vector& output, - std::vector& timesteps, - int stop, - size_t max_steps = std::numeric_limits::max()); + // get the prefix in index from some stop node to current nodel + PathTrie* get_path_vec(std::vector& output, + std::vector& timesteps, + int stop, + size_t max_steps = std::numeric_limits::max()); - // update log probs - void iterate_to_vec(std::vector& output); + // creates new PathTrie* node + PathTrie* create_new_node(int new_char, int new_timestep, float cur_log_prob_c); - // set dictionary for FST - void set_dictionary(fst::StdVectorFst* dictionary); + // update log probs + void iterate_to_vec(std::vector& output); - void set_matcher(std::shared_ptr>); + // set lexicon for FST + void set_lexicon(fst::StdVectorFst* lexicon); - bool is_empty() { return ROOT_ == character; } + void set_matcher(std::shared_ptr>); - // remove current path from root - void remove(); + bool is_empty() { return ROOT_ == character; } - float log_prob_b_prev; - float log_prob_nb_prev; - float log_prob_b_cur; - float log_prob_nb_cur; - float log_prob_c; - float score; - float approx_ctc; - int character; - int timestep; - PathTrie* parent; + bool is_hotpath() { return is_hotpath_; } + + // set as hotpath + void mark_as_hotpath() { is_hotpath_ = true; } + + bool is_word_start_char() { return is_word_start_char_; } + + // set as word start character + void mark_as_word_start_char() { is_word_start_char_ = true; } + + bool has_lexicon() { return has_lexicon_; } + + // check if current token forms OOV word + bool is_oov_token(); + + // remove current path from root + void remove(); + + void reset_hotword_params(); + void copy_parent_hotword_params(); + + float log_prob_b_prev; + float log_prob_nb_prev; + float log_prob_b_cur; + float log_prob_nb_cur; + float log_prob_b_prev_hw; + float log_prob_nb_prev_hw; + float log_prob_b_cur_hw; + float log_prob_nb_cur_hw; + + float log_prob_c; + float score; + float score_hw; + float approx_ctc; + int character; + int timestep; + PathTrie* parent; + float hotword_score; + int shortest_unigram_length; + float hotword_weight; + std::string partial_hotword; + std::shared_ptr> hotword_matcher; + fst::StdVectorFst::StateId hotword_dictionary_state; + int hotword_match_len; private: - int ROOT_; - bool exists_; - bool has_dictionary_; + int ROOT_; + bool exists_; + bool has_lexicon_; + + bool is_hotpath_; + bool is_word_start_char_; - std::vector> children_; + std::vector> children_; - // pointer to dictionary of FST - fst::StdVectorFst* dictionary_; - fst::StdVectorFst::StateId dictionary_state_; - // true if finding ars in FST - std::shared_ptr> matcher_; + // pointer to lexicon of FST + fst::StdVectorFst* lexicon_; + fst::StdVectorFst::StateId lexicon_state_; + // true if finding ars in FST + std::shared_ptr> matcher_; }; -#endif // PATH_TRIE_H +#endif // PATH_TRIE_H diff --git a/ctcdecode/src/resourceutils.cpp b/ctcdecode/src/resourceutils.cpp new file mode 100644 index 00000000..4901f28f --- /dev/null +++ b/ctcdecode/src/resourceutils.cpp @@ -0,0 +1,464 @@ +// #include +// #include +// #include +// #include +// #include +// #include + +// #include "resourceutils.h" + +// #define FORMATINFO(stream, key, value) \ +// (stream << std::left << std::setw(17) << std::setfill(' ') << key << value) + +// ProcessInfo::ProcessInfo() +// { +// pid = getpid(); +// tid = syscall(SYS_gettid); +// fdCount = 0; +// cpuPercent = 0.0; +// memoryPercent = 0.0; +// vsz = 0.0; +// rss = 0.0; +// framesProcessed = 0; +// latency = 0.0; +// } + +// /** @brief calculates rtf, rts */ +// void ProcessInfo::CalculateRealTimeValues() +// { +// rtf = latency / ((framesProcessed * batchSize) / sampleRate); +// rtf /= 1000.0; // rtf will be in ms, converting it to seconds +// rts = 1.0 / rtf; +// } + +// /** @brief calculates the avg vsz, rss, ram%, cpu%, latency */ +// void ProcessInfo::CalculateAvgResources() +// { +// vsz /= framesProcessed; +// rss /= framesProcessed; +// memoryPercent /= framesProcessed; +// cpuPercent /= framesProcessed; +// avgLatency = latency / framesProcessed; +// latency /= 1000.0; +// } + +// /** @brief Calculates the 90th, 97th, 99th Percentile values from the latencyValues vector */ +// void ProcessInfo::CalculatePercentile() +// { +// latency90P = Percentile(latencyValues, 0.9); +// latency97P = Percentile(latencyValues, 0.97); +// latency99P = Percentile(latencyValues, 0.99); +// } + +// /** @brief calculates the consolidates average stats of the process information */ +// void ProcessInfo::CalculateAvgStats() +// { +// CalculateRealTimeValues(); +// CalculateAvgResources(); +// CalculatePercentile(); +// latencyValues.clear(); +// } + +// /** @brief sorts the latencyValues vector */ +// void ProcessInfo::SortLatencyValues() { std::sort(latencyValues.begin(), latencyValues.end()); } + +// /** @brief pushes the latency value into the latencyValues vector */ +// void ProcessInfo::PushToLatencyValues(double latency) { latencyValues.push_back(latency); } + +// ResourceUtil::ResourceUtil() +// : logger(nullptr) +// { +// Init(); +// } + +// /** @brief get the logger object to log resources +// * @param logger logger object pointer +// * @returns void +// */ +// void ResourceUtil::SetLogger(Logger* logger) { this->logger = logger; } + +// /** @brief set the directory string to parse the stat file and filecount value */ +// void ResourceUtil::SetDirectoryStrings() +// { +// std::string pidStr = std::to_string(processInfo.pid); +// std::string tidStr = std::to_string(processInfo.tid); +// DIR* dir = opendir(("/proc/" + tidStr).c_str()); +// if (dir == NULL) { +// directoryStr = "/proc/" + pidStr + "/task/" + tidStr + "/stat"; +// fdCommand = "find /proc/" + pidStr + "/task/" + tidStr + "/fd | wc -l"; +// } else { +// directoryStr = "/proc/" + tidStr + "/stat"; +// fdCommand = "find /proc/" + tidStr + "/fd | wc -l"; +// closedir(dir); +// } +// } + +// /** @brief set the batchsize of the model, this value is later used +// * to compute the RTF and RTS value +// */ +// void ResourceUtil::SetBatchSize(int batchSize) { processInfo.batchSize = (double)batchSize; } + +// void ResourceUtil::SetSampleRate(std::string sampleRate) +// { +// processInfo.sampleRate = (double)std::stoi(sampleRate); +// } + +// /** @brief monitors the provided pid process */ +// void ResourceUtil::SetPid(pid_t pid) +// { +// processInfo.pid = pid; +// processInfo.tid = pid; +// SetDirectoryStrings(); +// } + +// /** @brief initialises constant values +// * @returns void +// */ +// void ResourceUtil::Init() +// { +// tSleep.tv_sec = 0; +// tSleep.tv_nsec = 10000000L; // 10ms +// clockTicks = sysconf(_SC_CLK_TCK); // Eg: 100 +// pageSizeKB = sysconf(_SC_PAGE_SIZE) / 1024.0; // Eg: 4 +// timePosition = 14; +// vszPosition = 23; +// SetDirectoryStrings(); +// } + +// /** @brief get the current date and time string +// * @returns std::string +// */ +// std::string ResourceUtil::GetDateAndTime() +// { +// time_t now; +// time(&now); +// struct tm tstruct = *localtime(&now); +// std::ostringstream currTime; +// currTime << tstruct.tm_year + 1900 << "-" << tstruct.tm_mon << "-" << tstruct.tm_mday << " "; +// currTime << tstruct.tm_hour << ":" << tstruct.tm_min << ":" << tstruct.tm_sec; +// return currTime.str(); +// } + +// /** @brief reads the vsz, rss, stime and utime of the process from the stat file +// * @param vsz virtual memory +// * @param rss RAM memory percentage +// * @param utime user time value +// * @param stime system time[kernel time] +// * @returns void +// * +// * From the stat file, the value +// * vsz - will be in bytes - Eg: 175415296 +// * rss - will be number of pages in memory - Eg: 24311 +// * utime - will be in clock ticks - Eg: 244 +// * stime - will be in clock ticks - Eg: 13 +// */ +// void ResourceUtil::ReadStatFile(double* vsz, double* rss, double* utime, double* stime) +// { +// int count = 1; +// std::string resourceValue; +// std::ifstream statStream(directoryStr, std::ios_base::in); +// while (statStream.good()) { +// if (count == timePosition) { +// statStream >> (*utime) >> (*stime); +// if (!vsz) { +// break; +// } +// count += 2; +// } else if (count == vszPosition) { +// statStream >> (*vsz) >> (*rss); +// break; +// } else { +// statStream >> resourceValue; +// count++; +// } +// } +// statStream.close(); +// } + +// /** @brief returns the filecount of the process +// * @returns int +// */ +// int ResourceUtil::FromShell() +// { +// FILE* pipe = popen(fdCommand.c_str(), "r"); +// if (pipe == nullptr) { +// return -1; +// } +// char buf[128]; +// while (fgets(buf, 128, pipe) != nullptr) { +// int count = std::stoi(buf); +// pclose(pipe); +// return count - 1; +// } +// if (ferror(pipe)) { +// return -1; +// } +// if (pclose(pipe) == -1) { +// return -1; +// } +// return -2; +// } + +// /** @brief returns the memory usage percentage of the process +// * @param rss Resident set size value +// * @returns double +// */ +// double ResourceUtil::GetMemoryPercent(double rss) +// { +// double total, mem; +// std::string resourceValue; +// std::ifstream stream("/proc/meminfo", std::ios_base::in); +// while (stream.good()) { +// stream >> resourceValue; +// if (resourceValue.find("MemTotal") != std::string::npos) { +// stream >> total; +// break; +// } +// } +// stream.close(); +// mem = (rss / total) * 100; +// return rss; +// } + +// /** @brief calculates the vsz, rss, cpupercent and of the process +// * @param vsz virtual memory +// * @param memoryPercent RAM memory percentage +// * @param cpuPercent CPU percentage +// * @returns void +// */ +// void ResourceUtil::ResourceCalculator(double* vsz, double* rss, double* cpuPercent) +// { +// double proc, time; +// double utime1, utime2, stime1, stime2; +// timespec tStart, tEnd, tRemaining; +// clock_gettime(CLOCK_MONOTONIC, &tStart); +// tStart.tv_nsec = tStart.tv_nsec; +// ReadStatFile(nullptr, nullptr, &utime1, &stime1); +// int val = nanosleep(&tSleep, &tRemaining); +// if (val == -1) { +// nanosleep(&tRemaining, nullptr); +// } +// clock_gettime(CLOCK_MONOTONIC, &tEnd); +// tEnd.tv_nsec = tEnd.tv_nsec; +// ReadStatFile(vsz, rss, &utime2, &stime2); +// proc = ((utime2 - utime1) + (stime2 - stime1)) / clockTicks; +// time = ((double)(tEnd.tv_nsec - tStart.tv_nsec)) * 1e-9; // converting nano_second to second +// *cpuPercent = ((proc / time) * 100); +// } + +// /** @brief log the resource values calculated, to the logger object provided +// * @param vsz virtual memory +// * @param memoryPercent RAM memory percentage +// * @param cpuPercent CPU percentage +// * @param fdCount Files count +// * @param latency Latency time for processing this frame +// * @returns void +// */ +// void ResourceUtil::WriteToLogger(double vsz, +// double memoryPercent, +// double cpuPercent, +// int fdCount, +// double latency) +// { +// if (latency != 0) { +// logger->Log(LogLevel::DEBUG, +// "PID:", +// processInfo.pid, +// ", TID: ", +// processInfo.tid, +// ", VSZ_VALUE:", +// vsz, +// "kb, MEMORY_PERCENT:", +// memoryPercent, +// "%, CPU_PERCENT:", +// cpuPercent, +// "%, FILES_COUNT:", +// fdCount, +// "FRAME_LATENCY:", +// latency, +// "ms"); +// } else { +// logger->Log(LogLevel::DEBUG, +// "PID:", +// processInfo.pid, +// ", TID: ", +// processInfo.tid, +// ", VSZ_VALUE:", +// vsz, +// "kb, MEMORY_PERCENT:", +// memoryPercent, +// "%, CPU_PERCENT:", +// cpuPercent, +// "%, FILES_COUNT:", +// fdCount); +// } +// } + +// /** @brief writes the resource values calculated to the resource.txt file +// * @param vsz virtual memory +// * @param memoryPercent RAM memory percentage +// * @param cpuPercent CPU percentage +// * @param fdCount Files count +// * @param latency Latency time for processing this frame +// * @returns void +// */ +// void ResourceUtil::WriteToFile(double vsz, +// double memoryPercent, +// double cpuPercent, +// int fdCount, +// double latency) +// { +// std::ofstream file("resource.txt", std::ios_base::app); +// std::string date_time = GetDateAndTime(); +// file << date_time << ' ' << "PID: " << processInfo.pid << ", "; +// file << "TID: " << processInfo.tid << ", "; +// file << "VSZ_VALUE: " << vsz << " kb, "; +// file << "MEMORY_PERCENT: " << memoryPercent << " kb, "; +// file << "CPU_PERCENT: " << cpuPercent << " %, "; +// file << "FILES_COUNT: " << fdCount; +// if (latency != 0) { +// file << " , FRAME_LATENCY: " << latency << "ms"; +// } +// file << std::endl; +// file.close(); +// } + +// /** @brief returns the average resource values calculated so far, as string +// * @returns void +// */ +// std::unordered_map ResourceUtil::GetAverageStats() +// { +// std::unordered_map processDetails; +// processInfo.SortLatencyValues(); +// processInfo.CalculateAvgStats(); +// processDetails["FRAMES"] = processInfo.framesProcessed; +// processDetails["TOT_LAT"] = processInfo.latency; +// processDetails["AVG_LAT"] = processInfo.avgLatency; +// // processDetails["99P"] = processInfo.latency99P; +// // processDetails["97P"] = processInfo.latency97P; +// // processDetails["90P"] = processInfo.latency90P; +// processDetails["AVG_CPU"] = processInfo.cpuPercent; +// processDetails["AVG_RAM"] = processInfo.memoryPercent; +// processDetails["AVG_VSZ"] = processInfo.vsz; +// processDetails["RTS"] = processInfo.rts; +// processDetails["TID"] = processInfo.tid; +// processDetails["PID"] = processInfo.pid; +// return processDetails; +// } + +// /** @brief calculates the resource details of pid and write it to file +// * @param latency Latency time taken to process a frame +// * @return void +// */ +// void ResourceUtil::Monitor(double latency) +// { +// double vsz, rss, cpuPercent, memoryPercent; +// ResourceCalculator(&vsz, &rss, &cpuPercent); +// vsz = vsz / 1024.0; +// rss = rss * pageSizeKB; +// memoryPercent = GetMemoryPercent(rss); +// processInfo.fdCount = FromShell(); +// latency *= 1000; +// if (logger) { +// WriteToLogger(vsz, memoryPercent, cpuPercent, processInfo.fdCount, latency); +// } else { +// WriteToFile(vsz, memoryPercent, cpuPercent, processInfo.fdCount, latency); +// } +// processInfo.cpuPercent += cpuPercent; +// processInfo.memoryPercent += memoryPercent; +// processInfo.vsz += vsz; +// processInfo.rss += rss; +// processInfo.framesProcessed++; +// processInfo.latency += latency; +// processInfo.PushToLatencyValues(latency); +// } + +// /** @brief gets the current resource consumption values of the pid */ +// ProcessInfo ResourceUtil::GetResourceValues(pid_t pid) +// { +// double vsz, rss, cpuPercent; +// processInfo.pid = pid; +// std::string pidStr = std::to_string(pid); +// directoryStr = "/proc/" + pidStr + "/stat"; +// fdCommand = "find /proc/" + pidStr + "/fd | wc -l"; +// ResourceCalculator(&vsz, &rss, &cpuPercent); +// processInfo.vsz = vsz / 1024.0; +// processInfo.rss = rss * pageSizeKB; +// processInfo.memoryPercent = GetMemoryPercent(processInfo.rss); +// processInfo.fdCount = FromShell(); +// processInfo.cpuPercent = cpuPercent; +// return processInfo; +// } + +// /** @brief aligning function used by GetHeaders */ +// std::string ProcessInfo::Center(const std::string s, const int w) +// { +// std::ostringstream ss, spaces; +// int padding = w - s.size(); +// for (int i = 0; i < padding / 2; ++i) +// spaces << " "; +// ss << spaces.str() << s << spaces.str(); +// if (padding > 0 && padding % 2 != 0) +// ss << " "; +// return ss.str(); +// } + +// /** @brief aligning function used by FormatToString */ +// std::string ProcessInfo::Prd(const double x, int decDigits, const int width) +// { +// if (x == std::floor(x)) +// decDigits = 0; +// std::stringstream ss; +// ss << std::fixed << std::right; +// ss.fill(' '); +// ss.width(width); +// ss.precision(decDigits); +// ss << x; +// return ss.str(); +// } + +// /** @brief Explain each stat keys */ +// std::string ProcessInfo::GetInfo() +// { +// std::ostringstream info; +// info << "\n"; +// FORMATINFO(info, "PID", ": Process ID\n"); +// FORMATINFO(info, "TID", ": Thread ID\n"); +// FORMATINFO(info, "AVG_VSZ", ": Average Virtual Memory size of the (in kb)\n"); +// FORMATINFO(info, "AVG_RAM", ": Average RAM usage of the process (in %)\n"); +// FORMATINFO(info, "AVG_CPU", ": Average CPU usage of the thread (in %)\n"); +// FORMATINFO(info, +// "RTS", +// ": 1 / RTF (Real Time Speed) - how much the synthesis is faster than realtime. " +// "If the value you get is 5 then it means, the library takes 1second to denoise a " +// "5second input.\n"); +// FORMATINFO(info, "AVG_LAT", ": Average processing time for a frame (in milli seconds)\n"); +// FORMATINFO(info, "TOT_LAT", ": Sum of processing time of frames (in seconds)\n"); +// FORMATINFO(info, "FRAMES", ": Total number of frames processed\n"); +// return info.str(); +// } + +// /** @brief Formats the provided vector of process infos into a string */ +// std::string ProcessInfo::FormatProcessInfoVector( +// const std::vector>& processInfos) +// { +// std::ostringstream finalStats; +// std::for_each(processInfos[0].cbegin(), +// processInfos[0].cend(), +// [&finalStats](const std::pair& pair) { +// finalStats << Center(pair.first, 12) << " | "; +// }); +// finalStats << "\n"; +// for (const std::unordered_map& processInfo : processInfos) { +// std::for_each(processInfo.cbegin(), +// processInfo.cend(), +// [&finalStats](const std::pair& pair) { +// finalStats << Prd(pair.second, 4, 12) << " | "; +// }); +// finalStats << "\n"; +// } +// finalStats << ProcessInfo::GetInfo(); +// return finalStats.str(); +// } + +// #undef FORMATINFO diff --git a/ctcdecode/src/resourceutils.h b/ctcdecode/src/resourceutils.h new file mode 100644 index 00000000..cb76cfc9 --- /dev/null +++ b/ctcdecode/src/resourceutils.h @@ -0,0 +1,110 @@ +// #pragma once + +// #include +// #include +// #include + +// #include "logger.h" + +// class ProcessInfo { +// public: +// /** @brief Find the Nth percentile value from the provided vector +// * @param values A sorted vector of any datatype +// * @param percentile a float value between 0 and 1. +// */ +// template +// static double Percentile(std::vector const& values, float percentile) +// { +// double position = ((double)values.size() - 1.0) * percentile; +// int positionFloor = std::floor(position); +// if (position == positionFloor) +// return (double)values[positionFloor]; +// double fractionValue = std::fmod(position, (double)positionFloor); +// double value +// = (double)values[positionFloor] +// + (fractionValue * (double)(values[positionFloor + 1] - values[positionFloor])); +// return value; +// } +// static std::string GetInfo(); +// static std::string FormatProcessInfoVector( +// const std::vector>& processInfo); + +// int fdCount, pid, tid; +// double cpuPercent, memoryPercent, vsz, rss, framesProcessed, avgLatency; +// double rtf, rts, latency90P, latency97P, latency99P, latency, batchSize; +// double sampleRate; + +// ProcessInfo(); + +// void CalculateAvgStats(); +// inline void SortLatencyValues(); +// inline void PushToLatencyValues(double latency); + +// private: +// static std::string Prd(const double x, int decDigits, const int width); +// static std::string Center(const std::string s, const int w); + +// std::vector latencyValues; + +// void CalculateRealTimeValues(); +// void CalculateAvgResources(); +// void CalculatePercentile(); +// }; + +// class ResourceUtil { +// public: +// ResourceUtil(); + +// std::unordered_map GetAverageStats(); +// ProcessInfo GetResourceValues(pid_t pid); +// void Monitor(double latency = 0); +// void SetBatchSize(int batchSize); +// void SetSampleRate(std::string sampRate); +// void SetLogger(Logger* logger); +// void SetPid(pid_t pid); + +// private: +// int timePosition, vszPosition; +// double clockTicks, pageSizeKB; +// std::string directoryStr, fdCommand; +// ProcessInfo processInfo; +// timespec tSleep; +// Logger* logger; + +// int FromShell(); +// double GetMemoryPercent(double rss); +// std::string GetDateAndTime(); +// void Init(); +// void ReadStatFile(double* vsz, double* rss, double* utime, double* stime); +// void ResourceCalculator(double* vsz, double* rss, double* cpuPercent); +// void SetDirectoryStrings(); +// void +// WriteToFile(double vsz, double memoryPercent, double cpuPercent, int fdCount, double +// latency); void WriteToLogger(double vsz, double memoryPercent, double cpuPercent, int +// fdCount, double latency); +// }; + +// /* +// Usage: +// ResourceUtil resourceUtil; + +// // this method will calculate CPU%, RAM%, VM(kb), filecount + +// // your function +// resourceUtil.CalculateResource(); + +// // if provided with latency per frame, it will also calculate +// // avg time per latency and total latency time taken +// resourceUtil.CalculateResource(latency_value); + +// // additionally you can set samplerate and input batchsize of your model. +// resourceUtil.SetSampleRate(16000); +// resourceUtil.SetBatchSize(1024); +// // the above values will be used to calculate the RTS and percentile. + +// // Finally after processing call +// resourceUtil.GetAverageStats(); +// // this will return an unordered_map of string and double, regarding the +// // stats monitored. + +// */ diff --git a/ctcdecode/src/scorer.cpp b/ctcdecode/src/scorer.cpp index c3550b3a..1475c1d3 100644 --- a/ctcdecode/src/scorer.cpp +++ b/ctcdecode/src/scorer.cpp @@ -1,7 +1,7 @@ #include "scorer.h" -#include #include +#include #include "lm/config.hh" #include "lm/model.hh" @@ -16,215 +16,268 @@ using namespace lm::ngram; Scorer::Scorer(double alpha, double beta, const std::string& lm_path, - const std::vector& vocab_list) { - this->alpha = alpha; - this->beta = beta; - - dictionary = nullptr; - is_character_based_ = true; - language_model_ = nullptr; - - max_order_ = 0; - dict_size_ = 0; - SPACE_ID_ = -1; - - setup(lm_path, vocab_list); + const std::vector& vocab_list, + const std::string& lm_type, + const std::string& lexicon_fst_path) +{ + this->alpha = alpha; + this->beta = beta; + this->lm_type = StringToTokenizerType[lm_type]; + lexicon = nullptr; + language_model_ = nullptr; + max_order_ = 0; + dict_size_ = 0; + SPACE_ID_ = -1; + + char_list_ = vocab_list; + setup(lm_path, vocab_list, lexicon_fst_path); } -Scorer::~Scorer() { - if (language_model_ != nullptr) { - delete static_cast(language_model_); - } - if (dictionary != nullptr) { - delete static_cast(dictionary); - } +Scorer::~Scorer() +{ + if (language_model_ != nullptr) { + delete static_cast(language_model_); + } + if (lexicon != nullptr) { + delete static_cast(lexicon); + } } void Scorer::setup(const std::string& lm_path, - const std::vector& vocab_list) { - // load language model - load_lm(lm_path); - // set char map for scorer - set_char_map(vocab_list); - // fill the dictionary for FST - if (!is_character_based()) { - fill_dictionary(true); - } + const std::vector& vocab_list, + const std::string& lexicon_fst_path) +{ + // load language model + load_lm(lm_path); + // set char map for scorer + set_char_map(vocab_list, char_map_, SPACE_ID_); + // fill the dictionary for FST + if (is_word_based() || !lexicon_fst_path.empty()) { + load_lexicon(true, lexicon_fst_path); + } } -void Scorer::load_lm(const std::string& lm_path) { - const char* filename = lm_path.c_str(); - VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path"); - - RetriveStrEnumerateVocab enumerate; - lm::ngram::Config config; - config.enumerate_vocab = &enumerate; - language_model_ = lm::ngram::LoadVirtual(filename, config); - max_order_ = static_cast(language_model_)->Order(); - vocabulary_ = enumerate.vocabulary; - for (size_t i = 0; i < vocabulary_.size(); ++i) { - if (is_character_based_ && vocabulary_[i] != UNK_TOKEN && - vocabulary_[i] != START_TOKEN && vocabulary_[i] != END_TOKEN && - get_utf8_str_len(enumerate.vocabulary[i]) > 1) { - is_character_based_ = false; +void Scorer::load_lm(const std::string& lm_path) +{ + const char* filename = lm_path.c_str(); + VALID_CHECK_EQ(access(filename, F_OK), 0, "Invalid language model path"); + + RetriveStrEnumerateVocab enumerate; + lm::ngram::Config config; + config.enumerate_vocab = &enumerate; + language_model_ = lm::ngram::LoadVirtual(filename, config); + max_order_ = static_cast(language_model_)->Order(); + vocabulary_ = enumerate.vocabulary; + + if (!is_bpe_based()) { + for (auto it = vocabulary_.begin(); it != vocabulary_.end(); ++it) { + if (is_character_based() && *it != UNK_TOKEN && *it != START_TOKEN && *it != END_TOKEN + && get_utf8_str_len(*it) > 1) { + lm_type = TokenizerType::WORD; + break; // terminate after `lm_type` is set + } + } } - } } -double Scorer::get_log_cond_prob(const std::vector& words) { - lm::base::Model* model = static_cast(language_model_); - double cond_prob; - lm::ngram::State state, tmp_state, out_state; - // avoid to inserting in begin - model->NullContextWrite(&state); - for (size_t i = 0; i < words.size(); ++i) { - lm::WordIndex word_index = model->BaseVocabulary().Index(words[i]); - // encounter OOV - if (word_index == 0) { - return OOV_SCORE; +double Scorer::get_log_cond_prob(const std::vector& words) +{ + lm::base::Model* model = static_cast(language_model_); + double cond_prob; + lm::ngram::State state, tmp_state, out_state; + // avoid to inserting in begin + model->NullContextWrite(&state); + for (size_t i = 0; i < words.size(); ++i) { + lm::WordIndex word_index = 0; + if (words[i] != UNK_TOKEN) { + word_index = model->BaseVocabulary().Index(words[i]); + } + // encounter OOV + if (word_index == 0) { + return OOV_SCORE; + } + cond_prob = model->BaseScore(&state, word_index, &out_state); + tmp_state = state; + state = out_state; + out_state = tmp_state; } - cond_prob = model->BaseScore(&state, word_index, &out_state); - tmp_state = state; - state = out_state; - out_state = tmp_state; - } - // return loge prob - return cond_prob/NUM_FLT_LOGE; + // return loge prob + return cond_prob / NUM_FLT_LOGE; } -double Scorer::get_sent_log_prob(const std::vector& words) { - std::vector sentence; - if (words.size() == 0) { - for (size_t i = 0; i < max_order_; ++i) { - sentence.push_back(START_TOKEN); - } - } else { - for (size_t i = 0; i < max_order_ - 1; ++i) { - sentence.push_back(START_TOKEN); +double Scorer::get_sent_log_prob(const std::vector& words) +{ + std::vector sentence; + if (words.size() == 0) { + for (size_t i = 0; i < max_order_; ++i) { + sentence.push_back(START_TOKEN); + } + } else { + for (size_t i = 0; i < max_order_ - 1; ++i) { + sentence.push_back(START_TOKEN); + } + sentence.insert(sentence.end(), words.begin(), words.end()); } - sentence.insert(sentence.end(), words.begin(), words.end()); - } - sentence.push_back(END_TOKEN); - return get_log_prob(sentence); + sentence.push_back(END_TOKEN); + return get_log_prob(sentence); } -double Scorer::get_log_prob(const std::vector& words) { - assert(words.size() > max_order_); - double score = 0.0; - for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) { - std::vector ngram(words.begin() + i, - words.begin() + i + max_order_); - score += get_log_cond_prob(ngram); - } - return score; +double Scorer::get_log_prob(const std::vector& words) +{ + assert(words.size() > max_order_); + double score = 0.0; + for (size_t i = 0; i < words.size() - max_order_ + 1; ++i) { + std::vector ngram(words.begin() + i, words.begin() + i + max_order_); + score += get_log_cond_prob(ngram); + } + return score; } -void Scorer::reset_params(float alpha, float beta) { - this->alpha = alpha; - this->beta = beta; +void Scorer::reset_params(float alpha, float beta) +{ + this->alpha = alpha; + this->beta = beta; } -std::string Scorer::vec2str(const std::vector& input) { - std::string word; - for (auto ind : input) { - word += char_list_[ind]; - } - return word; +std::string Scorer::vec2str(const std::vector& input) +{ + std::string word; + for (auto ind : input) { + word += char_list_[ind]; + } + return word; } -std::vector Scorer::split_labels(const std::vector& labels) { - if (labels.empty()) return {}; - - std::string s = vec2str(labels); - std::vector words; - if (is_character_based_) { - words = split_utf8_str(s); - } else { - words = split_str(s, " "); - } - return words; -} +std::vector Scorer::split_labels(const std::vector& labels) +{ + if (labels.empty()) + return {}; -void Scorer::set_char_map(const std::vector& char_list) { - char_list_ = char_list; - char_map_.clear(); + std::string s = vec2str(labels); + std::vector words; + if (is_character_based()) { + words = split_utf8_str(s); + } else { + words = split_str(s, " "); + } + return words; +} - for (size_t i = 0; i < char_list_.size(); i++) { - if (char_list_[i] == " ") { - SPACE_ID_ = i; +std::vector Scorer::make_ngram(PathTrie* prefix) +{ + std::vector ngram; + PathTrie* current_node = prefix; + PathTrie* new_node = nullptr; + + for (int order = 0; order < max_order_; ++order) { + std::vector prefix_vec; + std::vector prefix_steps; + + if (is_character_based() || is_bpe_based()) { + new_node = current_node->get_path_vec(prefix_vec, prefix_steps, -1, 1); + current_node = new_node; + } else { + new_node = current_node->get_path_vec(prefix_vec, prefix_steps, SPACE_ID_); + current_node = new_node->parent; // Skipping spaces + } + + // reconstruct word + std::string word = vec2str(prefix_vec); + ngram.push_back(word); + + if (new_node->character == -1) { + // No more spaces, but still need order + for (int i = 0; i < max_order_ - order - 1; ++i) { + ngram.push_back(START_TOKEN); + } + break; + } } - // The initial state of FST is state 0, hence the index of chars in - // the FST should start from 1 to avoid the conflict with the initial - // state, otherwise wrong decoding results would be given. - char_map_[char_list_[i]] = i + 1; - } + std::reverse(ngram.begin(), ngram.end()); + return ngram; } -std::vector Scorer::make_ngram(PathTrie* prefix) { - std::vector ngram; - PathTrie* current_node = prefix; - PathTrie* new_node = nullptr; +/** + * @brief Loads FST from the given path + * + * @param lexicon_fst_path, Path to the file containing the FST + */ +void Scorer::load_lexicon_from_fst_file(const std::string& lexicon_fst_path) +{ + + auto startTime = std::chrono::high_resolution_clock::now(); + fst::FstReadOptions read_options; + // Read the FST from the file + fst::StdVectorFst* dict = fst::StdVectorFst::Read(lexicon_fst_path); + if (!dict) { + std::cerr << "Failed to read FST from file: " << lexicon_fst_path << std::endl; + exit(EXIT_FAILURE); + } - for (int order = 0; order < max_order_; order++) { - std::vector prefix_vec; - std::vector prefix_steps; + auto endTime = std::chrono::high_resolution_clock::now(); + auto duration + = std::chrono::duration_cast(endTime - startTime).count(); + // Convert duration to seconds + auto seconds = duration / 1000000.0; - if (is_character_based_) { - new_node = current_node->get_path_vec(prefix_vec, prefix_steps, -1, 1); - current_node = new_node; - } else { - new_node = current_node->get_path_vec(prefix_vec, prefix_steps, SPACE_ID_); - current_node = new_node->parent; // Skipping spaces - } + std::cout << "Total time taken for reading the FST file: " << seconds << " seconds" + << std::endl; + + this->lexicon = dict; +} - // reconstruct word - std::string word = vec2str(prefix_vec); - ngram.push_back(word); +/** + * @brief Creates FST lexicon from the LM vocabulary or from the given FST + * + * @param add_space, whether to add space in the dictionary after each word + * @param lexicon_fst_path, Path to the file containing the FST + */ +void Scorer::load_lexicon(bool add_space, const std::string& lexicon_fst_path) +{ + fst::StdVectorFst lexicon; + // For each unigram convert to ints and put in trie + int dict_size = 0; + has_lexicon_ = true; + + if (lexicon_fst_path.empty()) { + for (const auto& word : vocabulary_) { + const auto& characters = split_utf8_str(word); + bool added + = add_word_to_lexicon(characters, char_map_, add_space, SPACE_ID_ + 1, &lexicon); + dict_size += added ? 1 : 0; + } + + dict_size_ = dict_size; + + /* Simplify FST + + * This gets rid of "epsilon" transitions in the FST. + * These are transitions that don't require a string input to be taken. + * Getting rid of them is necessary to make the FST determinisitc, but + * can greatly increase the size of the FST + */ + fst::RmEpsilon(&lexicon); + fst::StdVectorFst* new_lexicon = new fst::StdVectorFst; + + /* This makes the FST deterministic, meaning for any string input there's + * only one possible state the FST could be in. It is assumed our + * dictionary is deterministic when using it. + * (lest we'd have to check for multiple transitions at each state) + */ + fst::Determinize(lexicon, new_lexicon); + + /* Finds the simplest equivalent fst. This is unnecessary but decreases + * memory usage of the dictionary + */ + fst::Minimize(new_lexicon); + this->lexicon = new_lexicon; - if (new_node->character == -1) { - // No more spaces, but still need order - for (int i = 0; i < max_order_ - order - 1; i++) { - ngram.push_back(START_TOKEN); - } - break; + } else { + load_lexicon_from_fst_file(lexicon_fst_path); } - } - std::reverse(ngram.begin(), ngram.end()); - return ngram; -} -void Scorer::fill_dictionary(bool add_space) { - fst::StdVectorFst dictionary; - // For each unigram convert to ints and put in trie - int dict_size = 0; - for (const auto& word : vocabulary_) { - bool added = add_word_to_dictionary( - word, char_map_, add_space, SPACE_ID_ + 1, &dictionary); - dict_size += added ? 1 : 0; - } - - dict_size_ = dict_size; - - /* Simplify FST - - * This gets rid of "epsilon" transitions in the FST. - * These are transitions that don't require a string input to be taken. - * Getting rid of them is necessary to make the FST determinisitc, but - * can greatly increase the size of the FST - */ - fst::RmEpsilon(&dictionary); - fst::StdVectorFst* new_dict = new fst::StdVectorFst; - - /* This makes the FST deterministic, meaning for any string input there's - * only one possible state the FST could be in. It is assumed our - * dictionary is deterministic when using it. - * (lest we'd have to check for multiple transitions at each state) - */ - fst::Determinize(dictionary, new_dict); - - /* Finds the simplest equivalent fst. This is unnecessary but decreases - * memory usage of the dictionary - */ - fst::Minimize(new_dict); - this->dictionary = new_dict; + if (lexicon.NumStates() == 0) { + std::cout << "Lexicon is empty" << std::endl; + has_lexicon_ = false; + } } diff --git a/ctcdecode/src/scorer.h b/ctcdecode/src/scorer.h index 5ebc719c..5e7d8816 100644 --- a/ctcdecode/src/scorer.h +++ b/ctcdecode/src/scorer.h @@ -1,33 +1,40 @@ #ifndef SCORER_H_ #define SCORER_H_ -#include #include #include -#include #include "lm/enumerate_vocab.hh" #include "lm/virtual_interface.hh" #include "lm/word_index.hh" #include "util/string_piece.hh" +#include "decoder_utils.h" #include "path_trie.h" const double OOV_SCORE = -1000.0; const std::string START_TOKEN = ""; -const std::string UNK_TOKEN = ""; +const std::string UNK_TOKEN = "[UNK]"; const std::string END_TOKEN = ""; -// Implement a callback to retrive the dictionary of language model. +enum TokenizerType { CHAR = 0, BPE = 1, WORD = 2 }; + +static std::map StringToTokenizerType + = { { "character", TokenizerType::CHAR }, + { "bpe", TokenizerType::BPE }, + { "word", TokenizerType::WORD } }; + +// Implement a callback to retrive the lexicon of language model. class RetriveStrEnumerateVocab : public lm::EnumerateVocab { public: - RetriveStrEnumerateVocab() {} + RetriveStrEnumerateVocab() { } - void Add(lm::WordIndex index, const StringPiece &str) { - vocabulary.push_back(std::string(str.data(), str.length())); - } + void Add(lm::WordIndex index, const StringPiece& str) + { + vocabulary.push_back(std::string(str.data(), str.length())); + } - std::vector vocabulary; + std::vector vocabulary; }; /* External scorer to query score for n-gram or sentence, including language @@ -40,73 +47,89 @@ class RetriveStrEnumerateVocab : public lm::EnumerateVocab { */ class Scorer { public: - Scorer(double alpha, - double beta, - const std::string &lm_path, - const std::vector &vocabulary); - ~Scorer(); + Scorer(double alpha, + double beta, + const std::string& lm_path, + const std::vector& vocabulary, + const std::string& lm_type, + const std::string& lexicon_fst_path); + ~Scorer(); - double get_log_cond_prob(const std::vector &words); + double get_log_cond_prob(const std::vector& words); - double get_sent_log_prob(const std::vector &words); + double get_sent_log_prob(const std::vector& words); - // return the max order - size_t get_max_order() const { return max_order_; } + // return the max order + size_t get_max_order() const { return max_order_; } - // return the dictionary size of language model - size_t get_dict_size() const { return dict_size_; } + // return the lexicon size of language model + size_t get_lexicon_size() const { return dict_size_; } - // retrun true if the language model is character based - bool is_character_based() const { return is_character_based_; } + // retrun true if the language model is character based + bool is_character_based() const { return lm_type == TokenizerType::CHAR; } - // reset params alpha & beta - void reset_params(float alpha, float beta); + // retrun true if the language model is bpe based + bool is_bpe_based() const { return lm_type == TokenizerType::BPE; } + // retrun true if the language model is word based + bool is_word_based() const { return lm_type == TokenizerType::WORD; } - // make ngram for a given prefix - std::vector make_ngram(PathTrie *prefix); + bool has_lexicon() const { return has_lexicon_; } - // trransform the labels in index to the vector of words (word based lm) or - // the vector of characters (character based lm) - std::vector split_labels(const std::vector &labels); + std::unordered_map get_char_map() { return char_map_; } - // language model weight - double alpha; - // word insertion weight - double beta; + std::vector get_char_list() { return char_list_; } - // pointer to the dictionary of FST - void *dictionary; + // reset params alpha & beta + void reset_params(float alpha, float beta); -protected: - // necessary setup: load language model, set char map, fill FST's dictionary - void setup(const std::string &lm_path, - const std::vector &vocab_list); + // make ngram for a given prefix + std::vector make_ngram(PathTrie* prefix); - // load language model from given path - void load_lm(const std::string &lm_path); + // trransform the labels in index to the vector of words (word based lm) or + // the vector of characters (character based lm) + std::vector split_labels(const std::vector& labels); - // fill dictionary for FST - void fill_dictionary(bool add_space); + // language model weight + double alpha; + // word insertion weight + double beta; - // set char map - void set_char_map(const std::vector &char_list); + // Whether the lm is character based, or bpe based, or word based + TokenizerType lm_type; - double get_log_prob(const std::vector &words); + // pointer to the lexicon of FST + void* lexicon; - // translate the vector in index to string - std::string vec2str(const std::vector &input); +protected: + // necessary setup: load language model, set char map, fill FST's lexicon + void setup(const std::string& lm_path, + const std::vector& vocab_list, + const std::string& lexicon_fst_path); -private: - void *language_model_; - bool is_character_based_; - size_t max_order_; - size_t dict_size_; + // load language model from given path + void load_lm(const std::string& lm_path); - int SPACE_ID_; - std::vector char_list_; - std::unordered_map char_map_; + // fill lexicon for FST + void load_lexicon(bool add_space, const std::string& lexicon_fst_path); - std::vector vocabulary_; + // load FST from given path + void load_lexicon_from_fst_file(const std::string& lexicon_fst_path); + + double get_log_prob(const std::vector& words); + + // translate the vector in index to string + std::string vec2str(const std::vector& input); + +private: + void* language_model_; + size_t max_order_; + size_t dict_size_; + int SPACE_ID_; + bool has_lexicon_; + std::vector char_list_; + std::unordered_map char_map_; + + std::vector vocabulary_; }; -#endif // SCORER_H_ +#endif // SCORER_H_ diff --git a/format.sh b/format.sh new file mode 100644 index 00000000..6dbbadfc --- /dev/null +++ b/format.sh @@ -0,0 +1,5 @@ +black --line-length 100 ctcdecode/ tests/ setup.py + +isort ctcdecode/ tests/ + +clang-format -i ctcdecode/**/*.cpp ctcdecode/src/*.h tools/* \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 08ed5eeb..6faa4ae4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,5 @@ -torch \ No newline at end of file +torch +black +isort +clang-format +pytest \ No newline at end of file diff --git a/setup.py b/setup.py index a89f24ab..a2987442 100644 --- a/setup.py +++ b/setup.py @@ -24,8 +24,8 @@ def download_extract(url, dl_path): # Download/Extract openfst, boost download_extract( - "https://github.com/parlance/ctcdecode/releases/download/v1.0/openfst-1.6.7.tar.gz", - "third_party/openfst-1.6.7.tar.gz", + "https://www.openfst.org/twiki/pub/FST/FstDownload/openfst-1.8.2.tar.gz", + "third_party/openfst-1.8.2.tar.gz", ) download_extract( "https://github.com/parlance/ctcdecode/releases/download/v1.0/boost_1_67_0.tar.gz", @@ -34,7 +34,11 @@ def download_extract(url, dl_path): for file in ["third_party/kenlm/setup.py", "third_party/ThreadPool/ThreadPool.h"]: if not os.path.exists(file): - warnings.warn("File `{}` does not appear to be present. Did you forget `git submodule update`?".format(file)) + warnings.warn( + "File `{}` does not appear to be present. Did you forget `git submodule update`?".format( + file + ) + ) # Does gcc compile with this header and library? @@ -54,7 +58,7 @@ def compile_test(header, library): return os.system(command) == 0 -compile_args = ["-O3", "-DKENLM_MAX_ORDER=6", "-std=c++14", "-fPIC"] +compile_args = ["-O3", "-DKENLM_MAX_ORDER=6", "-std=c++17", "-fPIC"] ext_libs = [] if compile_test("zlib.h", "z"): compile_args.append("-DHAVE_ZLIB") @@ -68,17 +72,27 @@ def compile_test(header, library): compile_args.append("-DHAVE_XZLIB") ext_libs.append("lzma") -third_party_libs = ["kenlm", "openfst-1.6.7/src/include", "ThreadPool", "boost_1_67_0", "utf8"] -compile_args.extend(["-DINCLUDE_KENLM", "-DKENLM_MAX_ORDER=6"]) +third_party_libs = [ + "kenlm", + "openfst-1.8.2/src/include", + "ThreadPool", + "boost_1_67_0", + "utf8", +] +compile_args.extend(["-DINCLUDE_KENLM", "-DKENLM_MAX_ORDER=6 "]) lib_sources = ( glob.glob("third_party/kenlm/util/*.cc") + glob.glob("third_party/kenlm/lm/*.cc") + glob.glob("third_party/kenlm/util/double-conversion/*.cc") - + glob.glob("third_party/openfst-1.6.7/src/lib/*.cc") + + glob.glob("third_party/openfst-1.8.2/src/lib/*.cc") ) -lib_sources = [fn for fn in lib_sources if not (fn.endswith("main.cc") or fn.endswith("test.cc"))] +lib_sources = [ + fn for fn in lib_sources if not (fn.endswith("main.cc") or fn.endswith("test.cc")) +] -third_party_includes = [os.path.realpath(os.path.join("third_party", lib)) for lib in third_party_libs] +third_party_includes = [ + os.path.realpath(os.path.join("third_party", lib)) for lib in third_party_libs +] ctc_sources = glob.glob("ctcdecode/src/*.cpp") extension = CppExtension( @@ -131,7 +145,7 @@ def _single_compile(obj): setup( name="ctcdecode", - version="1.0.3", + version="1.0.4", description="CTC Decoder for PyTorch based on Paddle Paddle's implementation", url="https://github.com/parlance/ctcdecode", author="Ryan Leary", diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt new file mode 100644 index 00000000..83c8fbb6 --- /dev/null +++ b/tests/cpp/CMakeLists.txt @@ -0,0 +1,29 @@ +cmake_minimum_required(VERSION 3.16 FATAL_ERROR) +set(CMAKE_CXX_STANDARD 17) + +# TEST EXECUTABLE + +# add google test library +include(FetchContent) +FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG v1.14.0 +) +FetchContent_MakeAvailable(googletest) + +# Enable testing +enable_testing() + +# Create an executable for the tests +add_executable(build_fst_test ${CMAKE_SOURCE_DIR}/tests/cpp/test_build_fst.cpp ) + +# Link Google Test to the test executable +target_link_libraries(build_fst_test gtest gtest_main build_fst_lib) +target_sources(build_fst_test PRIVATE ${CMAKE_SOURCE_DIR}/tools/build_fst.cpp) +target_compile_definitions(build_fst_test PUBLIC TEST_FIXTURES_DIR="${CMAKE_SOURCE_DIR}/tests/cpp/fixtures") + + +# Add the tests to CTest +include(GoogleTest) +gtest_discover_tests(build_fst_test) \ No newline at end of file diff --git a/tests/cpp/fixtures/expected_fst.fst b/tests/cpp/fixtures/expected_fst.fst new file mode 100644 index 00000000..f7ffd99d Binary files /dev/null and b/tests/cpp/fixtures/expected_fst.fst differ diff --git a/tests/cpp/fixtures/lexicon.txt b/tests/cpp/fixtures/lexicon.txt new file mode 100644 index 00000000..3789088e --- /dev/null +++ b/tests/cpp/fixtures/lexicon.txt @@ -0,0 +1,7 @@ +1 aalborg a ##al ##b ##or ##g +1 aalesund a ##al ##es ##un ##d +1 aalii a ##al ##i ##i +1 aaliis a ##al ##i ##is +1 aals a ##al ##s +1 aalst a ##al ##st +1 aalto a ##al ##t ##o \ No newline at end of file diff --git a/tests/cpp/fixtures/vocab.txt b/tests/cpp/fixtures/vocab.txt new file mode 100644 index 00000000..f2fe4b61 --- /dev/null +++ b/tests/cpp/fixtures/vocab.txt @@ -0,0 +1,512 @@ +[UNK] +' +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +##i +##g +##r +##e +##f +##o +##n +##d +##t +##c +##a +##s +##h +##x +##y +##u +##p +##l +##v +##z +##b +##w +##m +##k +##q +##j +th +the +##er +##in +##at +##nd +##ou +##re +##on +##en +##ed +to +of +and +##es +##ing +##is +##or +##an +in +##it +##ll +##ar +##as +that +wh +##ic +##om +##ion +be +##al +is +##le +it +##ow +##ve +##ent +ha +##ot +you +##ut +##se +re +on +he +##st +##id +##ct +##et +##ly +##ld +##ay +we +##gh +for +##ce +was +##im +st +so +this +as +sh +##ur +##ith +##ir +##ver +##ro +##ch +with +an +not +##ere +##ad +##am +##ght +se +##all +##ter +##ation +at +##ill +have +##ould +##th +his +de +con +but +##ain +su +are +##ri +##ke +##oo +##il +me +they +or +do +com +there +##ight +had +fr +##pe +ne +##est +##if +##ul +##ore +ch +##ich +##ess +##ge +##ate +ex +pro +by +all +##us +##ant +which +her +one +##pp +what +##res +##ers +go +al +from +sa +##ck +##ment +##our +she +if +kn +po +##and +##ble +will +can +##ust +li +##ra +##art +##ist +##ome +##um +him +##ity +##her +##ive +ab +know +##ard +##ally +my +##out +no +##qu +##ous +##nt +##em +would +were +up +##au +##ig +le +##ol +tw +##ind +##ort +when +tr +##ure +who +out +then +wor +pl +##el +man +un +##un +##ie +##ple +##ast +int +##ven +co +us +said +fa +ma +##ction +##ame +mo +en +now +some +ag +ar +##red +them +##pt +say +##ong +##ound +cl +like +their +qu +bec +##ence +##ies +##ook +tim +##ine +##one +has +##ose +about +been +ta +##os +##ery +##ough +your +##ood +##ther +any +two +##ink +am +##per +im +here +##pl +dis +##ud +##nder +af +##ause +##sel +##ack +##ect +more +##ase +##iv +##ree +fe +ad +##itt +part +pr +##fer +very +how +just +ye +time +##ue +these +##ance +did +##ans +##ish +see +##ice +##reat +could +think +look +other +##ide +##ry +sp +res +##ab +into +##ci +because +##age +cour +cont +lo +our +##ress +##cc +than +##ag +gr +per +##hing +##ite +##ap +##own +##ire +pe +##are +##ber +te +br +ever +##ass +##ty +##king +get +##able +##so +##ach +over +##ated +ro +where +##ary +##ition +well +##int +##ings +right +##fore +under +bo +again +##ace +##ars +##ount +only +fl +##self +##end +##ord +should +##ade +also +##iz +ah +don +##ff +##een +fir +spe +after +##way +may +##ss +pre +ind +mr +let +##ru +##co +off +act +bl +going +way +##ep +##wn +day +inter +##other +comp +new +##ved +first +##ict +dif +##ations +##ial +app +thr +comm +its +##ign +##ile +##urn +##ens +imp +##ang +case +##ied +bet +##cl +pres +every +those +peo +court +##ving +##ip +people +even +##ces +mu +##orm +##ought +ob +##ical +##ittle +cons +rem +##ious +##ked +little +##ater +good +want +back +point +##ions +ho +before +too +##od +call +great +fo +##omet +down +##ult +rep +made +##ild +##ak +##erm +##ert +three +##ise +mar +##ced +hand +des +##igh +##ix +##ath +much +min +##iss +la +sm +come +make +through +somet +differ +att +upon +sc +dist +##ward +take +##ory diff --git a/tests/cpp/test_build_fst.cpp b/tests/cpp/test_build_fst.cpp new file mode 100644 index 00000000..f1738831 --- /dev/null +++ b/tests/cpp/test_build_fst.cpp @@ -0,0 +1,28 @@ +#include +#include + +#include "build_fst.h" + +// create a sample fst from the lexicon file and +// compare it with the expected fst +TEST(BuildFstTest, TestBuildFst) +{ + std::vector lexicon_paths = { std::string(TEST_FIXTURES_DIR) + "/lexicon.txt" }; + std::string label_path = std::string(TEST_FIXTURES_DIR) + "/vocab.txt"; + std::string expected_fst_path = std::string(TEST_FIXTURES_DIR) + "/expected_fst.fst"; + std::string output_fst_path = ::testing::TempDir() + "/test_output.fst"; + int freq_threshold = 0; + + construct_fst(label_path, lexicon_paths, "", output_fst_path, freq_threshold, false); + + auto output_fst = read_fst(output_fst_path); + auto expected_fst = read_fst(expected_fst_path); + EXPECT_EQ(output_fst->NumStates(), expected_fst->NumStates()); +} + +int main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + + return RUN_ALL_TESTS(); +} \ No newline at end of file diff --git a/tests/test.arpa b/tests/python/test.arpa similarity index 94% rename from tests/test.arpa rename to tests/python/test.arpa index c4d2e6df..4b0f7888 100644 --- a/tests/test.arpa +++ b/tests/python/test.arpa @@ -1,7 +1,7 @@ \data\ -ngram 1=37 -ngram 2=47 +ngram 1=39 +ngram 2=52 ngram 3=11 ngram 4=6 ngram 5=4 @@ -44,6 +44,8 @@ ngram 5=4 -3.141592 foo -2.718281 bar 3.0 -6.535897 baz -0.0 +-1.687872 bugs -0.30103 +-1.687872 bunny -0.30103 \2-grams: -0.6925742 , . @@ -93,6 +95,12 @@ ngram 5=4 -15 -2 -4 however -1 -6 foo bar +-0.2922095 bunny +-10 bugs +-0.2922095 bugs +-10 bunny +-0.2922095 bugs bunny + \3-grams: -0.01916512 more . diff --git a/tests/python/test_decode.py b/tests/python/test_decode.py new file mode 100644 index 00000000..40248864 --- /dev/null +++ b/tests/python/test_decode.py @@ -0,0 +1,517 @@ +"""Test decoders.""" +from __future__ import absolute_import, division, print_function + +import os +import unittest + +import torch + +import ctcdecode + + +class TestDecoders(unittest.TestCase): + def setUp(self): + self.vocab_list = ["'", " ", "a", "b", "c", "d", "_"] + self.beam_size = 20 + self.probs_seq1 = [ + [ + 0.06390443, + 0.21124858, + 0.27323887, + 0.06870235, + 0.0361254, + 0.18184413, + 0.16493624, + ], + [ + 0.03309247, + 0.22866108, + 0.24390638, + 0.09699597, + 0.31895462, + 0.0094893, + 0.06890021, + ], + [ + 0.218104, + 0.19992557, + 0.18245131, + 0.08503348, + 0.14903535, + 0.08424043, + 0.08120984, + ], + [ + 0.12094152, + 0.19162472, + 0.01473646, + 0.28045061, + 0.24246305, + 0.05206269, + 0.09772094, + ], + [ + 0.1333387, + 0.00550838, + 0.00301669, + 0.21745861, + 0.20803985, + 0.41317442, + 0.01946335, + ], + [ + 0.16468227, + 0.1980699, + 0.1906545, + 0.18963251, + 0.19860937, + 0.04377724, + 0.01457421, + ], + ] + self.probs_seq2 = [ + [ + 0.08034842, + 0.22671944, + 0.05799633, + 0.36814645, + 0.11307441, + 0.04468023, + 0.10903471, + ], + [ + 0.09742457, + 0.12959763, + 0.09435383, + 0.21889204, + 0.15113123, + 0.10219457, + 0.20640612, + ], + [ + 0.45033529, + 0.09091417, + 0.15333208, + 0.07939558, + 0.08649316, + 0.12298585, + 0.01654384, + ], + [ + 0.02512238, + 0.22079203, + 0.19664364, + 0.11906379, + 0.07816055, + 0.22538587, + 0.13483174, + ], + [ + 0.17928453, + 0.06065261, + 0.41153005, + 0.1172041, + 0.11880313, + 0.07113197, + 0.04139363, + ], + [ + 0.15882358, + 0.1235788, + 0.23376776, + 0.20510435, + 0.00279306, + 0.05294827, + 0.22298418, + ], + ] + self.greedy_result = ["ac'bdc", "b'da"] + self.beam_search_result = ["acdc", "b'a", "a a"] + + def convert_to_string(self, tokens, vocab, seq_len): + return "".join([vocab[x] for x in tokens[0:seq_len]]) + + def test_beam_search_decoder_1(self): + probs_seq = torch.FloatTensor([self.probs_seq1]) + decoder = ctcdecode.CTCBeamDecoder( + self.vocab_list, + beam_width=self.beam_size, + blank_id=self.vocab_list.index("_"), + ) + beam_result, beam_scores, timesteps, out_seq_len = decoder.decode(probs_seq) + output_str = self.convert_to_string( + beam_result[0][0], self.vocab_list, out_seq_len[0][0] + ) + self.assertEqual(output_str, self.beam_search_result[0]) + + def test_beam_search_decoder_2(self): + probs_seq = torch.FloatTensor([self.probs_seq2]) + decoder = ctcdecode.CTCBeamDecoder( + self.vocab_list, + beam_width=self.beam_size, + blank_id=self.vocab_list.index("_"), + ) + beam_result, beam_scores, timesteps, out_seq_len = decoder.decode(probs_seq) + output_str = self.convert_to_string( + beam_result[0][0], self.vocab_list, out_seq_len[0][0] + ) + self.assertEqual(output_str, self.beam_search_result[1]) + + def test_beam_search_decoder_3(self): + lm_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test.arpa") + probs_seq = torch.FloatTensor([self.probs_seq2]) + + decoder = ctcdecode.CTCBeamDecoder( + self.vocab_list, + beam_width=self.beam_size, + blank_id=self.vocab_list.index("_"), + model_path=lm_path, + ) + beam_result, beam_scores, timesteps, out_seq_len = decoder.decode(probs_seq) + output_str = self.convert_to_string( + beam_result[0][0], self.vocab_list, out_seq_len[0][0] + ) + self.assertEqual(output_str, self.beam_search_result[2]) + + def test_beam_search_decoder_batch(self): + probs_seq = torch.FloatTensor([self.probs_seq1, self.probs_seq2]) + decoder = ctcdecode.CTCBeamDecoder( + self.vocab_list, + beam_width=self.beam_size, + blank_id=self.vocab_list.index("_"), + num_processes=24, + ) + beam_results, beam_scores, timesteps, out_seq_len = decoder.decode(probs_seq) + output_str1 = self.convert_to_string( + beam_results[0][0], self.vocab_list, out_seq_len[0][0] + ) + output_str2 = self.convert_to_string( + beam_results[1][0], self.vocab_list, out_seq_len[1][0] + ) + self.assertEqual(output_str1, self.beam_search_result[0]) + self.assertEqual(output_str2, self.beam_search_result[1]) + del decoder + + def test_beam_search_decoder_batch_log(self): + probs_seq = torch.FloatTensor([self.probs_seq1, self.probs_seq2]).log() + decoder = ctcdecode.CTCBeamDecoder( + self.vocab_list, + beam_width=self.beam_size, + blank_id=self.vocab_list.index("_"), + log_probs_input=True, + num_processes=24, + ) + beam_results, beam_scores, timesteps, out_seq_len = decoder.decode(probs_seq) + output_str1 = self.convert_to_string( + beam_results[0][0], self.vocab_list, out_seq_len[0][0] + ) + output_str2 = self.convert_to_string( + beam_results[1][0], self.vocab_list, out_seq_len[1][0] + ) + self.assertEqual(output_str1, self.beam_search_result[0]) + self.assertEqual(output_str2, self.beam_search_result[1]) + + def test_online_decoder_decoding(self): + lm_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test.arpa") + decoder = ctcdecode.OnlineCTCBeamDecoder( + self.vocab_list, + beam_width=self.beam_size, + blank_id=self.vocab_list.index("_"), + log_probs_input=True, + num_processes=24, + model_path=lm_path, + ) + state1 = ctcdecode.DecoderState(decoder) + state2 = ctcdecode.DecoderState(decoder) + + probs_seq = torch.FloatTensor([self.probs_seq2, self.probs_seq2]).log() + + is_eos_s = [True for _ in range(len(probs_seq))] + + beam_results, beam_scores, timesteps, out_seq_len = decoder.decode( + probs_seq, [state1, state2], is_eos_s + ) + output_str1 = self.convert_to_string( + beam_results[0][0], self.vocab_list, out_seq_len[0][0] + ) + output_str2 = self.convert_to_string( + beam_results[1][0], self.vocab_list, out_seq_len[1][0] + ) + + self.assertEqual(output_str1, self.beam_search_result[2]) + self.assertEqual(output_str2, self.beam_search_result[2]) + + def test_online_decoder_decoding_no_lm(self): + decoder = ctcdecode.OnlineCTCBeamDecoder( + self.vocab_list, + beam_width=self.beam_size, + blank_id=self.vocab_list.index("_"), + log_probs_input=True, + num_processes=24, + ) + state1 = ctcdecode.DecoderState(decoder) + state2 = ctcdecode.DecoderState(decoder) + + probs_seq = torch.FloatTensor([self.probs_seq1, self.probs_seq2]).log() + + is_eos_s = [True for _ in range(len(probs_seq))] + + beam_results, beam_scores, timesteps, out_seq_len = decoder.decode( + probs_seq, [state1, state2], is_eos_s + ) + output_str1 = self.convert_to_string( + beam_results[0][0], self.vocab_list, out_seq_len[0][0] + ) + output_str2 = self.convert_to_string( + beam_results[1][0], self.vocab_list, out_seq_len[1][0] + ) + + self.assertEqual(output_str1, self.beam_search_result[0]) + self.assertEqual(output_str2, self.beam_search_result[1]) + + def test_online_decoder_decoding_with_two_calls(self): + lm_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test.arpa") + decoder = ctcdecode.OnlineCTCBeamDecoder( + self.vocab_list, + beam_width=self.beam_size, + blank_id=self.vocab_list.index("_"), + log_probs_input=True, + num_processes=24, + model_path=lm_path, + ) + state1 = ctcdecode.DecoderState(decoder) + + probs_seq = torch.FloatTensor([self.probs_seq2]).log() + + beam_results, beam_scores, timesteps, out_seq_len = decoder.decode( + probs_seq[:, :2], [state1], [False] + ) + beam_results, beam_scores, timesteps, out_seq_len = decoder.decode( + probs_seq[:, 2:], [state1], [True] + ) + + output_str1 = self.convert_to_string( + beam_results[0][0], self.vocab_list, out_seq_len[0][0] + ) + self.assertEqual(output_str1, self.beam_search_result[2]) + + def test_online_decoder_decoding_with_two_calls_no_lm(self): + decoder = ctcdecode.OnlineCTCBeamDecoder( + self.vocab_list, + beam_width=self.beam_size, + blank_id=self.vocab_list.index("_"), + log_probs_input=True, + num_processes=24, + ) + state1 = ctcdecode.DecoderState(decoder) + state2 = ctcdecode.DecoderState(decoder) + + probs_seq = torch.FloatTensor([self.probs_seq1, self.probs_seq2]).log() + + beam_results, beam_scores, timesteps, out_seq_len = decoder.decode( + probs_seq[:, :2], [state1, state2], [False, False] + ) + beam_results, beam_scores, timesteps, out_seq_len = decoder.decode( + probs_seq[:, 2:], [state1, state2], [True, True] + ) + + del state1, state2 + size = beam_results.shape + output_str1 = self.convert_to_string( + beam_results[0][0], self.vocab_list, out_seq_len[0][0] + ) + output_str2 = self.convert_to_string( + beam_results[1][0], self.vocab_list, out_seq_len[1][0] + ) + + self.assertEqual(output_str1, self.beam_search_result[0]) + self.assertEqual(output_str2, self.beam_search_result[1]) + + def test_online_decoder_decoding_with_a_lot_calls_no_lm_check_size(self): + decoder = ctcdecode.OnlineCTCBeamDecoder( + self.vocab_list, + beam_width=self.beam_size, + blank_id=self.vocab_list.index("_"), + log_probs_input=True, + num_processes=24, + ) + state1 = ctcdecode.DecoderState(decoder) + + probs_seq = torch.FloatTensor([self.probs_seq1]).log() + + for i in range(1000): + beam_results, beam_scores, timesteps, out_seq_len = decoder.decode( + probs_seq, [state1], [False, False] + ) + + beam_results, beam_scores, timesteps, out_seq_len = decoder.decode( + probs_seq, [state1], [True, True] + ) + + del state1 + self.assertGreaterEqual(beam_results.shape[2], out_seq_len.max()) + + def test_hotwords(self): + SAMPLE_LABELS = [" ", "b", "g", "n", "s", "u", "y", ""] + SAMPLE_VOCAB = {c: n for n, c in enumerate(SAMPLE_LABELS)} + + BUGS_PROBS = torch.zeros((4, len(SAMPLE_VOCAB))) + BUGS_PROBS[0][SAMPLE_VOCAB.get("b")] = 1 + BUGS_PROBS[1][SAMPLE_VOCAB.get("u")] = 1 + BUGS_PROBS[2][SAMPLE_VOCAB.get("g")] = 1 + BUGS_PROBS[3][SAMPLE_VOCAB.get("s")] = 1 + + BUNNY_PROBS = torch.zeros((6, len(SAMPLE_VOCAB))) + BUNNY_PROBS[0][SAMPLE_VOCAB.get("b")] = 1 + BUNNY_PROBS[1][SAMPLE_VOCAB.get("u")] = 1 + BUNNY_PROBS[2][SAMPLE_VOCAB.get("n")] = 1 + BUNNY_PROBS[3][SAMPLE_VOCAB.get("")] = 1 + BUNNY_PROBS[4][SAMPLE_VOCAB.get("n")] = 1 + BUNNY_PROBS[5][SAMPLE_VOCAB.get("y")] = 1 + + BLANK_PROBS = torch.zeros((1, len(SAMPLE_VOCAB))) + BLANK_PROBS[0][SAMPLE_VOCAB.get("")] = 1 + SPACE_PROBS = torch.zeros((1, len(SAMPLE_VOCAB))) + SPACE_PROBS[0][SAMPLE_VOCAB.get(" ")] = 1 + + # make mixed version that can get fixed with LM + TEST_PROBS = torch.vstack( + [ + torch.vstack([BUGS_PROBS, BLANK_PROBS, BLANK_PROBS]) * 0.51 + + BUNNY_PROBS * 0.49, + SPACE_PROBS, + BUNNY_PROBS, + ] + ) + + # without lm and without hotwords + decoder = ctcdecode.CTCBeamDecoder(SAMPLE_LABELS, blank_id=7, beam_width=100) + beam_result, _, _, out_seq_len = decoder.decode( + torch.unsqueeze(TEST_PROBS, dim=0) + ) + output_str = self.convert_to_string( + beam_result[0][0], SAMPLE_LABELS, out_seq_len[0][0] + ) + self.assertEqual(output_str, "bugs bunny") + + # without lm and with hotwords + beam_result, _, _, out_seq_len = decoder.decode( + torch.unsqueeze(TEST_PROBS, dim=0), + hotwords=[["b", "u", "n", "n", "y"]], + hotword_weight=10, + ) + output_str = self.convert_to_string( + beam_result[0][0], SAMPLE_LABELS, out_seq_len[0][0] + ) + self.assertEqual(output_str, "bunny bunny") + + lm_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test.arpa") + + # with lm and with hotwords + lm_decoder = ctcdecode.CTCBeamDecoder( + SAMPLE_LABELS, model_path=lm_path, blank_id=7, beam_width=100 + ) + + beam_result, _, _, out_seq_len = lm_decoder.decode( + torch.unsqueeze(TEST_PROBS, dim=0), + hotwords=[["b", "u", "n", "n", "y"]], + hotword_weight=10, + ) + output_str = self.convert_to_string( + beam_result[0][0], SAMPLE_LABELS, out_seq_len[0][0] + ) + self.assertEqual(output_str, "bunny bunny") + + TEST_PROBS = torch.vstack( + [ + torch.vstack([BUGS_PROBS, BLANK_PROBS, BLANK_PROBS]) * 0.51 + + BUNNY_PROBS * 0.49, + SPACE_PROBS, + BUNNY_PROBS, + ] + ) + + beam_result, _, _, out_seq_len = lm_decoder.decode( + torch.unsqueeze(TEST_PROBS, dim=0), + hotwords=[["b", "u", "g", "s"]], + hotword_weight=10, + ) + output_str = self.convert_to_string( + beam_result[0][0], SAMPLE_LABELS, out_seq_len[0][0] + ) + self.assertEqual(output_str, "bugs bunny") + + # hotword as a phrase + beam_result, _, _, out_seq_len = decoder.decode( + torch.unsqueeze(TEST_PROBS, dim=0), + hotwords=[list("bunny bunny")], + hotword_weight=10, + ) + output_str = self.convert_to_string( + beam_result[0][0], SAMPLE_LABELS, out_seq_len[0][0] + ) + self.assertEqual(output_str, "bunny bunny") + + def test_hotwords_with_small_input(self): + probs = [ + [0.1, 0.2, 0.2, 0.1], + [0.4, 0.4, 0.1, 0.3], + [0.2, 0.3, 0.1, 0.4], + [0.3, 0.1, 0.6, 0.2], + ] + labels = ["_", "a", "b", " "] + probs = torch.Tensor(probs) + probs = torch.unsqueeze(probs.transpose(0, 1), dim=0) + + # without hotword + decoder = ctcdecode.CTCBeamDecoder(labels, blank_id=0, beam_width=100) + beam_result, _, _, out_seq_len = decoder.decode(probs) + output_str = self.convert_to_string( + beam_result[0][0], labels, out_seq_len[0][0] + ) + self.assertEqual(output_str, "a b") + + # with hotword a + beam_result, _, _, out_seq_len = decoder.decode( + probs, hotwords=[["a"]], hotword_weight=10 + ) + output_str = self.convert_to_string( + beam_result[0][0], labels, out_seq_len[0][0] + ) + self.assertEqual(output_str, "a a") + + # with hotword b + beam_result, _, _, out_seq_len = decoder.decode( + probs, hotwords=[["b"]], hotword_weight=10 + ) + output_str = self.convert_to_string( + beam_result[0][0], labels, out_seq_len[0][0] + ) + self.assertEqual(output_str, "b b") + + # with hotword "b b" + beam_result, _, _, out_seq_len = decoder.decode( + probs, hotwords=[list("b b")], hotword_weight=10 + ) + output_str = self.convert_to_string( + beam_result[0][0], labels, out_seq_len[0][0] + ) + self.assertEqual(output_str, "b b") + + # test for passing hotword scorer to decoder call with hotword "b b" + hotword_scorer = decoder.create_hotword_scorer( + hotwords=[list("b b")], hotword_weight=10 + ) + beam_result, _, _, out_seq_len = decoder.decode( + probs, + hotword_scorer=hotword_scorer, + ) + output_str = self.convert_to_string( + beam_result[0][0], labels, out_seq_len[0][0] + ) + decoder.delete_hotword_scorer(hotword_scorer) + self.assertEqual(output_str, "b b") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_decode.py b/tests/test_decode.py deleted file mode 100644 index 2dcb83f1..00000000 --- a/tests/test_decode.py +++ /dev/null @@ -1,215 +0,0 @@ -"""Test decoders.""" -from __future__ import absolute_import, division, print_function - -import os -import unittest - -import ctcdecode -import torch - - -class TestDecoders(unittest.TestCase): - def setUp(self): - self.vocab_list = ["'", " ", "a", "b", "c", "d", "_"] - self.beam_size = 20 - self.probs_seq1 = [ - [0.06390443, 0.21124858, 0.27323887, 0.06870235, 0.0361254, 0.18184413, 0.16493624], - [0.03309247, 0.22866108, 0.24390638, 0.09699597, 0.31895462, 0.0094893, 0.06890021], - [0.218104, 0.19992557, 0.18245131, 0.08503348, 0.14903535, 0.08424043, 0.08120984], - [0.12094152, 0.19162472, 0.01473646, 0.28045061, 0.24246305, 0.05206269, 0.09772094], - [0.1333387, 0.00550838, 0.00301669, 0.21745861, 0.20803985, 0.41317442, 0.01946335], - [0.16468227, 0.1980699, 0.1906545, 0.18963251, 0.19860937, 0.04377724, 0.01457421], - ] - self.probs_seq2 = [ - [0.08034842, 0.22671944, 0.05799633, 0.36814645, 0.11307441, 0.04468023, 0.10903471], - [0.09742457, 0.12959763, 0.09435383, 0.21889204, 0.15113123, 0.10219457, 0.20640612], - [0.45033529, 0.09091417, 0.15333208, 0.07939558, 0.08649316, 0.12298585, 0.01654384], - [0.02512238, 0.22079203, 0.19664364, 0.11906379, 0.07816055, 0.22538587, 0.13483174], - [0.17928453, 0.06065261, 0.41153005, 0.1172041, 0.11880313, 0.07113197, 0.04139363], - [0.15882358, 0.1235788, 0.23376776, 0.20510435, 0.00279306, 0.05294827, 0.22298418], - ] - self.greedy_result = ["ac'bdc", "b'da"] - self.beam_search_result = ["acdc", "b'a", "a a"] - - def convert_to_string(self, tokens, vocab, seq_len): - return "".join([vocab[x] for x in tokens[0:seq_len]]) - - def test_beam_search_decoder_1(self): - probs_seq = torch.FloatTensor([self.probs_seq1]) - decoder = ctcdecode.CTCBeamDecoder( - self.vocab_list, beam_width=self.beam_size, blank_id=self.vocab_list.index("_") - ) - beam_result, beam_scores, timesteps, out_seq_len = decoder.decode(probs_seq) - output_str = self.convert_to_string(beam_result[0][0], self.vocab_list, out_seq_len[0][0]) - self.assertEqual(output_str, self.beam_search_result[0]) - - def test_beam_search_decoder_2(self): - probs_seq = torch.FloatTensor([self.probs_seq2]) - decoder = ctcdecode.CTCBeamDecoder( - self.vocab_list, beam_width=self.beam_size, blank_id=self.vocab_list.index("_") - ) - beam_result, beam_scores, timesteps, out_seq_len = decoder.decode(probs_seq) - output_str = self.convert_to_string(beam_result[0][0], self.vocab_list, out_seq_len[0][0]) - self.assertEqual(output_str, self.beam_search_result[1]) - - def test_beam_search_decoder_3(self): - lm_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test.arpa") - probs_seq = torch.FloatTensor([self.probs_seq2]) - - decoder = ctcdecode.CTCBeamDecoder( - self.vocab_list, beam_width=self.beam_size, blank_id=self.vocab_list.index("_"), model_path=lm_path - ) - beam_result, beam_scores, timesteps, out_seq_len = decoder.decode(probs_seq) - output_str = self.convert_to_string(beam_result[0][0], self.vocab_list, out_seq_len[0][0]) - self.assertEqual(output_str, self.beam_search_result[2]) - - def test_beam_search_decoder_batch(self): - probs_seq = torch.FloatTensor([self.probs_seq1, self.probs_seq2]) - decoder = ctcdecode.CTCBeamDecoder( - self.vocab_list, beam_width=self.beam_size, blank_id=self.vocab_list.index("_"), num_processes=24 - ) - beam_results, beam_scores, timesteps, out_seq_len = decoder.decode(probs_seq) - output_str1 = self.convert_to_string(beam_results[0][0], self.vocab_list, out_seq_len[0][0]) - output_str2 = self.convert_to_string(beam_results[1][0], self.vocab_list, out_seq_len[1][0]) - self.assertEqual(output_str1, self.beam_search_result[0]) - self.assertEqual(output_str2, self.beam_search_result[1]) - del decoder - - def test_beam_search_decoder_batch_log(self): - probs_seq = torch.FloatTensor([self.probs_seq1, self.probs_seq2]).log() - decoder = ctcdecode.CTCBeamDecoder( - self.vocab_list, - beam_width=self.beam_size, - blank_id=self.vocab_list.index("_"), - log_probs_input=True, - num_processes=24, - ) - beam_results, beam_scores, timesteps, out_seq_len = decoder.decode(probs_seq) - output_str1 = self.convert_to_string(beam_results[0][0], self.vocab_list, out_seq_len[0][0]) - output_str2 = self.convert_to_string(beam_results[1][0], self.vocab_list, out_seq_len[1][0]) - self.assertEqual(output_str1, self.beam_search_result[0]) - self.assertEqual(output_str2, self.beam_search_result[1]) - - def test_online_decoder_decoding(self): - lm_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test.arpa") - decoder = ctcdecode.OnlineCTCBeamDecoder( - self.vocab_list, - beam_width=self.beam_size, - blank_id=self.vocab_list.index("_"), - log_probs_input=True, - num_processes=24, - model_path=lm_path, - ) - state1 = ctcdecode.DecoderState(decoder) - state2 = ctcdecode.DecoderState(decoder) - - probs_seq = torch.FloatTensor([self.probs_seq2, self.probs_seq2]).log() - - is_eos_s = [True for _ in range(len(probs_seq))] - - beam_results, beam_scores, timesteps, out_seq_len = decoder.decode(probs_seq, [state1, state2], is_eos_s) - output_str1 = self.convert_to_string(beam_results[0][0], self.vocab_list, out_seq_len[0][0]) - output_str2 = self.convert_to_string(beam_results[1][0], self.vocab_list, out_seq_len[1][0]) - - self.assertEqual(output_str1, self.beam_search_result[2]) - self.assertEqual(output_str2, self.beam_search_result[2]) - - def test_online_decoder_decoding_no_lm(self): - decoder = ctcdecode.OnlineCTCBeamDecoder( - self.vocab_list, - beam_width=self.beam_size, - blank_id=self.vocab_list.index("_"), - log_probs_input=True, - num_processes=24, - ) - state1 = ctcdecode.DecoderState(decoder) - state2 = ctcdecode.DecoderState(decoder) - - probs_seq = torch.FloatTensor([self.probs_seq1, self.probs_seq2]).log() - - is_eos_s = [True for _ in range(len(probs_seq))] - - beam_results, beam_scores, timesteps, out_seq_len = decoder.decode( - probs_seq, [state1, state2], is_eos_s - ) - output_str1 = self.convert_to_string(beam_results[0][0], self.vocab_list, out_seq_len[0][0]) - output_str2 = self.convert_to_string(beam_results[1][0], self.vocab_list, out_seq_len[1][0]) - - self.assertEqual(output_str1, self.beam_search_result[0]) - self.assertEqual(output_str2, self.beam_search_result[1]) - - def test_online_decoder_decoding_with_two_calls(self): - lm_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test.arpa") - decoder = ctcdecode.OnlineCTCBeamDecoder( - self.vocab_list, - beam_width=self.beam_size, - blank_id=self.vocab_list.index("_"), - log_probs_input=True, - num_processes=24, - model_path=lm_path, - ) - state1 = ctcdecode.DecoderState(decoder) - - probs_seq = torch.FloatTensor([self.probs_seq2]).log() - - beam_results, beam_scores, timesteps, out_seq_len = decoder.decode(probs_seq[:, :2], [state1], [False]) - beam_results, beam_scores, timesteps, out_seq_len = decoder.decode(probs_seq[:, 2:], [state1], [True]) - - output_str1 = self.convert_to_string(beam_results[0][0], self.vocab_list, out_seq_len[0][0]) - self.assertEqual(output_str1, self.beam_search_result[2]) - - def test_online_decoder_decoding_with_two_calls_no_lm(self): - decoder = ctcdecode.OnlineCTCBeamDecoder( - self.vocab_list, - beam_width=self.beam_size, - blank_id=self.vocab_list.index("_"), - log_probs_input=True, - num_processes=24, - ) - state1 = ctcdecode.DecoderState(decoder) - state2 = ctcdecode.DecoderState(decoder) - - probs_seq = torch.FloatTensor([self.probs_seq1, self.probs_seq2]).log() - - beam_results, beam_scores, timesteps, out_seq_len = decoder.decode( - probs_seq[:, :2], [state1, state2], [False, False] - ) - beam_results, beam_scores, timesteps, out_seq_len = decoder.decode( - probs_seq[:, 2:], [state1, state2], [True, True] - ) - - del state1, state2 - size = beam_results.shape - output_str1 = self.convert_to_string(beam_results[0][0], self.vocab_list, out_seq_len[0][0]) - output_str2 = self.convert_to_string(beam_results[1][0], self.vocab_list, out_seq_len[1][0]) - - self.assertEqual(output_str1, self.beam_search_result[0]) - self.assertEqual(output_str2, self.beam_search_result[1]) - - def test_online_decoder_decoding_with_a_lot_calls_no_lm_check_size(self): - decoder = ctcdecode.OnlineCTCBeamDecoder( - self.vocab_list, - beam_width=self.beam_size, - blank_id=self.vocab_list.index("_"), - log_probs_input=True, - num_processes=24, - ) - state1 = ctcdecode.DecoderState(decoder) - - probs_seq = torch.FloatTensor([self.probs_seq1]).log() - - for i in range(1000): - beam_results, beam_scores, timesteps, out_seq_len = decoder.decode( - probs_seq, [state1], [False, False] - ) - - beam_results, beam_scores, timesteps, out_seq_len = decoder.decode( - probs_seq, [state1], [True, True] - ) - - del state1 - self.assertGreaterEqual(beam_results.shape[2], out_seq_len.max()) - - -if __name__ == "__main__": - unittest.main() diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt new file mode 100644 index 00000000..42551b0c --- /dev/null +++ b/third_party/CMakeLists.txt @@ -0,0 +1,15 @@ +cmake_minimum_required(VERSION 3.16 FATAL_ERROR) + +add_compile_options("-g" "-O3" "-DKENLM_MAX_ORDER=6" "-std=c++17" "-fPIC" "-DINCLUDE_KENLM") + +# kenlm sources +file(GLOB KENLM_UTIL_LIB_SOURCES ${CMAKE_SOURCE_DIR}/third_party/kenlm/util/*.cc) +file(GLOB KENLM_LM_LIB_SOURCES ${CMAKE_SOURCE_DIR}/third_party/kenlm/lm/*.cc) +file(GLOB KENLM_DOUBLE_CONV_LIB_SOURCES ${CMAKE_SOURCE_DIR}/third_party/kenlm/util/double-conversion/*.cc) +add_library(kenlm STATIC "${KENLM_UTIL_LIB_SOURCES}" "${KENLM_LM_LIB_SOURCES}" "${KENLM_DOUBLE_CONV_LIB_SOURCES}") +target_include_directories(kenlm PUBLIC ${CMAKE_SOURCE_DIR}/third_party/kenlm ${CMAKE_SOURCE_DIR}/third_party/boost_1_67_0) + +# openfst sources +file(GLOB FST_SOURCES ${CMAKE_SOURCE_DIR}/third_party/openfst-1.8.2/src/lib/*.cc) +add_library(fst STATIC "${FST_SOURCES}") +target_include_directories(fst PUBLIC ${CMAKE_SOURCE_DIR}/third_party/openfst-1.8.2/src/include ${CMAKE_SOURCE_DIR}/third_party/openfst-1.8.2/src/include/fst/script) \ No newline at end of file diff --git a/third_party/kenlm b/third_party/kenlm index 35835f1a..35f14583 160000 --- a/third_party/kenlm +++ b/third_party/kenlm @@ -1 +1 @@ -Subproject commit 35835f1ac4884126458ac89f9bf6dd9ccad561e0 +Subproject commit 35f145839eca742f2402716d17542fd0546efc9d diff --git a/tools/README.md b/tools/README.md new file mode 100644 index 00000000..5f1fcf03 --- /dev/null +++ b/tools/README.md @@ -0,0 +1,21 @@ +### Build FST tool + +This is a c++ program to build fst for the given lexicon files. + +Run the below steps for creating the build and constructing the FST. + +```bash + +cd .. +bash build.sh +./build/build_fst --vocab-path --lexicon-paths --output-path --freq-threshold 30[Optional] --fst-path [Optional] + +``` + +- Vocab path - Each line in the file should contain single label. Example file is provided in `tests/cpp/fixtures` directory +- Lexicon paths - Path to single/multiple lexicon files. Each line in the file should be like this: +- Frequency threshold - Words having frequency greater than or equal to this threshold will be considered while constructing the FST. (Default is -1 i.e all are considered) +- Fst path - If a fst file is provided, then the given lexicon words will be added on top of this FST file. +- Output path - Path to output file. Two output files will be generated. One with `.opt` extension contains optimized FST and other contains unoptimized. + +For more information, run `./build/build_fst --help` \ No newline at end of file diff --git a/tools/build_fst.cpp b/tools/build_fst.cpp new file mode 100644 index 00000000..6ed58816 --- /dev/null +++ b/tools/build_fst.cpp @@ -0,0 +1,292 @@ +#include "build_fst.h" + +/** + * @brief This method parses the labels file and returns a vector of labels + * + * @param vocab_path, Path to the file containing bpe labels. Each line in the file should contain a + * label. + * @return labels, A Vector of labels + */ +std::vector get_bpe_vocab(const std::string vocab_path) +{ + + std::ifstream inputFile(vocab_path); + std::vector labels; + if (!inputFile.is_open()) { + std::cerr << "Failed to open the the vocabulary file." << std::endl; + } + + std::string line; + while (std::getline(inputFile, line)) { + labels.push_back(line); + } + + std::cout << "Size of labels: " << labels.size() << std::endl; + + inputFile.close(); + + return labels; +} + +/** + * @brief Returns character map for the given labels + * + * @param labels, A Vector of labels + * @return char_map, A map of characters/tokens to their corresponding integer ids starting from 1 + */ +std::unordered_map get_char_map(const std::vector& labels) +{ + + std::unordered_map char_map; + for (int i = 0; i < labels.size(); ++i) { + char_map[labels[i]] = i + 1; + } + + return char_map; +} + +/** + * @brief This method reads the FST from the given file + * + * @param input_path, The path to the file containing the FST + * @return dictionary, The FST read from the file + */ +fst::StdVectorFst* read_fst(const std::string input_path) +{ + auto startTime = std::chrono::high_resolution_clock::now(); + // Read the FST from the file + fst::StdVectorFst* dict = fst::StdVectorFst::Read(input_path); + if (!dict) { + std::cerr << "Failed to read FST from file: " << input_path << std::endl; + } + + auto endTime = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(endTime - startTime).count(); + std::cout << "Time taken for reading the FST: " << duration << " seconds" << std::endl; + std::cout << "Number of states in FST are " << dict->NumStates() << std::endl; + return dict; +} + +/** + * @brief This method writes the given FST to the given file + * + * @param dictionary, The FST to be written + * @param output_path, The path to the file to which the FST is to be written + */ +void write_fst(fst::StdVectorFst* dictionary, const std::string output_path) +{ + auto startTime = std::chrono::high_resolution_clock::now(); + dictionary->Write(output_path); + auto endTime = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(endTime - startTime).count(); + std::cout << "Time taken for writing the FST to file: " << duration << " seconds" << std::endl; +} + +/** + * @brief This method optimizes the FST by removing the epsilon transitions and minimizing it + * + * @param dictionary, The FST to be optimized + */ +void optimize_fst(fst::StdVectorFst* dictionary) +{ + auto startTime = std::chrono::high_resolution_clock::now(); + + fst::RmEpsilon(dictionary); + fst::Determinize(*dictionary, dictionary); + fst::Minimize(dictionary); + + auto endTime = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(endTime - startTime).count(); + std::cout << "Time taken to optimize FST: " << duration << " seconds" << std::endl; +} + +/** + * @brief This method adds a word to the given FST + * + * @param characters, A vector of characters/tokens in the word + * @param char_map, A map of characters/tokens to their corresponding integer ids + * @param dictionary, The FST to which the word is to be added + * @param current_state, The current state in the FST from which the word is to be added + * @return is_word_added, Returns true if the word is added newly, else false + */ +bool add_word_to_fst(const std::vector& characters, + const std::unordered_map& char_map, + fst::StdVectorFst* dictionary, + fst::StdVectorFst::StateId current_state) +{ + bool is_word_added = false; + for (auto& c : characters) { + // Find the symbol ID for the character + auto int_c = char_map.find(c); + int symbol_id; + if (int_c == char_map.end()) { + std::cerr << "Character/token not found\n"; + } else { + symbol_id = int_c->second; + } + + // Check if the arc already exists + bool arc_exists = false; + + for (fst::ArcIterator aiter(*dictionary, current_state); !aiter.Done(); + aiter.Next()) { + const fst::StdArc& arc = aiter.Value(); + if (arc.ilabel == symbol_id) { + arc_exists = true; + current_state = arc.nextstate; + break; + } + } + if (!arc_exists) { + // Add a new arc + fst::StdVectorFst::StateId next_state = dictionary->AddState(); + dictionary->AddArc(current_state, fst::StdArc(symbol_id, symbol_id, 0, next_state)); + current_state = next_state; + is_word_added = true; + } + } + + // Set the current state as final + dictionary->SetFinal(current_state, fst::StdArc::Weight::One()); + + return is_word_added; +} + +/** + * @brief This method parses the lexicon file and adds the words in it to the given FST + * + * @param lexicon_path, The path to the lexicon file + * @param dictionary, The FST to which the words are to be added + * @param char_map, A map of characters/tokens to their corresponding integer ids + * @param freq_threshold, Frequency threshold, words having frequency greater than or equal to this + * threshold will be considered + * @return word_count, The number of words added to the FST + */ +int parse_lexicon_and_add_to_fst(const std::string& lexicon_path, + fst::StdVectorFst* dictionary, + const std::unordered_map& char_map, + const int freq_threshold) +{ + int word_count = 0; + std::ifstream file(lexicon_path); + if (!file) { + std::cerr << "Error opening lexicon file." << std::endl; + } + + std::cout << "Loading words from unigrams path provided\n"; + + std::string line; + fst::StdVectorFst::StateId start_state; + + if (dictionary->NumStates() == 0) { + std::cout << "Setting dictionary start state\n"; + start_state = dictionary->AddState(); + assert(start_state == 0); + dictionary->SetStart(start_state); + } + start_state = dictionary->Start(); + int count = 0; + + while (std::getline(file, line)) { + std::istringstream iss(line); + std::string token; + std::vector characters; + count += 1; + int i = 0; + bool skip = false; + while (std::getline(iss, token, ' ')) { + if (i == 0 && freq_threshold != -1) { + + int freq = std::stoi(token); + if (freq < freq_threshold) { + skip = true; + } + } + + if (i != 0 && i != 1 && !skip) { + characters.push_back(token); + } + ++i; + } + + if (count % 100000 == 0) { + std::cout << "Processed " << count << " records\n"; + } + + if (characters.size() > 0 && !skip) { + word_count += add_word_to_fst(characters, char_map, dictionary, start_state); + } + } + std::cout << "Constructed the fst for the given lexicon path: " << lexicon_path << std::endl; + std::cout << "Number of words in the given path are " << count << std::endl; + + return word_count; +} + +/** + * @brief This method constructs the FST from the given lexicon files + * + * @param vocab_path, Path to the file containing labels + * @param lexicon_paths, A vector of paths to lexicon files + * @param fst_path, The path to the FST file. If empty, a new FST will be created or + * else the words will be added on top of this FST + * @param output_path, The path to the file to which the FST is to be written + * @param freq_threshold, Frequency threshold, words having frequency greater than or equal to this + * threshold will be considered ( Default = -1 , i.e all are considered in this case) + * @param optimize, If true, the FST will be optimized ( Default = true, two output files will be + * generated in this case, one is optimized and other is unoptimized ) + */ +void construct_fst(const std::string vocab_path, + const std::vector& lexicon_paths, + const std::string fst_path, + std::string output_path, + const int freq_threshold = -1, + bool optimize = true) +{ + // Load vocabulary + std::vector labels = get_bpe_vocab(vocab_path); + // get character map + std::unordered_map char_map = get_char_map(labels); + + fst::StdVectorFst* dictionary; + + // Load the FST from the given file + if (!fst_path.empty()) { + std::cout << "Reading the fst from " << fst_path << std::endl; + dictionary = read_fst(fst_path); + } else { + dictionary = new fst::StdVectorFst; + } + + int dict_size = 0; + auto startTime = std::chrono::high_resolution_clock::now(); + + // Parse each lexicon file and add the words in it to the FST + for (auto lexicon_path : lexicon_paths) { + int word_count + = parse_lexicon_and_add_to_fst(lexicon_path, dictionary, char_map, freq_threshold); + std::cout << "Number of words added to the dictionary are " << word_count << std::endl; + } + + auto endTime = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(endTime - startTime).count(); + std::cout << "Time taken to create FST: " << duration << " seconds" << std::endl; + + // Write the FST to the given output file + // output file with `.opt` extension will be created if optimize is true + write_fst(dictionary, output_path); + if (optimize) { + optimize_fst(dictionary); + write_fst(dictionary, output_path + ".opt"); + } + + // Number of states in FST + std::cout << "Number of states in FST are " << dictionary->NumStates() << std::endl; + + endTime = std::chrono::high_resolution_clock::now(); + duration = std::chrono::duration_cast(endTime - startTime).count(); + std::cout << "Total time taken: " << duration << " seconds" << std::endl; + + // delete the FST + delete dictionary; +} \ No newline at end of file diff --git a/tools/build_fst.h b/tools/build_fst.h new file mode 100644 index 00000000..43ece64b --- /dev/null +++ b/tools/build_fst.h @@ -0,0 +1,34 @@ +#ifndef BUILD_FST_H +#define BUILD_FST_H + +#include "fst/fst.h" +#include "fst/fstlib.h" + +std::vector get_bpe_vocab(const std::string vocab_path); + +std::unordered_map get_char_map(const std::vector& labels); + +fst::StdVectorFst* read_fst(const std::string output_path); + +void write_fst(fst::StdVectorFst* dictionary, const std::string output_path); + +void optimize_fst(fst::StdVectorFst* dictionary); + +bool add_word_to_fst(const std::vector& characters, + const std::unordered_map& char_map, + fst::StdVectorFst* dictionary, + fst::StdVectorFst::StateId current_state); + +int parse_lexicon_and_add_to_fst(const std::string& lexicon_path, + fst::StdVectorFst* dictionary, + const std::unordered_map& char_map, + const int freq_threshold); + +void construct_fst(const std::string vocab_path, + const std::vector& lexicon_paths, + const std::string fst_path, + std::string output_path, + const int freq_threshold, + bool optimize); + +#endif diff --git a/tools/build_fst_main.cpp b/tools/build_fst_main.cpp new file mode 100644 index 00000000..8c4e93d7 --- /dev/null +++ b/tools/build_fst_main.cpp @@ -0,0 +1,61 @@ +#include + +#include "build_fst.h" + +int main(int argc, char* argv[]) +{ + cxxopts::Options options("build_fst", "A program to build FST from the lexicon files"); + + std::vector lexicon_paths; + + options.add_options()( + "v,vocab-path", "Path to file containing labels", cxxopts::value())( + "i,lexicon-paths", + "Path to lexicon files. Multiple paths can be provided. Each line in the file should be " + "like this: ", + cxxopts::value>(lexicon_paths))( + "o,output-path", + "Path to a output file. Both optimized and unoptimized files gets created. optimized file " + "contains `.opt` extension", + cxxopts::value())( + "freq-threshold", + "Words having frequency greater than or equal to this threshold will be considered " + "(Default = -1 i.e all are considered in this case) ", + cxxopts::value()->default_value("-1"))( + "fst-path", + "Path to a fst file. Default is empty. If provided, the words will be added on top of this " + "FST. (NOTE: Unoptimized fst file need to be provided in this case, otherwise the " + "generated FST will " + "not be proper ) ", + cxxopts::value()->default_value(""))("h,help", "Print usage"); + + options.parse_positional({ "vocab-path", "lexicon-paths" }); + + auto result = options.parse(argc, argv); + + if (result.count("help")) { + std::cout << options.help() << std::endl; + exit(0); + } + + if (!lexicon_paths.empty()) { + std::cout << lexicon_paths.size() << " lexicon paths are provided: " << std::endl; + for (auto path : lexicon_paths) { + std::cout << path << std::endl; + } + } else { + std::cout << "No lexicon paths are provided. Exiting" << std::endl; + exit(0); + } + + std::string vocab_path = result["vocab-path"].as(); + std::string output_path = result["output-path"].as(); + std::string fst_path = result["fst-path"].as(); + int freq_threshold = result["freq-threshold"].as(); + + std::cout << "Freq threshold: " << freq_threshold << std::endl; + + construct_fst(vocab_path, lexicon_paths, fst_path, output_path, freq_threshold, true); + + return 0; +}