diff --git a/TensorVox.pro b/TensorVox.pro index dd6ec15..710baaa 100644 --- a/TensorVox.pro +++ b/TensorVox.pro @@ -25,12 +25,16 @@ SOURCES += \ VoxCommon.cpp \ attention.cpp \ batchdenoisedlg.cpp \ + bert.cpp \ + berttokenizer.cpp \ + devits.cpp \ espeakphonemizer.cpp \ ext/ByteArr.cpp \ ext/Qt-Frameless-Window-DarkStyle-master/DarkStyle.cpp \ ext/Qt-Frameless-Window-DarkStyle-master/framelesswindow/framelesswindow.cpp \ ext/Qt-Frameless-Window-DarkStyle-master/framelesswindow/windowdragger.cpp \ ext/ZCharScanner.cpp \ + ext/ZCharScannerWide.cpp \ ext/ZFile.cpp \ ext/qcustomplot.cpp \ istftnettorch.cpp \ @@ -61,6 +65,9 @@ HEADERS += \ VoxCommon.hpp \ attention.h \ batchdenoisedlg.h \ + bert.h \ + berttokenizer.h \ + devits.h \ espeakphonemizer.h \ ext/AudioFile.hpp \ ext/ByteArr.h \ @@ -76,6 +83,7 @@ HEADERS += \ ext/Qt-Frameless-Window-DarkStyle-master/framelesswindow/framelesswindow.h \ ext/Qt-Frameless-Window-DarkStyle-master/framelesswindow/windowdragger.h \ ext/ZCharScanner.h \ + ext/ZCharScannerWide.h \ ext/ZFile.h \ ext/json.hpp \ ext/qcustomplot.h \ @@ -115,7 +123,7 @@ DEFINES += _CRT_SECURE_NO_WARNINGS INCLUDEPATH += $$PWD/deps/include INCLUDEPATH += $$PWD/deps/include/libtorch INCLUDEPATH += $$PWD/ext/Qt-Frameless-Window-DarkStyle-master/framelesswindow -win32: LIBS += -L$$PWD/deps/lib/ tensorflow.lib r8bsrc64.lib rnnoise64.lib LogitechLEDLib.lib LibNumberText64.lib c10.lib torch.lib torch_cpu.lib libespeak-ng.lib +win32: LIBS += -L$$PWD/deps/lib/ tensorflow.lib r8bsrc64.lib rnnoise64.lib LogitechLEDLib.lib LibNumberText64.lib c10.lib torch.lib torch_cpu.lib libespeak-ng.lib Utf8Proc.lib win32: LIBS += Advapi32.lib User32.lib Psapi.lib diff --git a/Voice.cpp b/Voice.cpp index 425462b..37125af 100644 --- a/Voice.cpp +++ b/Voice.cpp @@ -120,14 +120,16 @@ Voice::Voice(const std::string & VoxPath, const std::string &inName, Phonemizer const int32_t Tex2MelArch = VoxInfo.Architecture.Text2Mel; - const bool IsVITS = Tex2MelArch == EText2MelModel::VITS || Tex2MelArch == EText2MelModel::VITSTM; + const bool IsVITS = Tex2MelArch == EText2MelModel::VITS || Tex2MelArch == EText2MelModel::DEVITS; if (Tex2MelArch == EText2MelModel::Tacotron2) MelPredictor = std::make_unique(); else if (Tex2MelArch == EText2MelModel::FastSpeech2) MelPredictor = std::make_unique(); - else if (Tex2MelArch == EText2MelModel::VITS || Tex2MelArch == EText2MelModel::VITSTM) + else if (Tex2MelArch == EText2MelModel::VITS) MelPredictor = std::make_unique(); + else if (Tex2MelArch == EText2MelModel::DEVITS) + MelPredictor = std::make_unique(); else MelPredictor = std::make_unique(); @@ -144,8 +146,12 @@ Voice::Voice(const std::string & VoxPath, const std::string &inName, Phonemizer - if (Tex2MelArch == EText2MelModel::VITSTM) + if (Tex2MelArch == EText2MelModel::DEVITS){ Moji.Initialize(VoxPath + "/moji.pt", VoxPath + "/tm_dict.txt"); + BertFE.Initialize(VoxPath + "/bert.pt", VoxPath + "/bert_vocab.txt"); + + } + const int32_t VocoderArch = VoxInfo.Architecture.Vocoder; @@ -286,20 +292,36 @@ VoxResults Voice::Vocalize(const std::string & Prompt, float Speed, int32_t Spea Mel = ((FastSpeech2*)MelPredictor.get())->DoInference(InputIDs,FloatArgs,IntArgs,SpeakerID, EmotionID); - }else + }else if (Text2MelN == EText2MelModel::VITS) { FloatArgs = {Speed}; - if (EmotionOvr.size()){ - std::vector MojiInput = Processor.GetTokenizer().Tokenize(EmotionOvr,true,true); - std::vector MojiStates = Moji.Infer(MojiInput); + TFTensor Audio = MelPredictor.get()->DoInference(InputIDs,FloatArgs,IntArgs,SpeakerID,EmotionID); + Attention = ((VITS*)MelPredictor.get())->Attention; - FloatArgs.insert(FloatArgs.end(),MojiStates.begin(),MojiStates.end()); + std::vector AudioData = Audio.Data; - } + Mel.Shape.push_back(-1); // Tell the plotter that we have no mel to plot - TFTensor Audio = MelPredictor.get()->DoInference(InputIDs,FloatArgs,IntArgs,SpeakerID,EmotionID); + // As VITS is fully E2E, we return here + + return {AudioData,Attention,Mel}; + + }else // DE-VITS + { + FloatArgs = {Speed}; + std::vector MojiInput = Processor.GetTokenizer().Tokenize(EmotionOvr,true,true); + TFTensor MojiStates = Moji.Infer(MojiInput); + + auto BERTOutputs = BertFE.Infer(Prompt); + + + + + TFTensor Audio = ((DEVITS*)MelPredictor.get())->DoInferenceDE(InputIDs, MojiStates, + BERTOutputs.first,FloatArgs, + IntArgs,SpeakerID,EmotionID); Attention = ((VITS*)MelPredictor.get())->Attention; std::vector AudioData = Audio.Data; @@ -309,7 +331,6 @@ VoxResults Voice::Vocalize(const std::string & Prompt, float Speed, int32_t Spea // As VITS is fully E2E, we return here return {AudioData,Attention,Mel}; - } // Vocoder inference diff --git a/Voice.h b/Voice.h index 68bfd8f..6f3f1ea 100644 --- a/Voice.h +++ b/Voice.h @@ -10,6 +10,9 @@ #include "phoneticdict.h" #include "tacotron2torch.h" #include "istftnettorch.h" +#include "devits.h" +#include "bert.h" + struct VoxResults{ std::vector Audio; TFTensor Alignment; @@ -24,6 +27,7 @@ class Voice EnglishPhoneticProcessor Processor; VoiceInfo VoxInfo; TorchMoji Moji; + BERT BertFE; diff --git a/VoxCommon.cpp b/VoxCommon.cpp index f81b532..1973ab9 100644 --- a/VoxCommon.cpp +++ b/VoxCommon.cpp @@ -4,7 +4,7 @@ using namespace nlohmann; #include #include // std::wstring_convert -const std::vector Text2MelNames = {"FastSpeech2","Tacotron2 (TF)","VITS","VITS + TorchMoji","Tacotron2 (Torch)"}; +const std::vector Text2MelNames = {"FastSpeech2","Tacotron2 (TF)","VITS","DE-VITS","Tacotron2 (Torch)"}; const std::vector VocoderNames = {"Multi-Band MelGAN","MelGAN-STFT","","iSTFTNet"}; const std::vector RepoNames = {"TensorflowTTS","Coqui-TTS","jaywalnut310","keonlee9420"}; diff --git a/VoxCommon.hpp b/VoxCommon.hpp index d3c27f8..ac42c1d 100644 --- a/VoxCommon.hpp +++ b/VoxCommon.hpp @@ -63,7 +63,7 @@ enum Enum{ FastSpeech2 = 0, Tacotron2, VITS, - VITSTM, + DEVITS, Tacotron2Torch }; @@ -167,7 +167,7 @@ namespace VoxUtil { // Copy PyTorch tensor template - TFTensor CopyTensor(at::Tensor& InTens){ + TFTensor CopyTensor(const at::Tensor& InTens){ D* Data = InTens.data(); std::vector Shape = InTens.sizes().vec(); diff --git a/bert.cpp b/bert.cpp new file mode 100644 index 0000000..9a43ae1 --- /dev/null +++ b/bert.cpp @@ -0,0 +1,67 @@ +#include "bert.h" +#include + +BERT::BERT() +{ + +} + +BERT::BERT(const std::string &Path, const std::string &DictPath) +{ + Initialize(Path, DictPath); +} + +void BERT::Initialize(const std::string &Path, const std::string &DictPath) +{ + Model = torch::jit::load(Path); + Tokenizer = std::make_unique(DictPath,true); + + +} + +std::pair, TFTensor > BERT::Infer(const std::string &InText) +{ + torch::NoGradGuard no_grad; + + auto Tokens = Tokenizer->tokenize(InText + "\n"); + auto Ids = Tokenizer->convertTokensToIds(Tokens); + + std::vector InTokens(Ids.begin(),Ids.end()); + + + + auto InIDS = torch::tensor(InTokens).unsqueeze(0); // (1, tokens) + std::pair,TFTensor> BERTOutputs; + + try{ + auto Output = Model({InIDS}).toTuple(); // (hidden states, pooled) + BERTOutputs.first = VoxUtil::CopyTensor(Output.get()->elements()[0].toTensor()); + BERTOutputs.second = VoxUtil::CopyTensor(Output.get()->elements()[1].toTensor()); + + + + } + + catch (const std::exception& e) { + int msgboxID = MessageBox( + NULL, + (LPCWSTR)QString::fromStdString(e.what()).toStdWString().c_str(), + (LPCWSTR)L"Error1!!", + MB_ICONWARNING | MB_CANCELTRYCONTINUE | MB_DEFBUTTON2 + ); + + + return BERTOutputs; + + } + + + + + + + return BERTOutputs; + + + +} diff --git a/bert.h b/bert.h new file mode 100644 index 0000000..92cabeb --- /dev/null +++ b/bert.h @@ -0,0 +1,28 @@ +#ifndef BERT_H +#define BERT_H + +#include "VoxCommon.hpp" + +#include "berttokenizer.h" + +// BERT: Class for inference of TorchScript-exported BERT. +class BERT +{ +private: + torch::jit::script::Module Model; + std::unique_ptr Tokenizer; + +public: + BERT(); + BERT(const std::string& Path,const std::string& DictPath); + void Initialize(const std::string& Path,const std::string& DictPath); + + + // Do inference on BERT model. + // Returns 2 tensors: + // [1, tokens, channels] : Hidden states + // [1, channels] : Pooled embeddings + std::pair,TFTensor> Infer(const std::string& InText); +}; + +#endif // BERT_H diff --git a/berttokenizer.cpp b/berttokenizer.cpp new file mode 100644 index 0000000..2befdb5 --- /dev/null +++ b/berttokenizer.cpp @@ -0,0 +1,301 @@ +#include "berttokenizer.h" +const std::wstring stripChar = L" \t\n\r\v\f"; +#include "utf8proc.h" +#include "ext/ZCharScannerWide.h" + +const std::vector sDelimChar = {L" ",L"\t",L"\n",L"\r",L"\v",L"\f"}; + +static std::string normalize_nfd(const std::string& s) { + std::string ret; + char *result = (char *) utf8proc_NFD((unsigned char *)s.c_str()); + if (result) { + ret = std::string(result); + free(result); + result = NULL; + } + return ret; +} + +static bool isStripChar(const wchar_t& ch) { + return stripChar.find(ch) != std::wstring::npos; +} + +static std::wstring strip(const std::wstring& text) { + std::wstring ret = text; + if (ret.empty()) return ret; + size_t pos = 0; + while (pos < ret.size() && isStripChar(ret[pos])) pos++; + if (pos != 0) ret = ret.substr(pos, ret.size() - pos); + pos = ret.size() - 1; + while (pos != (size_t)-1 && isStripChar(ret[pos])) pos--; + return ret.substr(0, pos + 1); +} + +static std::vector split(const std::wstring& text) { + + + + std::vector result; + ZStringDelimiterWide Del(text); + + Del.AddDelimiter(L" "); + + result = Del.GetTokens(); + if (!result.size()) + result.push_back(text); // + + + + return result; +} + +static std::vector whitespaceTokenize(const std::wstring& text) { + std::wstring rtext = strip(text); + if (rtext.empty()) return std::vector(); + return split(text); +} + +static std::wstring convertToUnicode(const std::string& text) { + size_t i = 0; + std::wstring ret; + while (i < text.size()) { + wchar_t codepoint; + utf8proc_ssize_t forward = utf8proc_iterate((utf8proc_uint8_t *)&text[i], text.size() - i, (utf8proc_int32_t*)&codepoint); + if (forward < 0) return L""; + ret += codepoint; + i += forward; + } + return ret; +} + +static std::string convertFromUnicode(const std::wstring& wText) { + char dst[64]; + std::string ret; + for (auto ch : wText) { + utf8proc_ssize_t num = utf8proc_encode_char(ch, (utf8proc_uint8_t *)dst); + if (num <= 0) return ""; + ret += std::string(dst, dst+num); + } + return ret; +} + +static std::wstring tolower(const std::wstring& s) { + std::wstring ret(s.size(), L' '); + for (size_t i = 0; i < s.size(); i++) { + ret[i] = utf8proc_tolower(s[i]); + } + return ret; +} + +static std::shared_ptr loadVocab(const std::string& vocabFile) { + std::shared_ptr vocab(new Vocab); + size_t index = 0; + std::ifstream ifs(vocabFile, std::ifstream::in); + std::string line; + while (getline(ifs, line)) { + std::wstring token = convertToUnicode(line); + if (token.empty()) break; + token = strip(token); + (*vocab)[token] = index; + index++; + } + + return vocab; +} + +BasicTokenizer::BasicTokenizer(bool doLowerCase) + : mDoLowerCase(doLowerCase) { +} + +std::wstring BasicTokenizer::cleanText(const std::wstring& text) const { + std::wstring output; + for (const wchar_t& cp : text) { + if (cp == 0 || cp == 0xfffd || isControol(cp)) continue; + if (isWhitespace(cp)) output += L" "; + else output += cp; + } + return output; +} + +bool BasicTokenizer::isControol(const wchar_t& ch) const { + if (ch== L'\t' || ch== L'\n' || ch== L'\r') return false; + auto cat = utf8proc_category(ch); + if (cat == UTF8PROC_CATEGORY_CC || cat == UTF8PROC_CATEGORY_CF) return true; + return false; +} + +bool BasicTokenizer::isWhitespace(const wchar_t& ch) const { + if (ch== L' ' || ch== L'\t' || ch== L'\n' || ch== L'\r') return true; + auto cat = utf8proc_category(ch); + if (cat == UTF8PROC_CATEGORY_ZS) return true; + return false; +} + +bool BasicTokenizer::isPunctuation(const wchar_t& ch) const { + if ((ch >= 33 && ch <= 47) || (ch >= 58 && ch <= 64) || + (ch >= 91 && ch <= 96) || (ch >= 123 && ch <= 126)) return true; + auto cat = utf8proc_category(ch); + if (cat == UTF8PROC_CATEGORY_PD || cat == UTF8PROC_CATEGORY_PS + || cat == UTF8PROC_CATEGORY_PE || cat == UTF8PROC_CATEGORY_PC + || cat == UTF8PROC_CATEGORY_PO //sometimes ΒΆ belong SO + || cat == UTF8PROC_CATEGORY_PI + || cat == UTF8PROC_CATEGORY_PF) return true; + return false; +} + +bool BasicTokenizer::isChineseChar(const wchar_t& ch) const { + if ((ch >= 0x4E00 && ch <= 0x9FFF) || + (ch >= 0x3400 && ch <= 0x4DBF) || + (ch >= 0x20000 && ch <= 0x2A6DF) || + (ch >= 0x2A700 && ch <= 0x2B73F) || + (ch >= 0x2B740 && ch <= 0x2B81F) || + (ch >= 0x2B820 && ch <= 0x2CEAF) || + (ch >= 0xF900 && ch <= 0xFAFF) || + (ch >= 0x2F800 && ch <= 0x2FA1F)) + return true; + return false; +} + +std::wstring BasicTokenizer::tokenizeChineseChars(const std::wstring& text) const { + std::wstring output; + for (auto& ch : text) { + if (isChineseChar(ch)) { + output += L' '; + output += ch; + output += L' '; + } + else + output += ch; + } + return output; +} + +std::wstring BasicTokenizer::runStripAccents(const std::wstring& text) const { + //Strips accents from a piece of text. + std::wstring nText; + try { + nText = convertToUnicode(normalize_nfd(convertFromUnicode(text))); + } catch (std::bad_cast& e) { + std::cerr << "bad_cast" << std::endl; + return L""; + } + + std::wstring output; + for (auto& ch : nText) { + auto cat = utf8proc_category(ch); + if (cat == UTF8PROC_CATEGORY_MN) continue; + output += ch; + } + return output; +} + +std::vector BasicTokenizer::runSplitOnPunc(const std::wstring& text) const { + size_t i = 0; + bool startNewWord = true; + std::vector output; + while (i < text.size()) { + wchar_t ch = text[i]; + if (isPunctuation(ch)) { + output.push_back(std::wstring(&ch, 1)); + startNewWord = true; + } + else { + if (startNewWord) output.push_back(std::wstring()); + startNewWord = false; + output[output.size() - 1] += ch; + } + i++; + } + return output; +} + +std::vector BasicTokenizer::tokenize(const std::string& text) const { + std::wstring nText = convertToUnicode(text); + nText = cleanText(nText); + + nText = tokenizeChineseChars(nText); + + const std::vector& origTokens = whitespaceTokenize(nText); + std::vector splitTokens; + for (std::wstring token : origTokens) { + if (mDoLowerCase) { + token = tolower(token); + token = runStripAccents(token); + } + const auto& tokens = runSplitOnPunc(token); + splitTokens.insert(splitTokens.end(), tokens.begin(), tokens.end()); + } + ZStringDelimiterWide DelRs; + std::wstring WSP = DelRs.Reassemble(L" ",splitTokens); + + + + + return whitespaceTokenize(WSP); +} + +WordpieceTokenizer::WordpieceTokenizer(const std::shared_ptr vocab, const std::wstring& unkToken, size_t maxInputCharsPerWord) + : mVocab(vocab), + mUnkToken(unkToken), + mMaxInputCharsPerWord(maxInputCharsPerWord) { +} + +std::vector WordpieceTokenizer::tokenize(const std::wstring& text) const { + std::vector outputTokens; + + for (auto& token : whitespaceTokenize(text)) { + if (token.size() > mMaxInputCharsPerWord) { + outputTokens.push_back(mUnkToken); + } + bool isBad = false; + size_t start = 0; + std::vector subTokens; + while (start < token.size()) { + size_t end = token.size(); + std::wstring curSubstr; + bool hasCurSubstr = false; + while (start < end) { + std::wstring substr = token.substr(start, end - start); + if (start > 0) substr = L"##" + substr; + if (mVocab->find(substr) != mVocab->end()) { + curSubstr = substr; + hasCurSubstr = true; + break; + } + end--; + } + if (!hasCurSubstr) { + isBad = true; + break; + } + subTokens.push_back(curSubstr); + start = end; + } + if (isBad) outputTokens.push_back(mUnkToken); + else outputTokens.insert(outputTokens.end(), subTokens.begin(), subTokens.end()); + } + return outputTokens; +} + +FullTokenizer::FullTokenizer(const std::string& vocabFile, bool doLowerCase) : + mVocab(loadVocab(vocabFile)), + mBasicTokenizer(BasicTokenizer(doLowerCase)), + mWordpieceTokenizer(WordpieceTokenizer(mVocab)) { + for (auto& v : *mVocab) mInvVocab[v.second] = v.first; +} + +std::vector FullTokenizer::tokenize(const std::string& text) const { + std::vector splitTokens; + for (auto& token : mBasicTokenizer.tokenize(text)) + for (auto& subToken : mWordpieceTokenizer.tokenize(token)) + splitTokens.push_back(subToken); + return splitTokens; +} + +std::vector FullTokenizer::convertTokensToIds(const std::vector& text) const { + std::vector ret(text.size()); + for (size_t i = 0; i < text.size(); i++) { + ret[i] = (*mVocab)[text[i]]; + } + return ret; +} diff --git a/berttokenizer.h b/berttokenizer.h new file mode 100644 index 0000000..58cf6ea --- /dev/null +++ b/berttokenizer.h @@ -0,0 +1,64 @@ +#ifndef BERTTOKENIZER_H +#define BERTTOKENIZER_H +// https://gist.github.com/luistung/ace4888cf5fd1bad07844021cb2c7ecf + +#include +#include +#include +#include +#include + +using Vocab = std::unordered_map; +using InvVocab = std::unordered_map; + +class BasicTokenizer { +public: + BasicTokenizer(bool doLowerCase); + std::vector tokenize(const std::string& text) const; + +private: + std::wstring cleanText(const std::wstring& text) const; + bool isControol(const wchar_t& ch) const; + bool isWhitespace(const wchar_t& ch) const; + bool isPunctuation(const wchar_t& ch) const; + bool isChineseChar(const wchar_t& ch) const; + std::wstring tokenizeChineseChars(const std::wstring& text) const; + bool isStripChar(const wchar_t& ch) const; + std::wstring strip(const std::wstring& text) const; + std::vector split(const std::wstring& text) const; + std::wstring runStripAccents(const std::wstring& text) const; + std::vector runSplitOnPunc(const std::wstring& text) const; + + bool mDoLowerCase; +}; + +class WordpieceTokenizer { +public: + WordpieceTokenizer(std::shared_ptr vocab, const std::wstring& unkToken = L"[UNK]", size_t maxInputCharsPerWord=200); + std::vector tokenize(const std::wstring& text) const; + +private: + std::shared_ptr mVocab; + std::wstring mUnkToken; + size_t mMaxInputCharsPerWord; +}; + +class FullTokenizer { +public: + FullTokenizer(const std::string& vocabFile, bool doLowerCase = true); + std::vector tokenize(const std::string& text) const; + std::vector convertTokensToIds(const std::vector& text) const; + +private: + std::shared_ptr mVocab; + InvVocab mInvVocab; + std::string mVocabFile; + bool mDoLowerCase; + BasicTokenizer mBasicTokenizer; + WordpieceTokenizer mWordpieceTokenizer; +}; + + + + +#endif // BERTTOKENIZER_H diff --git a/devits.cpp b/devits.cpp new file mode 100644 index 0000000..ea9e20f --- /dev/null +++ b/devits.cpp @@ -0,0 +1,69 @@ +#include "devits.h" +DEVITS::DEVITS() +{ + +} + +TFTensor DEVITS::DoInferenceDE(const std::vector &InputIDs, const TFTensor &MojiIn, const TFTensor &BERTIn, const std::vector &ArgsFloat, const std::vector ArgsInt, int32_t SpeakerID, int32_t EmotionID) +{ + // without this memory consumption is 4x + torch::NoGradGuard no_grad; + + + + std::vector PaddedIDs; + + + PaddedIDs = ZeroPadVec(InputIDs); + + + std::vector inLen = { (int64_t)PaddedIDs.size() }; + + + // ZDisket: Is this really necessary? + torch::TensorOptions Opts = torch::TensorOptions().requires_grad(false); + + auto InIDS = torch::tensor(PaddedIDs, Opts).unsqueeze(0); + auto InLens = torch::tensor(inLen, Opts); + auto MojiHidden = torch::tensor(MojiIn.Data).unsqueeze(0); + auto BERTHidden = torch::tensor(BERTIn.Data).reshape(BERTIn.Shape); + + std::vector BERTSz = {BERTIn.Shape[1]}; + auto BERTLens = torch::tensor(BERTSz); + + auto InLenScale = torch::tensor({ ArgsFloat[0]}, Opts); + + + + std::vector inputs{ InIDS,InLens, MojiHidden, BERTHidden, BERTLens, InLenScale }; + + if (SpeakerID != -1){ + auto InSpkid = torch::tensor({SpeakerID},Opts); + inputs.push_back(InSpkid); + } + + + + + + // Infer + + c10::IValue Output = Model.get_method("infer_ts")(inputs); + + // Output = tuple (audio,att) + + auto OutputT = Output.toTuple(); + + // Grab audio + // [1, frames] -> [frames] + auto AuTens = OutputT.get()->elements()[0].toTensor().squeeze(); + + // Grab Attention + // [1, 1, x, y] -> [x, y] -> [y,x] -> [1, y, x] + auto AttTens = OutputT.get()->elements()[1].toTensor().squeeze().transpose(0,1).unsqueeze(0); + + Attention = VoxUtil::CopyTensor(AttTens); + + return VoxUtil::CopyTensor(AuTens); + +} diff --git a/devits.h b/devits.h new file mode 100644 index 0000000..0b7c51a --- /dev/null +++ b/devits.h @@ -0,0 +1,24 @@ +#ifndef DEVITS_H +#define DEVITS_H +#include "vits.h" + +class DEVITS : public VITS +{ +public: + DEVITS(); + + /* + Do inference on a DE-VITS model. + + -> InputIDs: Input IDs of tokens for inference + -> SpeakerID: ID of the speaker in the model to do inference on. If single speaker, always leave at 0. If multispeaker, refer to your model. + -> MojiIn: TorchMoji hidden states size [tm] + -> BERTIn: BERT hidden states size [1, n_tokens, channels] + -> ArgsFloat[0]: Length scale. + + <- Returns: TFTensor with shape {frames} of audio data + */ + TFTensor DoInferenceDE(const std::vector& InputIDs, const TFTensor& MojiIn, const TFTensor& BERTIn,const std::vector& ArgsFloat,const std::vector ArgsInt, int32_t SpeakerID = 0, int32_t EmotionID = -1); +}; + +#endif // DEVITS_H diff --git a/ext/ZCharScanner.cpp b/ext/ZCharScanner.cpp index a281db8..06fcdaf 100644 --- a/ext/ZCharScanner.cpp +++ b/ext/ZCharScanner.cpp @@ -198,6 +198,13 @@ void ZStringDelimiter::AddDelimiter(const GString & in_Delim) } +void ZStringDelimiter::SetDelimiters(const std::vector &Delims) +{ + m_vDelimiters.assign(Delims.begin(),Delims.end()); + UpdateTokens(); + +} + ZStringDelimiter::~ZStringDelimiter() { } diff --git a/ext/ZCharScanner.h b/ext/ZCharScanner.h index a7fef3c..9c01ddc 100644 --- a/ext/ZCharScanner.h +++ b/ext/ZCharScanner.h @@ -5,11 +5,10 @@ #include #include -#define ZSDEL_USE_STD_STRING -#ifndef ZSDEL_USE_STD_STRING -#include "golem_string.h" -#else +#ifndef ZSDEL_USE_WSTRING #define GString std::string +#else +#define GString std::wstring #endif typedef std::vector::const_iterator TokenIterator; @@ -73,6 +72,7 @@ class ZStringDelimiter UpdateTokens(); } void AddDelimiter(const GString& in_Delim); + void SetDelimiters(const std::vector& Delims); ~ZStringDelimiter(); }; diff --git a/ext/ZCharScannerWide.cpp b/ext/ZCharScannerWide.cpp new file mode 100644 index 0000000..bcd22fb --- /dev/null +++ b/ext/ZCharScannerWide.cpp @@ -0,0 +1,210 @@ +#include "ZCharScannerWide.h" +using namespace std; +#include + +int ZStringDelimiterWide::key_search(const std::wstring& s, const std::wstring& key) +{ + int count = 0; + size_t pos = 0; + while ((pos = s.find(key, pos)) != std::wstring::npos) { + ++count; + ++pos; + } + return count; +} +void ZStringDelimiterWide::UpdateTokens() +{ + if (!m_vDelimiters.size() || m_sString == L"") + return; + + m_vTokens.clear(); + + + vector::iterator dIt = m_vDelimiters.begin(); + while (dIt != m_vDelimiters.end()) + { + std::wstring delimiter = *dIt; + + + DelimStr(m_sString, delimiter, true); + + + ++dIt; + } + + + +} + + +void ZStringDelimiterWide::DelimStr(const std::wstring & s, const std::wstring & delimiter, const bool & removeEmptyEntries) +{ + BarRange(0, s.length()); + for (size_t start = 0, end; start < s.length(); start = end + delimiter.length()) + { + size_t position = s.find(delimiter, start); + end = position != std::wstring::npos ? position : s.length(); + + std::wstring token = s.substr(start, end - start); + if (!removeEmptyEntries || !token.empty()) + { + if (token != s) + m_vTokens.push_back(token); + + } + Bar(position); + } + + // dadwwdawdaawdwadwd +} + +void ZStringDelimiterWide::BarRange(const int & min, const int & max) +{ +#ifdef _AFX_ALL_WARNINGS + if (PgBar) + m_pBar->SetRange32(min, max); + + +#endif +} + +void ZStringDelimiterWide::Bar(const int & pos) +{ +#ifdef _AFX_ALL_WARNINGS + if (PgBar) + m_pBar->SetPos(pos); + + +#endif +} + +ZStringDelimiterWide::ZStringDelimiterWide() +{ + m_sString = L""; + tokenIndex = 0; + PgBar = false; +} + + +bool ZStringDelimiterWide::GetFirstToken(std::wstring & in_out) +{ + if (m_vTokens.size() >= 1) { + in_out = m_vTokens[0]; + return true; + } + else { + return false; + } +} + +bool ZStringDelimiterWide::GetNextToken(std::wstring & in_sOut) +{ + if (tokenIndex > m_vTokens.size() - 1) + return false; + + in_sOut = m_vTokens[tokenIndex]; + ++tokenIndex; + + return true; +} + +std::wstring ZStringDelimiterWide::operator[](const size_t & in_index) +{ + if (in_index > m_vTokens.size()) + throw std::out_of_range("ZStringDelimiter tried to access token higher than size"); + + return m_vTokens[in_index]; + +} +std::wstring ZStringDelimiterWide::Reassemble(const std::wstring& delim, const int& nelem) +{ + std::wstring Result = L""; + TokenIterator RasIt = m_vTokens.begin(); + int r = 0; + if (nelem == -1) { + while (RasIt != m_vTokens.end()) + { + + if (r != 0) + Result.append(delim); + + Result.append(*RasIt); + + ++r; + + + ++RasIt; + } + } + else { + while (RasIt != m_vTokens.end() && r < nelem) + { + + if (r != 0) + Result.append(delim); + + Result.append(*RasIt); + + ++r; + ++RasIt; + } + } + + return Result; + +} + +std::wstring ZStringDelimiterWide::Reassemble(const std::wstring & delim, const std::vector& Strs,int nelem) +{ + std::wstring Result = L""; + TokenIterator RasIt = Strs.begin(); + int r = 0; + if (nelem == -1) { + while (RasIt != Strs.end()) + { + + if (r != 0) + Result.append(delim); + + Result.append(*RasIt); + + ++r; + + + ++RasIt; + } + } + else { + while (RasIt != Strs.end() && r < nelem) + { + + if (r != 0) + Result.append(delim); + + Result.append(*RasIt); + + ++r; + ++RasIt; + } + } + + return Result; +} + +void ZStringDelimiterWide::AddDelimiter(const std::wstring & in_Delim) +{ + m_vDelimiters.push_back(in_Delim); + UpdateTokens(); + +} + +void ZStringDelimiterWide::SetDelimiters(const std::vector &Delims) +{ + m_vDelimiters.assign(Delims.begin(),Delims.end()); + UpdateTokens(); + +} + +ZStringDelimiterWide::~ZStringDelimiterWide() +{ +} diff --git a/ext/ZCharScannerWide.h b/ext/ZCharScannerWide.h new file mode 100644 index 0000000..47dcad0 --- /dev/null +++ b/ext/ZCharScannerWide.h @@ -0,0 +1,74 @@ +#pragma once + +#define GBasicCharScanner ZStringDelimiter + +#include +#include + +// We need ZCharScanner but for wstrings. I copy class, fastest way +typedef std::vector::const_iterator TokenIterator; + +// ZStringDelimiter +// ============== +// Simple class to delimit and split strings. +// You can use operator[] to access them +// Or you can use the itBegin() and itEnd() to get some iterators +// ================= +class ZStringDelimiterWide +{ +private: + int key_search(const std::wstring & s, const std::wstring & key); + void UpdateTokens(); + std::vector m_vTokens; + std::vector m_vDelimiters; + + std::wstring m_sString; + + void DelimStr(const std::wstring& s, const std::wstring& delimiter, const bool& removeEmptyEntries = false); + void BarRange(const int& min, const int& max); + void Bar(const int& pos); + size_t tokenIndex; +public: + ZStringDelimiterWide(); + bool PgBar; + +#ifdef _AFX_ALL_WARNINGS + CProgressCtrl* m_pBar; +#endif + + ZStringDelimiterWide(const std::wstring& in_iStr) { + m_sString = in_iStr; + PgBar = false; + + } + + bool GetFirstToken(std::wstring& in_out); + bool GetNextToken(std::wstring& in_sOut); + + // std::String alts + + size_t szTokens() { return m_vTokens.size(); } + std::wstring operator[](const size_t& in_index); + + std::wstring Reassemble(const std::wstring & delim, const int & nelem = -1); + + // Override to reassemble provided tokens. + std::wstring Reassemble(const std::wstring & delim, const std::vector& Strs,int nelem = -1); + + // Get a const reference to the tokens + const std::vector& GetTokens() { return m_vTokens; } + + TokenIterator itBegin() { return m_vTokens.begin(); } + TokenIterator itEnd() { return m_vTokens.end(); } + + void SetText(const std::wstring& in_Txt) { + m_sString = in_Txt; + if (m_vDelimiters.size()) + UpdateTokens(); + } + void AddDelimiter(const std::wstring& in_Delim); + void SetDelimiters(const std::vector& Delims); + + ~ZStringDelimiterWide(); +}; + diff --git a/mainwindow.cpp b/mainwindow.cpp index baaf715..7da3084 100644 --- a/mainwindow.cpp +++ b/mainwindow.cpp @@ -1206,7 +1206,7 @@ void MainWindow::HandleIsMultiSpeaker(size_t inVid) ArchitectureInfo Inf = CurrentVoice.GetInfo().Architecture; - if (Inf.Text2Mel == EText2MelModel::FastSpeech2 || Inf.Text2Mel == EText2MelModel::VITS || Inf.Text2Mel == EText2MelModel::VITSTM) + if (Inf.Text2Mel == EText2MelModel::FastSpeech2 || Inf.Text2Mel == EText2MelModel::VITS || Inf.Text2Mel == EText2MelModel::DEVITS) { ui->grpFs2Params->show(); @@ -1269,7 +1269,7 @@ void MainWindow::HandleIsMultiEmotion(size_t inVid) Voice& CurrentVoice = *VoMan[inVid]; - const bool TorchMojiEnabled = CurrentVoice.GetInfo().Architecture.Text2Mel == EText2MelModel::VITSTM; + const bool TorchMojiEnabled = CurrentVoice.GetInfo().Architecture.Text2Mel == EText2MelModel::DEVITS; ui->lblEmotionOvr->setVisible(TorchMojiEnabled); ui->edtEmotionOvr->setVisible(TorchMojiEnabled); diff --git a/mainwindow.ui b/mainwindow.ui index 7d8d9e9..d69c0ae 100644 --- a/mainwindow.ui +++ b/mainwindow.ui @@ -7,7 +7,7 @@ 0 0 1047 - 526 + 538 @@ -324,7 +324,7 @@ 20 - 1000 + 5000 10 diff --git a/torchmoji.cpp b/torchmoji.cpp index c1f5b82..fc90531 100644 --- a/torchmoji.cpp +++ b/torchmoji.cpp @@ -65,8 +65,9 @@ void TorchMoji::Initialize(const std::string &Path, const std::string &DictPath) LoadDict(DictPath); } -std::vector TorchMoji::Infer(const std::vector &Seq) +TFTensor TorchMoji::Infer(const std::vector &Seq) { + torch::NoGradGuard no_grad; std::vector Input = WordsToIDs(Seq); auto InIDS = torch::tensor(Input).unsqueeze(0); // (1, TMLen) @@ -78,7 +79,7 @@ std::vector TorchMoji::Infer(const std::vector &Seq) TFTensor Tens = VoxUtil::CopyTensor(Output); - return Tens.Data; + return Tens; diff --git a/torchmoji.h b/torchmoji.h index 9bd0833..d10ddcc 100644 --- a/torchmoji.h +++ b/torchmoji.h @@ -25,8 +25,8 @@ class TorchMoji // Return hidden states of emotion state. // -> Seq: Vector of words - // <- Returns float vec of size VoxCommon::TorchMojiEmbSize containing hidden states, ready to feed into TTS model. - std::vector Infer(const std::vector& Seq); + // <- Returns float tensor of size VoxCommon::TorchMojiEmbSize containing hidden states, ready to feed into TTS model. + TFTensor Infer(const std::vector& Seq); }; #endif // TORCHMOJI_H diff --git a/vits.h b/vits.h index d2fc766..a7f1cee 100644 --- a/vits.h +++ b/vits.h @@ -12,12 +12,15 @@ class VITS : public MelGen { private: - torch::jit::script::Module Model; + + +public: + torch::jit::script::Module Model; // Most VITS model require zero-interspersed input IDs std::vector ZeroPadVec(const std::vector& InIDs); -public: + TFTensor Attention; VITS();