From 604adcacd895d10857ddf02a7919c6e4691fee31 Mon Sep 17 00:00:00 2001 From: ZDisket <30500847+ZDisket@users.noreply.github.com> Date: Tue, 28 Feb 2023 16:52:17 -0300 Subject: [PATCH 1/6] Change names --- Voice.cpp | 6 +++--- VoxCommon.cpp | 2 +- VoxCommon.hpp | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Voice.cpp b/Voice.cpp index 425462b..0d0968e 100644 --- a/Voice.cpp +++ b/Voice.cpp @@ -120,13 +120,13 @@ Voice::Voice(const std::string & VoxPath, const std::string &inName, Phonemizer const int32_t Tex2MelArch = VoxInfo.Architecture.Text2Mel; - const bool IsVITS = Tex2MelArch == EText2MelModel::VITS || Tex2MelArch == EText2MelModel::VITSTM; + const bool IsVITS = Tex2MelArch == EText2MelModel::VITS || Tex2MelArch == EText2MelModel::DEVITS; if (Tex2MelArch == EText2MelModel::Tacotron2) MelPredictor = std::make_unique(); else if (Tex2MelArch == EText2MelModel::FastSpeech2) MelPredictor = std::make_unique(); - else if (Tex2MelArch == EText2MelModel::VITS || Tex2MelArch == EText2MelModel::VITSTM) + else if (Tex2MelArch == EText2MelModel::VITS || Tex2MelArch == EText2MelModel::DEVITS) MelPredictor = std::make_unique(); else MelPredictor = std::make_unique(); @@ -144,7 +144,7 @@ Voice::Voice(const std::string & VoxPath, const std::string &inName, Phonemizer - if (Tex2MelArch == EText2MelModel::VITSTM) + if (Tex2MelArch == EText2MelModel::DEVITS) Moji.Initialize(VoxPath + "/moji.pt", VoxPath + "/tm_dict.txt"); diff --git a/VoxCommon.cpp b/VoxCommon.cpp index f81b532..1973ab9 100644 --- a/VoxCommon.cpp +++ b/VoxCommon.cpp @@ -4,7 +4,7 @@ using namespace nlohmann; #include #include // std::wstring_convert -const std::vector Text2MelNames = {"FastSpeech2","Tacotron2 (TF)","VITS","VITS + TorchMoji","Tacotron2 (Torch)"}; +const std::vector Text2MelNames = {"FastSpeech2","Tacotron2 (TF)","VITS","DE-VITS","Tacotron2 (Torch)"}; const std::vector VocoderNames = {"Multi-Band MelGAN","MelGAN-STFT","","iSTFTNet"}; const std::vector RepoNames = {"TensorflowTTS","Coqui-TTS","jaywalnut310","keonlee9420"}; diff --git a/VoxCommon.hpp b/VoxCommon.hpp index d3c27f8..e418175 100644 --- a/VoxCommon.hpp +++ b/VoxCommon.hpp @@ -63,7 +63,7 @@ enum Enum{ FastSpeech2 = 0, Tacotron2, VITS, - VITSTM, + DEVITS, Tacotron2Torch }; From 8ac877cfb21393620212b836aa6c026a9e27c9e7 Mon Sep 17 00:00:00 2001 From: ZDisket <30500847+ZDisket@users.noreply.github.com> Date: Tue, 28 Feb 2023 17:50:49 -0300 Subject: [PATCH 2/6] Begin tokenizer work - copy tokenizer - create clone class of string delimiter but wide (don't judge me for copying code) - remove boost from orig code - add utf8 lib --- TensorVox.pro | 8 +- bert.cpp | 18 +++ bert.h | 17 +++ berttokenizer.cpp | 286 +++++++++++++++++++++++++++++++++++++++ berttokenizer.h | 64 +++++++++ ext/ZCharScanner.cpp | 7 + ext/ZCharScanner.h | 8 +- ext/ZCharScannerWide.cpp | 210 ++++++++++++++++++++++++++++ ext/ZCharScannerWide.h | 74 ++++++++++ mainwindow.cpp | 4 +- 10 files changed, 689 insertions(+), 7 deletions(-) create mode 100644 bert.cpp create mode 100644 bert.h create mode 100644 berttokenizer.cpp create mode 100644 berttokenizer.h create mode 100644 ext/ZCharScannerWide.cpp create mode 100644 ext/ZCharScannerWide.h diff --git a/TensorVox.pro b/TensorVox.pro index dd6ec15..7732e1b 100644 --- a/TensorVox.pro +++ b/TensorVox.pro @@ -25,12 +25,15 @@ SOURCES += \ VoxCommon.cpp \ attention.cpp \ batchdenoisedlg.cpp \ + bert.cpp \ + berttokenizer.cpp \ espeakphonemizer.cpp \ ext/ByteArr.cpp \ ext/Qt-Frameless-Window-DarkStyle-master/DarkStyle.cpp \ ext/Qt-Frameless-Window-DarkStyle-master/framelesswindow/framelesswindow.cpp \ ext/Qt-Frameless-Window-DarkStyle-master/framelesswindow/windowdragger.cpp \ ext/ZCharScanner.cpp \ + ext/ZCharScannerWide.cpp \ ext/ZFile.cpp \ ext/qcustomplot.cpp \ istftnettorch.cpp \ @@ -61,6 +64,8 @@ HEADERS += \ VoxCommon.hpp \ attention.h \ batchdenoisedlg.h \ + bert.h \ + berttokenizer.h \ espeakphonemizer.h \ ext/AudioFile.hpp \ ext/ByteArr.h \ @@ -76,6 +81,7 @@ HEADERS += \ ext/Qt-Frameless-Window-DarkStyle-master/framelesswindow/framelesswindow.h \ ext/Qt-Frameless-Window-DarkStyle-master/framelesswindow/windowdragger.h \ ext/ZCharScanner.h \ + ext/ZCharScannerWide.h \ ext/ZFile.h \ ext/json.hpp \ ext/qcustomplot.h \ @@ -115,7 +121,7 @@ DEFINES += _CRT_SECURE_NO_WARNINGS INCLUDEPATH += $$PWD/deps/include INCLUDEPATH += $$PWD/deps/include/libtorch INCLUDEPATH += $$PWD/ext/Qt-Frameless-Window-DarkStyle-master/framelesswindow -win32: LIBS += -L$$PWD/deps/lib/ tensorflow.lib r8bsrc64.lib rnnoise64.lib LogitechLEDLib.lib LibNumberText64.lib c10.lib torch.lib torch_cpu.lib libespeak-ng.lib +win32: LIBS += -L$$PWD/deps/lib/ tensorflow.lib r8bsrc64.lib rnnoise64.lib LogitechLEDLib.lib LibNumberText64.lib c10.lib torch.lib torch_cpu.lib libespeak-ng.lib Utf8Proc.lib win32: LIBS += Advapi32.lib User32.lib Psapi.lib diff --git a/bert.cpp b/bert.cpp new file mode 100644 index 0000000..97b0258 --- /dev/null +++ b/bert.cpp @@ -0,0 +1,18 @@ +#include "bert.h" + +BERT::BERT() +{ + +} + +BERT::BERT(const std::string &Path, const std::string &DictPath) +{ + Initialize(Path, DictPath); +} + +void BERT::Initialize(const std::string &Path, const std::string &DictPath) +{ + Model = torch::jit::load(Path); + + +} diff --git a/bert.h b/bert.h new file mode 100644 index 0000000..98ab2fe --- /dev/null +++ b/bert.h @@ -0,0 +1,17 @@ +#ifndef BERT_H +#define BERT_H + +#include "VoxCommon.hpp" +// BERT: Class for inference of TorchScript-exported BERT. +class BERT +{ +private: + torch::jit::script::Module Model; + +public: + BERT(); + BERT(const std::string& Path,const std::string& DictPath); + void Initialize(const std::string& Path,const std::string& DictPath); +}; + +#endif // BERT_H diff --git a/berttokenizer.cpp b/berttokenizer.cpp new file mode 100644 index 0000000..94a3fc7 --- /dev/null +++ b/berttokenizer.cpp @@ -0,0 +1,286 @@ +#include "berttokenizer.h" +const std::wstring stripChar = L" \t\n\r\v\f"; +#include "utf8proc.h" +#include "ext/ZCharScannerWide.h" + +const std::vector sDelimChar = {L" ",L"\t",L"\n",L"\r",L"\v",L"\f"}; + +static std::string normalize_nfd(const std::string& s) { + std::string ret; + char *result = (char *) utf8proc_NFD((unsigned char *)s.c_str()); + if (result) { + ret = std::string(result); + free(result); + result = NULL; + } + return ret; +} + +static bool isStripChar(const wchar_t& ch) { + return stripChar.find(ch) != std::wstring::npos; +} + +static std::wstring strip(const std::wstring& text) { + std::wstring ret = text; + if (ret.empty()) return ret; + size_t pos = 0; + while (pos < ret.size() && isStripChar(ret[pos])) pos++; + if (pos != 0) ret = ret.substr(pos, ret.size() - pos); + pos = ret.size() - 1; + while (pos != (size_t)-1 && isStripChar(ret[pos])) pos--; + return ret.substr(0, pos + 1); +} + +static std::vector split(const std::wstring& text) { + std::vector result; + ZStringDelimiter Del(text); + Del.SetDelimiters(sDelimChar); + + result = Del.GetTokens(); + return result; +} + +static std::vector whitespaceTokenize(const std::wstring& text) { + std::wstring rtext = strip(text); + if (rtext.empty()) return std::vector(); + return split(text); +} + +static std::wstring convertToUnicode(const std::string& text) { + size_t i = 0; + std::wstring ret; + while (i < text.size()) { + wchar_t codepoint; + utf8proc_ssize_t forward = utf8proc_iterate((utf8proc_uint8_t *)&text[i], text.size() - i, (utf8proc_int32_t*)&codepoint); + if (forward < 0) return L""; + ret += codepoint; + i += forward; + } + return ret; +} + +static std::string convertFromUnicode(const std::wstring& wText) { + char dst[64]; + std::string ret; + for (auto ch : wText) { + utf8proc_ssize_t num = utf8proc_encode_char(ch, (utf8proc_uint8_t *)dst); + if (num <= 0) return ""; + ret += std::string(dst, dst+num); + } + return ret; +} + +static std::wstring tolower(const std::wstring& s) { + std::wstring ret(s.size(), L' '); + for (size_t i = 0; i < s.size(); i++) { + ret[i] = utf8proc_tolower(s[i]); + } + return ret; +} + +static std::shared_ptr loadVocab(const std::string& vocabFile) { + std::shared_ptr vocab(new Vocab); + size_t index = 0; + std::ifstream ifs(vocabFile, std::ifstream::in); + std::string line; + while (getline(ifs, line)) { + std::wstring token = convertToUnicode(line); + if (token.empty()) break; + token = strip(token); + (*vocab)[token] = index; + index++; + } + return vocab; +} + +BasicTokenizer::BasicTokenizer(bool doLowerCase) + : mDoLowerCase(doLowerCase) { +} + +std::wstring BasicTokenizer::cleanText(const std::wstring& text) const { + std::wstring output; + for (const wchar_t& cp : text) { + if (cp == 0 || cp == 0xfffd || isControol(cp)) continue; + if (isWhitespace(cp)) output += L" "; + else output += cp; + } + return output; +} + +bool BasicTokenizer::isControol(const wchar_t& ch) const { + if (ch== L'\t' || ch== L'\n' || ch== L'\r') return false; + auto cat = utf8proc_category(ch); + if (cat == UTF8PROC_CATEGORY_CC || cat == UTF8PROC_CATEGORY_CF) return true; + return false; +} + +bool BasicTokenizer::isWhitespace(const wchar_t& ch) const { + if (ch== L' ' || ch== L'\t' || ch== L'\n' || ch== L'\r') return true; + auto cat = utf8proc_category(ch); + if (cat == UTF8PROC_CATEGORY_ZS) return true; + return false; +} + +bool BasicTokenizer::isPunctuation(const wchar_t& ch) const { + if ((ch >= 33 && ch <= 47) || (ch >= 58 && ch <= 64) || + (ch >= 91 && ch <= 96) || (ch >= 123 && ch <= 126)) return true; + auto cat = utf8proc_category(ch); + if (cat == UTF8PROC_CATEGORY_PD || cat == UTF8PROC_CATEGORY_PS + || cat == UTF8PROC_CATEGORY_PE || cat == UTF8PROC_CATEGORY_PC + || cat == UTF8PROC_CATEGORY_PO //sometimes ΒΆ belong SO + || cat == UTF8PROC_CATEGORY_PI + || cat == UTF8PROC_CATEGORY_PF) return true; + return false; +} + +bool BasicTokenizer::isChineseChar(const wchar_t& ch) const { + if ((ch >= 0x4E00 && ch <= 0x9FFF) || + (ch >= 0x3400 && ch <= 0x4DBF) || + (ch >= 0x20000 && ch <= 0x2A6DF) || + (ch >= 0x2A700 && ch <= 0x2B73F) || + (ch >= 0x2B740 && ch <= 0x2B81F) || + (ch >= 0x2B820 && ch <= 0x2CEAF) || + (ch >= 0xF900 && ch <= 0xFAFF) || + (ch >= 0x2F800 && ch <= 0x2FA1F)) + return true; + return false; +} + +std::wstring BasicTokenizer::tokenizeChineseChars(const std::wstring& text) const { + std::wstring output; + for (auto& ch : text) { + if (isChineseChar(ch)) { + output += L' '; + output += ch; + output += L' '; + } + else + output += ch; + } + return output; +} + +std::wstring BasicTokenizer::runStripAccents(const std::wstring& text) const { + //Strips accents from a piece of text. + std::wstring nText; + try { + nText = convertToUnicode(normalize_nfd(convertFromUnicode(text))); + } catch (std::bad_cast& e) { + std::cerr << "bad_cast" << std::endl; + return L""; + } + + std::wstring output; + for (auto& ch : nText) { + auto cat = utf8proc_category(ch); + if (cat == UTF8PROC_CATEGORY_MN) continue; + output += ch; + } + return output; +} + +std::vector BasicTokenizer::runSplitOnPunc(const std::wstring& text) const { + size_t i = 0; + bool startNewWord = true; + std::vector output; + while (i < text.size()) { + wchar_t ch = text[i]; + if (isPunctuation(ch)) { + output.push_back(std::wstring(&ch, 1)); + startNewWord = true; + } + else { + if (startNewWord) output.push_back(std::wstring()); + startNewWord = false; + output[output.size() - 1] += ch; + } + i++; + } + return output; +} + +std::vector BasicTokenizer::tokenize(const std::string& text) const { + std::wstring nText = convertToUnicode(text); + nText = cleanText(nText); + + nText = tokenizeChineseChars(nText); + + const std::vector& origTokens = whitespaceTokenize(nText); + std::vector splitTokens; + for (std::wstring token : origTokens) { + if (mDoLowerCase) { + token = tolower(token); + token = runStripAccents(token); + } + const auto& tokens = runSplitOnPunc(token); + splitTokens.insert(splitTokens.end(), tokens.begin(), tokens.end()); + } + ZStringDelimiter DelRs; + + return whitespaceTokenize(DelRs.Reassemble(L" ",splitTokens)); +} + +WordpieceTokenizer::WordpieceTokenizer(const std::shared_ptr vocab, const std::wstring& unkToken, size_t maxInputCharsPerWord) + : mVocab(vocab), + mUnkToken(unkToken), + mMaxInputCharsPerWord(maxInputCharsPerWord) { +} + +std::vector WordpieceTokenizer::tokenize(const std::wstring& text) const { + std::vector outputTokens; + for (auto& token : whitespaceTokenize(text)) { + if (token.size() > mMaxInputCharsPerWord) { + outputTokens.push_back(mUnkToken); + } + bool isBad = false; + size_t start = 0; + std::vector subTokens; + while (start < token.size()) { + size_t end = token.size(); + std::wstring curSubstr; + bool hasCurSubstr = false; + while (start < end) { + std::wstring substr = token.substr(start, end - start); + if (start > 0) substr = L"##" + substr; + if (mVocab->find(substr) != mVocab->end()) { + curSubstr = substr; + hasCurSubstr = true; + break; + } + end--; + } + if (!hasCurSubstr) { + isBad = true; + break; + } + subTokens.push_back(curSubstr); + start = end; + } + if (isBad) outputTokens.push_back(mUnkToken); + else outputTokens.insert(outputTokens.end(), subTokens.begin(), subTokens.end()); + } + return outputTokens; +} + +FullTokenizer::FullTokenizer(const std::string& vocabFile, bool doLowerCase) : + mVocab(loadVocab(vocabFile)), + mBasicTokenizer(BasicTokenizer(doLowerCase)), + mWordpieceTokenizer(WordpieceTokenizer(mVocab)) { + for (auto& v : *mVocab) mInvVocab[v.second] = v.first; +} + +std::vector FullTokenizer::tokenize(const std::string& text) const { + std::vector splitTokens; + for (auto& token : mBasicTokenizer.tokenize(text)) + for (auto& subToken : mWordpieceTokenizer.tokenize(token)) + splitTokens.push_back(subToken); + return splitTokens; +} + +std::vector FullTokenizer::convertTokensToIds(const std::vector& text) const { + std::vector ret(text.size()); + for (size_t i = 0; i < text.size(); i++) { + ret[i] = (*mVocab)[text[i]]; + } + return ret; +} diff --git a/berttokenizer.h b/berttokenizer.h new file mode 100644 index 0000000..58cf6ea --- /dev/null +++ b/berttokenizer.h @@ -0,0 +1,64 @@ +#ifndef BERTTOKENIZER_H +#define BERTTOKENIZER_H +// https://gist.github.com/luistung/ace4888cf5fd1bad07844021cb2c7ecf + +#include +#include +#include +#include +#include + +using Vocab = std::unordered_map; +using InvVocab = std::unordered_map; + +class BasicTokenizer { +public: + BasicTokenizer(bool doLowerCase); + std::vector tokenize(const std::string& text) const; + +private: + std::wstring cleanText(const std::wstring& text) const; + bool isControol(const wchar_t& ch) const; + bool isWhitespace(const wchar_t& ch) const; + bool isPunctuation(const wchar_t& ch) const; + bool isChineseChar(const wchar_t& ch) const; + std::wstring tokenizeChineseChars(const std::wstring& text) const; + bool isStripChar(const wchar_t& ch) const; + std::wstring strip(const std::wstring& text) const; + std::vector split(const std::wstring& text) const; + std::wstring runStripAccents(const std::wstring& text) const; + std::vector runSplitOnPunc(const std::wstring& text) const; + + bool mDoLowerCase; +}; + +class WordpieceTokenizer { +public: + WordpieceTokenizer(std::shared_ptr vocab, const std::wstring& unkToken = L"[UNK]", size_t maxInputCharsPerWord=200); + std::vector tokenize(const std::wstring& text) const; + +private: + std::shared_ptr mVocab; + std::wstring mUnkToken; + size_t mMaxInputCharsPerWord; +}; + +class FullTokenizer { +public: + FullTokenizer(const std::string& vocabFile, bool doLowerCase = true); + std::vector tokenize(const std::string& text) const; + std::vector convertTokensToIds(const std::vector& text) const; + +private: + std::shared_ptr mVocab; + InvVocab mInvVocab; + std::string mVocabFile; + bool mDoLowerCase; + BasicTokenizer mBasicTokenizer; + WordpieceTokenizer mWordpieceTokenizer; +}; + + + + +#endif // BERTTOKENIZER_H diff --git a/ext/ZCharScanner.cpp b/ext/ZCharScanner.cpp index a281db8..06fcdaf 100644 --- a/ext/ZCharScanner.cpp +++ b/ext/ZCharScanner.cpp @@ -198,6 +198,13 @@ void ZStringDelimiter::AddDelimiter(const GString & in_Delim) } +void ZStringDelimiter::SetDelimiters(const std::vector &Delims) +{ + m_vDelimiters.assign(Delims.begin(),Delims.end()); + UpdateTokens(); + +} + ZStringDelimiter::~ZStringDelimiter() { } diff --git a/ext/ZCharScanner.h b/ext/ZCharScanner.h index a7fef3c..9c01ddc 100644 --- a/ext/ZCharScanner.h +++ b/ext/ZCharScanner.h @@ -5,11 +5,10 @@ #include #include -#define ZSDEL_USE_STD_STRING -#ifndef ZSDEL_USE_STD_STRING -#include "golem_string.h" -#else +#ifndef ZSDEL_USE_WSTRING #define GString std::string +#else +#define GString std::wstring #endif typedef std::vector::const_iterator TokenIterator; @@ -73,6 +72,7 @@ class ZStringDelimiter UpdateTokens(); } void AddDelimiter(const GString& in_Delim); + void SetDelimiters(const std::vector& Delims); ~ZStringDelimiter(); }; diff --git a/ext/ZCharScannerWide.cpp b/ext/ZCharScannerWide.cpp new file mode 100644 index 0000000..967918d --- /dev/null +++ b/ext/ZCharScannerWide.cpp @@ -0,0 +1,210 @@ +#include "ZCharScannerWide.h" +using namespace std; +#include + +int ZStringDelimiter::key_search(const std::wstring& s, const std::wstring& key) +{ + int count = 0; + size_t pos = 0; + while ((pos = s.find(key, pos)) != std::wstring::npos) { + ++count; + ++pos; + } + return count; +} +void ZStringDelimiter::UpdateTokens() +{ + if (!m_vDelimiters.size() || m_sString == L"") + return; + + m_vTokens.clear(); + + + vector::iterator dIt = m_vDelimiters.begin(); + while (dIt != m_vDelimiters.end()) + { + std::wstring delimiter = *dIt; + + + DelimStr(m_sString, delimiter, true); + + + ++dIt; + } + + + +} + + +void ZStringDelimiter::DelimStr(const std::wstring & s, const std::wstring & delimiter, const bool & removeEmptyEntries) +{ + BarRange(0, s.length()); + for (size_t start = 0, end; start < s.length(); start = end + delimiter.length()) + { + size_t position = s.find(delimiter, start); + end = position != std::wstring::npos ? position : s.length(); + + std::wstring token = s.substr(start, end - start); + if (!removeEmptyEntries || !token.empty()) + { + if (token != s) + m_vTokens.push_back(token); + + } + Bar(position); + } + + // dadwwdawdaawdwadwd +} + +void ZStringDelimiter::BarRange(const int & min, const int & max) +{ +#ifdef _AFX_ALL_WARNINGS + if (PgBar) + m_pBar->SetRange32(min, max); + + +#endif +} + +void ZStringDelimiter::Bar(const int & pos) +{ +#ifdef _AFX_ALL_WARNINGS + if (PgBar) + m_pBar->SetPos(pos); + + +#endif +} + +ZStringDelimiter::ZStringDelimiter() +{ + m_sString = ""; + tokenIndex = 0; + PgBar = false; +} + + +bool ZStringDelimiter::GetFirstToken(std::wstring & in_out) +{ + if (m_vTokens.size() >= 1) { + in_out = m_vTokens[0]; + return true; + } + else { + return false; + } +} + +bool ZStringDelimiter::GetNextToken(std::wstring & in_sOut) +{ + if (tokenIndex > m_vTokens.size() - 1) + return false; + + in_sOut = m_vTokens[tokenIndex]; + ++tokenIndex; + + return true; +} + +std::wstring ZStringDelimiter::operator[](const size_t & in_index) +{ + if (in_index > m_vTokens.size()) + throw std::out_of_range("ZStringDelimiter tried to access token higher than size"); + + return m_vTokens[in_index]; + +} +std::wstring ZStringDelimiter::Reassemble(const std::wstring& delim, const int& nelem) +{ + std::wstring Result = L""; + TokenIterator RasIt = m_vTokens.begin(); + int r = 0; + if (nelem == -1) { + while (RasIt != m_vTokens.end()) + { + + if (r != 0) + Result.append(delim); + + Result.append(*RasIt); + + ++r; + + + ++RasIt; + } + } + else { + while (RasIt != m_vTokens.end() && r < nelem) + { + + if (r != 0) + Result.append(delim); + + Result.append(*RasIt); + + ++r; + ++RasIt; + } + } + + return Result; + +} + +std::wstring ZStringDelimiter::Reassemble(const std::wstring & delim, const std::vector& Strs,int nelem) +{ + std::wstring Result = L""; + TokenIterator RasIt = Strs.begin(); + int r = 0; + if (nelem == -1) { + while (RasIt != Strs.end()) + { + + if (r != 0) + Result.append(delim); + + Result.append(*RasIt); + + ++r; + + + ++RasIt; + } + } + else { + while (RasIt != Strs.end() && r < nelem) + { + + if (r != 0) + Result.append(delim); + + Result.append(*RasIt); + + ++r; + ++RasIt; + } + } + + return Result; +} + +void ZStringDelimiter::AddDelimiter(const std::wstring & in_Delim) +{ + m_vDelimiters.push_back(in_Delim); + UpdateTokens(); + +} + +void ZStringDelimiter::SetDelimiters(const std::vector &Delims) +{ + m_vDelimiters.assign(Delims.begin(),Delims.end()); + UpdateTokens(); + +} + +ZStringDelimiter::~ZStringDelimiter() +{ +} diff --git a/ext/ZCharScannerWide.h b/ext/ZCharScannerWide.h new file mode 100644 index 0000000..ec7dd36 --- /dev/null +++ b/ext/ZCharScannerWide.h @@ -0,0 +1,74 @@ +#pragma once + +#define GBasicCharScanner ZStringDelimiter + +#include +#include + +// We need ZCharScanner but for wstrings. I copy class, fastest way +typedef std::vector::const_iterator TokenIterator; + +// ZStringDelimiter +// ============== +// Simple class to delimit and split strings. +// You can use operator[] to access them +// Or you can use the itBegin() and itEnd() to get some iterators +// ================= +class ZStringDelimiter +{ +private: + int key_search(const std::wstring & s, const std::wstring & key); + void UpdateTokens(); + std::vector m_vTokens; + std::vector m_vDelimiters; + + std::wstring m_sString; + + void DelimStr(const std::wstring& s, const std::wstring& delimiter, const bool& removeEmptyEntries = false); + void BarRange(const int& min, const int& max); + void Bar(const int& pos); + size_t tokenIndex; +public: + ZStringDelimiter(); + bool PgBar; + +#ifdef _AFX_ALL_WARNINGS + CProgressCtrl* m_pBar; +#endif + + ZStringDelimiter(const std::wstring& in_iStr) { + m_sString = in_iStr; + PgBar = false; + + } + + bool GetFirstToken(std::wstring& in_out); + bool GetNextToken(std::wstring& in_sOut); + + // std::String alts + + size_t szTokens() { return m_vTokens.size(); } + std::wstring operator[](const size_t& in_index); + + std::wstring Reassemble(const std::wstring & delim, const int & nelem = -1); + + // Override to reassemble provided tokens. + std::wstring Reassemble(const std::wstring & delim, const std::vector& Strs,int nelem = -1); + + // Get a const reference to the tokens + const std::vector& GetTokens() { return m_vTokens; } + + TokenIterator itBegin() { return m_vTokens.begin(); } + TokenIterator itEnd() { return m_vTokens.end(); } + + void SetText(const std::wstring& in_Txt) { + m_sString = in_Txt; + if (m_vDelimiters.size()) + UpdateTokens(); + } + void AddDelimiter(const std::wstring& in_Delim); + void SetDelimiters(const std::vector& Delims); + + ~ZStringDelimiter(); +}; + diff --git a/mainwindow.cpp b/mainwindow.cpp index baaf715..7da3084 100644 --- a/mainwindow.cpp +++ b/mainwindow.cpp @@ -1206,7 +1206,7 @@ void MainWindow::HandleIsMultiSpeaker(size_t inVid) ArchitectureInfo Inf = CurrentVoice.GetInfo().Architecture; - if (Inf.Text2Mel == EText2MelModel::FastSpeech2 || Inf.Text2Mel == EText2MelModel::VITS || Inf.Text2Mel == EText2MelModel::VITSTM) + if (Inf.Text2Mel == EText2MelModel::FastSpeech2 || Inf.Text2Mel == EText2MelModel::VITS || Inf.Text2Mel == EText2MelModel::DEVITS) { ui->grpFs2Params->show(); @@ -1269,7 +1269,7 @@ void MainWindow::HandleIsMultiEmotion(size_t inVid) Voice& CurrentVoice = *VoMan[inVid]; - const bool TorchMojiEnabled = CurrentVoice.GetInfo().Architecture.Text2Mel == EText2MelModel::VITSTM; + const bool TorchMojiEnabled = CurrentVoice.GetInfo().Architecture.Text2Mel == EText2MelModel::DEVITS; ui->lblEmotionOvr->setVisible(TorchMojiEnabled); ui->edtEmotionOvr->setVisible(TorchMojiEnabled); From cf790eccea997ca477ee04fd57aa4b80ad789e90 Mon Sep 17 00:00:00 2001 From: ZDisket <30500847+ZDisket@users.noreply.github.com> Date: Tue, 28 Feb 2023 18:24:42 -0300 Subject: [PATCH 3/6] bert, tm, wide zcharscanner fix --- TensorVox.pro | 2 ++ VoxCommon.hpp | 2 +- bert.cpp | 27 +++++++++++++++- bert.h | 11 +++++++ berttokenizer.cpp | 4 +-- devits.cpp | 70 ++++++++++++++++++++++++++++++++++++++++ devits.h | 24 ++++++++++++++ ext/ZCharScannerWide.cpp | 30 ++++++++--------- ext/ZCharScannerWide.h | 8 ++--- torchmoji.cpp | 3 +- torchmoji.h | 4 +-- vits.h | 7 ++-- 12 files changed, 164 insertions(+), 28 deletions(-) create mode 100644 devits.cpp create mode 100644 devits.h diff --git a/TensorVox.pro b/TensorVox.pro index 7732e1b..710baaa 100644 --- a/TensorVox.pro +++ b/TensorVox.pro @@ -27,6 +27,7 @@ SOURCES += \ batchdenoisedlg.cpp \ bert.cpp \ berttokenizer.cpp \ + devits.cpp \ espeakphonemizer.cpp \ ext/ByteArr.cpp \ ext/Qt-Frameless-Window-DarkStyle-master/DarkStyle.cpp \ @@ -66,6 +67,7 @@ HEADERS += \ batchdenoisedlg.h \ bert.h \ berttokenizer.h \ + devits.h \ espeakphonemizer.h \ ext/AudioFile.hpp \ ext/ByteArr.h \ diff --git a/VoxCommon.hpp b/VoxCommon.hpp index e418175..ac42c1d 100644 --- a/VoxCommon.hpp +++ b/VoxCommon.hpp @@ -167,7 +167,7 @@ namespace VoxUtil { // Copy PyTorch tensor template - TFTensor CopyTensor(at::Tensor& InTens){ + TFTensor CopyTensor(const at::Tensor& InTens){ D* Data = InTens.data(); std::vector Shape = InTens.sizes().vec(); diff --git a/bert.cpp b/bert.cpp index 97b0258..bec1f42 100644 --- a/bert.cpp +++ b/bert.cpp @@ -1,5 +1,4 @@ #include "bert.h" - BERT::BERT() { @@ -13,6 +12,32 @@ BERT::BERT(const std::string &Path, const std::string &DictPath) void BERT::Initialize(const std::string &Path, const std::string &DictPath) { Model = torch::jit::load(Path); + Tokenizer = std::make_unique(DictPath,true); + + +} + +std::pair, TFTensor > BERT::Infer(const std::string &InText) +{ + torch::NoGradGuard no_grad; + + auto Tokens = Tokenizer->tokenize(InText); + auto Ids = Tokenizer->convertTokensToIds(Tokens); + + std::vector InTokens(Ids.begin(),Ids.end()); + + auto InIDS = torch::tensor(InTokens).unsqueeze(0); // (1, tokens) + + auto Output = Model({InIDS}).toTuple(); // (hidden states, pooled) + + + std::pair,TFTensor> BERTOutputs; + BERTOutputs.first = VoxUtil::CopyTensor(Output.get()->elements()[0].toTensor()); + BERTOutputs.second = VoxUtil::CopyTensor(Output.get()->elements()[1].toTensor()); + + + return BERTOutputs; + } diff --git a/bert.h b/bert.h index 98ab2fe..92cabeb 100644 --- a/bert.h +++ b/bert.h @@ -2,16 +2,27 @@ #define BERT_H #include "VoxCommon.hpp" + +#include "berttokenizer.h" + // BERT: Class for inference of TorchScript-exported BERT. class BERT { private: torch::jit::script::Module Model; + std::unique_ptr Tokenizer; public: BERT(); BERT(const std::string& Path,const std::string& DictPath); void Initialize(const std::string& Path,const std::string& DictPath); + + + // Do inference on BERT model. + // Returns 2 tensors: + // [1, tokens, channels] : Hidden states + // [1, channels] : Pooled embeddings + std::pair,TFTensor> Infer(const std::string& InText); }; #endif // BERT_H diff --git a/berttokenizer.cpp b/berttokenizer.cpp index 94a3fc7..5cd2019 100644 --- a/berttokenizer.cpp +++ b/berttokenizer.cpp @@ -33,7 +33,7 @@ static std::wstring strip(const std::wstring& text) { static std::vector split(const std::wstring& text) { std::vector result; - ZStringDelimiter Del(text); + ZStringDelimiterWide Del(text); Del.SetDelimiters(sDelimChar); result = Del.GetTokens(); @@ -215,7 +215,7 @@ std::vector BasicTokenizer::tokenize(const std::string& text) cons const auto& tokens = runSplitOnPunc(token); splitTokens.insert(splitTokens.end(), tokens.begin(), tokens.end()); } - ZStringDelimiter DelRs; + ZStringDelimiterWide DelRs; return whitespaceTokenize(DelRs.Reassemble(L" ",splitTokens)); } diff --git a/devits.cpp b/devits.cpp new file mode 100644 index 0000000..7034259 --- /dev/null +++ b/devits.cpp @@ -0,0 +1,70 @@ +#include "devits.h" + +DEVITS::DEVITS() +{ + +} + +TFTensor DEVITS::DoInferenceDE(const std::vector &InputIDs, const TFTensor &MojiIn, const TFTensor &BERTIn, const std::vector &ArgsFloat, const std::vector ArgsInt, int32_t SpeakerID, int32_t EmotionID) +{ + // without this memory consumption is 4x + torch::NoGradGuard no_grad; + + + + std::vector PaddedIDs; + + + PaddedIDs = ZeroPadVec(InputIDs); + + + std::vector inLen = { (int64_t)PaddedIDs.size() }; + + + // ZDisket: Is this really necessary? + torch::TensorOptions Opts = torch::TensorOptions().requires_grad(false); + + auto InIDS = torch::tensor(PaddedIDs, Opts).unsqueeze(0); + auto InLens = torch::tensor(inLen, Opts); + auto MojiHidden = torch::tensor(MojiIn.Data); + auto BERTHidden = torch::tensor(BERTIn.Data).reshape(BERTIn.Shape); + + std::vector BERTSz = {BERTIn.Shape[1]}; + auto BERTLens = torch::tensor(BERTSz); + + auto InLenScale = torch::tensor({ ArgsFloat[0]}, Opts); + + + + std::vector inputs{ InIDS,InLens, MojiHidden, BERTHidden, BERTLens, InLenScale }; + + if (SpeakerID != -1){ + auto InSpkid = torch::tensor({SpeakerID},Opts); + inputs.push_back(InSpkid); + } + + + + + + // Infer + + c10::IValue Output = Model.get_method("infer_ts")(inputs); + + // Output = tuple (audio,att) + + auto OutputT = Output.toTuple(); + + // Grab audio + // [1, frames] -> [frames] + auto AuTens = OutputT.get()->elements()[0].toTensor().squeeze(); + + // Grab Attention + // [1, 1, x, y] -> [x, y] -> [y,x] -> [1, y, x] + auto AttTens = OutputT.get()->elements()[1].toTensor().squeeze().transpose(0,1).unsqueeze(0); + + Attention = VoxUtil::CopyTensor(AttTens); + + return VoxUtil::CopyTensor(AuTens); + +} diff --git a/devits.h b/devits.h new file mode 100644 index 0000000..0b7c51a --- /dev/null +++ b/devits.h @@ -0,0 +1,24 @@ +#ifndef DEVITS_H +#define DEVITS_H +#include "vits.h" + +class DEVITS : public VITS +{ +public: + DEVITS(); + + /* + Do inference on a DE-VITS model. + + -> InputIDs: Input IDs of tokens for inference + -> SpeakerID: ID of the speaker in the model to do inference on. If single speaker, always leave at 0. If multispeaker, refer to your model. + -> MojiIn: TorchMoji hidden states size [tm] + -> BERTIn: BERT hidden states size [1, n_tokens, channels] + -> ArgsFloat[0]: Length scale. + + <- Returns: TFTensor with shape {frames} of audio data + */ + TFTensor DoInferenceDE(const std::vector& InputIDs, const TFTensor& MojiIn, const TFTensor& BERTIn,const std::vector& ArgsFloat,const std::vector ArgsInt, int32_t SpeakerID = 0, int32_t EmotionID = -1); +}; + +#endif // DEVITS_H diff --git a/ext/ZCharScannerWide.cpp b/ext/ZCharScannerWide.cpp index 967918d..bcd22fb 100644 --- a/ext/ZCharScannerWide.cpp +++ b/ext/ZCharScannerWide.cpp @@ -2,7 +2,7 @@ using namespace std; #include -int ZStringDelimiter::key_search(const std::wstring& s, const std::wstring& key) +int ZStringDelimiterWide::key_search(const std::wstring& s, const std::wstring& key) { int count = 0; size_t pos = 0; @@ -12,7 +12,7 @@ int ZStringDelimiter::key_search(const std::wstring& s, const std::wstring& key) } return count; } -void ZStringDelimiter::UpdateTokens() +void ZStringDelimiterWide::UpdateTokens() { if (!m_vDelimiters.size() || m_sString == L"") return; @@ -37,7 +37,7 @@ void ZStringDelimiter::UpdateTokens() } -void ZStringDelimiter::DelimStr(const std::wstring & s, const std::wstring & delimiter, const bool & removeEmptyEntries) +void ZStringDelimiterWide::DelimStr(const std::wstring & s, const std::wstring & delimiter, const bool & removeEmptyEntries) { BarRange(0, s.length()); for (size_t start = 0, end; start < s.length(); start = end + delimiter.length()) @@ -58,7 +58,7 @@ void ZStringDelimiter::DelimStr(const std::wstring & s, const std::wstring & del // dadwwdawdaawdwadwd } -void ZStringDelimiter::BarRange(const int & min, const int & max) +void ZStringDelimiterWide::BarRange(const int & min, const int & max) { #ifdef _AFX_ALL_WARNINGS if (PgBar) @@ -68,7 +68,7 @@ void ZStringDelimiter::BarRange(const int & min, const int & max) #endif } -void ZStringDelimiter::Bar(const int & pos) +void ZStringDelimiterWide::Bar(const int & pos) { #ifdef _AFX_ALL_WARNINGS if (PgBar) @@ -78,15 +78,15 @@ void ZStringDelimiter::Bar(const int & pos) #endif } -ZStringDelimiter::ZStringDelimiter() +ZStringDelimiterWide::ZStringDelimiterWide() { - m_sString = ""; + m_sString = L""; tokenIndex = 0; PgBar = false; } -bool ZStringDelimiter::GetFirstToken(std::wstring & in_out) +bool ZStringDelimiterWide::GetFirstToken(std::wstring & in_out) { if (m_vTokens.size() >= 1) { in_out = m_vTokens[0]; @@ -97,7 +97,7 @@ bool ZStringDelimiter::GetFirstToken(std::wstring & in_out) } } -bool ZStringDelimiter::GetNextToken(std::wstring & in_sOut) +bool ZStringDelimiterWide::GetNextToken(std::wstring & in_sOut) { if (tokenIndex > m_vTokens.size() - 1) return false; @@ -108,7 +108,7 @@ bool ZStringDelimiter::GetNextToken(std::wstring & in_sOut) return true; } -std::wstring ZStringDelimiter::operator[](const size_t & in_index) +std::wstring ZStringDelimiterWide::operator[](const size_t & in_index) { if (in_index > m_vTokens.size()) throw std::out_of_range("ZStringDelimiter tried to access token higher than size"); @@ -116,7 +116,7 @@ std::wstring ZStringDelimiter::operator[](const size_t & in_index) return m_vTokens[in_index]; } -std::wstring ZStringDelimiter::Reassemble(const std::wstring& delim, const int& nelem) +std::wstring ZStringDelimiterWide::Reassemble(const std::wstring& delim, const int& nelem) { std::wstring Result = L""; TokenIterator RasIt = m_vTokens.begin(); @@ -154,7 +154,7 @@ std::wstring ZStringDelimiter::Reassemble(const std::wstring& delim, const int& } -std::wstring ZStringDelimiter::Reassemble(const std::wstring & delim, const std::vector& Strs,int nelem) +std::wstring ZStringDelimiterWide::Reassemble(const std::wstring & delim, const std::vector& Strs,int nelem) { std::wstring Result = L""; TokenIterator RasIt = Strs.begin(); @@ -191,20 +191,20 @@ std::wstring ZStringDelimiter::Reassemble(const std::wstring & delim, const std: return Result; } -void ZStringDelimiter::AddDelimiter(const std::wstring & in_Delim) +void ZStringDelimiterWide::AddDelimiter(const std::wstring & in_Delim) { m_vDelimiters.push_back(in_Delim); UpdateTokens(); } -void ZStringDelimiter::SetDelimiters(const std::vector &Delims) +void ZStringDelimiterWide::SetDelimiters(const std::vector &Delims) { m_vDelimiters.assign(Delims.begin(),Delims.end()); UpdateTokens(); } -ZStringDelimiter::~ZStringDelimiter() +ZStringDelimiterWide::~ZStringDelimiterWide() { } diff --git a/ext/ZCharScannerWide.h b/ext/ZCharScannerWide.h index ec7dd36..47dcad0 100644 --- a/ext/ZCharScannerWide.h +++ b/ext/ZCharScannerWide.h @@ -14,7 +14,7 @@ typedef std::vector::const_iterator TokenIterator; // You can use operator[] to access them // Or you can use the itBegin() and itEnd() to get some iterators // ================= -class ZStringDelimiter +class ZStringDelimiterWide { private: int key_search(const std::wstring & s, const std::wstring & key); @@ -29,14 +29,14 @@ class ZStringDelimiter void Bar(const int& pos); size_t tokenIndex; public: - ZStringDelimiter(); + ZStringDelimiterWide(); bool PgBar; #ifdef _AFX_ALL_WARNINGS CProgressCtrl* m_pBar; #endif - ZStringDelimiter(const std::wstring& in_iStr) { + ZStringDelimiterWide(const std::wstring& in_iStr) { m_sString = in_iStr; PgBar = false; @@ -69,6 +69,6 @@ class ZStringDelimiter void AddDelimiter(const std::wstring& in_Delim); void SetDelimiters(const std::vector& Delims); - ~ZStringDelimiter(); + ~ZStringDelimiterWide(); }; diff --git a/torchmoji.cpp b/torchmoji.cpp index c1f5b82..55b1161 100644 --- a/torchmoji.cpp +++ b/torchmoji.cpp @@ -67,6 +67,7 @@ void TorchMoji::Initialize(const std::string &Path, const std::string &DictPath) std::vector TorchMoji::Infer(const std::vector &Seq) { + torch::NoGradGuard no_grad; std::vector Input = WordsToIDs(Seq); auto InIDS = torch::tensor(Input).unsqueeze(0); // (1, TMLen) @@ -78,7 +79,7 @@ std::vector TorchMoji::Infer(const std::vector &Seq) TFTensor Tens = VoxUtil::CopyTensor(Output); - return Tens.Data; + return Tens; diff --git a/torchmoji.h b/torchmoji.h index 9bd0833..d10ddcc 100644 --- a/torchmoji.h +++ b/torchmoji.h @@ -25,8 +25,8 @@ class TorchMoji // Return hidden states of emotion state. // -> Seq: Vector of words - // <- Returns float vec of size VoxCommon::TorchMojiEmbSize containing hidden states, ready to feed into TTS model. - std::vector Infer(const std::vector& Seq); + // <- Returns float tensor of size VoxCommon::TorchMojiEmbSize containing hidden states, ready to feed into TTS model. + TFTensor Infer(const std::vector& Seq); }; #endif // TORCHMOJI_H diff --git a/vits.h b/vits.h index d2fc766..a7f1cee 100644 --- a/vits.h +++ b/vits.h @@ -12,12 +12,15 @@ class VITS : public MelGen { private: - torch::jit::script::Module Model; + + +public: + torch::jit::script::Module Model; // Most VITS model require zero-interspersed input IDs std::vector ZeroPadVec(const std::vector& InIDs); -public: + TFTensor Attention; VITS(); From fae4472b4616ac554c460f041fac29c7d4a0dc78 Mon Sep 17 00:00:00 2001 From: ZDisket <30500847+ZDisket@users.noreply.github.com> Date: Tue, 28 Feb 2023 18:40:52 -0300 Subject: [PATCH 4/6] final adds --- Voice.cpp | 41 +++++++++++++++++++++++++++++++---------- Voice.h | 4 ++++ torchmoji.cpp | 2 +- 3 files changed, 36 insertions(+), 11 deletions(-) diff --git a/Voice.cpp b/Voice.cpp index 0d0968e..37125af 100644 --- a/Voice.cpp +++ b/Voice.cpp @@ -126,8 +126,10 @@ Voice::Voice(const std::string & VoxPath, const std::string &inName, Phonemizer MelPredictor = std::make_unique(); else if (Tex2MelArch == EText2MelModel::FastSpeech2) MelPredictor = std::make_unique(); - else if (Tex2MelArch == EText2MelModel::VITS || Tex2MelArch == EText2MelModel::DEVITS) + else if (Tex2MelArch == EText2MelModel::VITS) MelPredictor = std::make_unique(); + else if (Tex2MelArch == EText2MelModel::DEVITS) + MelPredictor = std::make_unique(); else MelPredictor = std::make_unique(); @@ -144,8 +146,12 @@ Voice::Voice(const std::string & VoxPath, const std::string &inName, Phonemizer - if (Tex2MelArch == EText2MelModel::DEVITS) + if (Tex2MelArch == EText2MelModel::DEVITS){ Moji.Initialize(VoxPath + "/moji.pt", VoxPath + "/tm_dict.txt"); + BertFE.Initialize(VoxPath + "/bert.pt", VoxPath + "/bert_vocab.txt"); + + } + const int32_t VocoderArch = VoxInfo.Architecture.Vocoder; @@ -286,20 +292,36 @@ VoxResults Voice::Vocalize(const std::string & Prompt, float Speed, int32_t Spea Mel = ((FastSpeech2*)MelPredictor.get())->DoInference(InputIDs,FloatArgs,IntArgs,SpeakerID, EmotionID); - }else + }else if (Text2MelN == EText2MelModel::VITS) { FloatArgs = {Speed}; - if (EmotionOvr.size()){ - std::vector MojiInput = Processor.GetTokenizer().Tokenize(EmotionOvr,true,true); - std::vector MojiStates = Moji.Infer(MojiInput); + TFTensor Audio = MelPredictor.get()->DoInference(InputIDs,FloatArgs,IntArgs,SpeakerID,EmotionID); + Attention = ((VITS*)MelPredictor.get())->Attention; - FloatArgs.insert(FloatArgs.end(),MojiStates.begin(),MojiStates.end()); + std::vector AudioData = Audio.Data; - } + Mel.Shape.push_back(-1); // Tell the plotter that we have no mel to plot - TFTensor Audio = MelPredictor.get()->DoInference(InputIDs,FloatArgs,IntArgs,SpeakerID,EmotionID); + // As VITS is fully E2E, we return here + + return {AudioData,Attention,Mel}; + + }else // DE-VITS + { + FloatArgs = {Speed}; + std::vector MojiInput = Processor.GetTokenizer().Tokenize(EmotionOvr,true,true); + TFTensor MojiStates = Moji.Infer(MojiInput); + + auto BERTOutputs = BertFE.Infer(Prompt); + + + + + TFTensor Audio = ((DEVITS*)MelPredictor.get())->DoInferenceDE(InputIDs, MojiStates, + BERTOutputs.first,FloatArgs, + IntArgs,SpeakerID,EmotionID); Attention = ((VITS*)MelPredictor.get())->Attention; std::vector AudioData = Audio.Data; @@ -309,7 +331,6 @@ VoxResults Voice::Vocalize(const std::string & Prompt, float Speed, int32_t Spea // As VITS is fully E2E, we return here return {AudioData,Attention,Mel}; - } // Vocoder inference diff --git a/Voice.h b/Voice.h index 68bfd8f..6f3f1ea 100644 --- a/Voice.h +++ b/Voice.h @@ -10,6 +10,9 @@ #include "phoneticdict.h" #include "tacotron2torch.h" #include "istftnettorch.h" +#include "devits.h" +#include "bert.h" + struct VoxResults{ std::vector Audio; TFTensor Alignment; @@ -24,6 +27,7 @@ class Voice EnglishPhoneticProcessor Processor; VoiceInfo VoxInfo; TorchMoji Moji; + BERT BertFE; diff --git a/torchmoji.cpp b/torchmoji.cpp index 55b1161..fc90531 100644 --- a/torchmoji.cpp +++ b/torchmoji.cpp @@ -65,7 +65,7 @@ void TorchMoji::Initialize(const std::string &Path, const std::string &DictPath) LoadDict(DictPath); } -std::vector TorchMoji::Infer(const std::vector &Seq) +TFTensor TorchMoji::Infer(const std::vector &Seq) { torch::NoGradGuard no_grad; std::vector Input = WordsToIDs(Seq); From 3e51cf9cc6251b2ed727efd2923a7b810c0a3676 Mon Sep 17 00:00:00 2001 From: ZDisket <30500847+ZDisket@users.noreply.github.com> Date: Tue, 28 Feb 2023 19:35:17 -0300 Subject: [PATCH 5/6] final fixes to bert, model works now --- bert.cpp | 34 +++++++++++++++++++++++++++++----- berttokenizer.cpp | 19 +++++++++++++++++-- devits.cpp | 3 +-- 3 files changed, 47 insertions(+), 9 deletions(-) diff --git a/bert.cpp b/bert.cpp index bec1f42..9a43ae1 100644 --- a/bert.cpp +++ b/bert.cpp @@ -1,4 +1,6 @@ #include "bert.h" +#include + BERT::BERT() { @@ -21,19 +23,41 @@ std::pair, TFTensor > BERT::Infer(const std::string &InTe { torch::NoGradGuard no_grad; - auto Tokens = Tokenizer->tokenize(InText); + auto Tokens = Tokenizer->tokenize(InText + "\n"); auto Ids = Tokenizer->convertTokensToIds(Tokens); std::vector InTokens(Ids.begin(),Ids.end()); + + auto InIDS = torch::tensor(InTokens).unsqueeze(0); // (1, tokens) + std::pair,TFTensor> BERTOutputs; + + try{ + auto Output = Model({InIDS}).toTuple(); // (hidden states, pooled) + BERTOutputs.first = VoxUtil::CopyTensor(Output.get()->elements()[0].toTensor()); + BERTOutputs.second = VoxUtil::CopyTensor(Output.get()->elements()[1].toTensor()); + + + + } + + catch (const std::exception& e) { + int msgboxID = MessageBox( + NULL, + (LPCWSTR)QString::fromStdString(e.what()).toStdWString().c_str(), + (LPCWSTR)L"Error1!!", + MB_ICONWARNING | MB_CANCELTRYCONTINUE | MB_DEFBUTTON2 + ); + + + return BERTOutputs; + + } + - auto Output = Model({InIDS}).toTuple(); // (hidden states, pooled) - std::pair,TFTensor> BERTOutputs; - BERTOutputs.first = VoxUtil::CopyTensor(Output.get()->elements()[0].toTensor()); - BERTOutputs.second = VoxUtil::CopyTensor(Output.get()->elements()[1].toTensor()); return BERTOutputs; diff --git a/berttokenizer.cpp b/berttokenizer.cpp index 5cd2019..2befdb5 100644 --- a/berttokenizer.cpp +++ b/berttokenizer.cpp @@ -32,11 +32,20 @@ static std::wstring strip(const std::wstring& text) { } static std::vector split(const std::wstring& text) { + + + std::vector result; ZStringDelimiterWide Del(text); - Del.SetDelimiters(sDelimChar); + + Del.AddDelimiter(L" "); result = Del.GetTokens(); + if (!result.size()) + result.push_back(text); // + + + return result; } @@ -90,6 +99,7 @@ static std::shared_ptr loadVocab(const std::string& vocabFile) { (*vocab)[token] = index; index++; } + return vocab; } @@ -216,8 +226,12 @@ std::vector BasicTokenizer::tokenize(const std::string& text) cons splitTokens.insert(splitTokens.end(), tokens.begin(), tokens.end()); } ZStringDelimiterWide DelRs; + std::wstring WSP = DelRs.Reassemble(L" ",splitTokens); + - return whitespaceTokenize(DelRs.Reassemble(L" ",splitTokens)); + + + return whitespaceTokenize(WSP); } WordpieceTokenizer::WordpieceTokenizer(const std::shared_ptr vocab, const std::wstring& unkToken, size_t maxInputCharsPerWord) @@ -228,6 +242,7 @@ WordpieceTokenizer::WordpieceTokenizer(const std::shared_ptr vocab, const std::vector WordpieceTokenizer::tokenize(const std::wstring& text) const { std::vector outputTokens; + for (auto& token : whitespaceTokenize(text)) { if (token.size() > mMaxInputCharsPerWord) { outputTokens.push_back(mUnkToken); diff --git a/devits.cpp b/devits.cpp index 7034259..ea9e20f 100644 --- a/devits.cpp +++ b/devits.cpp @@ -1,5 +1,4 @@ #include "devits.h" - DEVITS::DEVITS() { @@ -26,7 +25,7 @@ TFTensor DEVITS::DoInferenceDE(const std::vector &InputIDs, cons auto InIDS = torch::tensor(PaddedIDs, Opts).unsqueeze(0); auto InLens = torch::tensor(inLen, Opts); - auto MojiHidden = torch::tensor(MojiIn.Data); + auto MojiHidden = torch::tensor(MojiIn.Data).unsqueeze(0); auto BERTHidden = torch::tensor(BERTIn.Data).reshape(BERTIn.Shape); std::vector BERTSz = {BERTIn.Shape[1]}; From dfb872a555e1c3b7004b8690dd61b700e896a76e Mon Sep 17 00:00:00 2001 From: ZDisket <30500847+ZDisket@users.noreply.github.com> Date: Tue, 28 Feb 2023 20:09:29 -0300 Subject: [PATCH 6/6] up max limit to 5k --- mainwindow.ui | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mainwindow.ui b/mainwindow.ui index 7d8d9e9..d69c0ae 100644 --- a/mainwindow.ui +++ b/mainwindow.ui @@ -7,7 +7,7 @@ 0 0 1047 - 526 + 538 @@ -324,7 +324,7 @@ 20 - 1000 + 5000 10