diff --git a/README.md b/README.md index b758dbe..3c64300 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,26 @@ sqlite> select simple_highlight(t1, 0, '[', ']') as text from t1 where text matc 5. simple_snippet() 实现截取 match 片段的功能,与 sqlite 自带的 snippet 功能类似,同样是增强连续 match 的词汇分到同一组的逻辑 6. jieba_query() 实现jieba分词的效果,在索引不变的情况下,可以实现更精准的匹配。可以通过 `-DSIMPLE_WITH_JIEBA=OFF ` 关掉结巴分词的功能 [#35](https://github.com/wangfenjin/simple/pull/35) 7. jieba_dict() 指定 dict 的目录,只需要调用一次,需要在调用 jieba_query() 之前指定。 +8. pinyin_dict() 支持指定自定义的 `pinyin.txt` 文件路径。调用成功后会立即切换拼音映射;如果文件格式不正确,会返回错误并保持当前映射不变。 + +### 自定义 pinyin.txt + +默认会使用内置在 so 中的 `contrib/pinyin.txt`。如果希望使用自己的拼音表,可以在查询前调用: + +```sql +select pinyin_dict('/path/to/pinyin.txt'); +``` + +`pinyin.txt` 每行格式与默认文件一致,例如: + +```text +U+3007: líng,yuán,xīng +U+3007: líng,yuán,xīng # 行尾注释也支持(前面需要空格) +``` + +注意: +- 建议在建索引和查询前先调用一次 `pinyin_dict()`。 +- 如果替换了拼音映射,已有索引中的拼音 token 不会自动重建,需要按你的业务策略重建索引。 ## 开发 diff --git a/contrib/pinyin-mini.txt b/contrib/pinyin-mini.txt new file mode 100644 index 0000000..720a792 --- /dev/null +++ b/contrib/pinyin-mini.txt @@ -0,0 +1,3 @@ +# demo custom pinyin dictionary for example.sql +U+5468: zhōu # 周 +U+4F26: lún # 伦 diff --git a/example.sql b/example.sql index f2d7cbf..84b3c01 100644 --- a/example.sql +++ b/example.sql @@ -3,6 +3,19 @@ -- load so file .load libsimple +select '自定义拼音词典示例(只包含 周/伦 的拼音):'; +-- 在本仓库里从 output/bin 运行本例时,路径可使用 ../../contrib/pinyin-mini.txt +select pinyin_dict('../../contrib/pinyin-mini.txt'); +CREATE VIRTUAL TABLE t0 USING fts5(x, tokenize = 'simple'); +insert into t0(x) values ('周杰伦'); +select ' 搜索 zhou,命中数量(预期 1):', count(*) from t0 where x match simple_query('zhou'); +select ' 搜索 lun,命中数量(预期 1):', count(*) from t0 where x match simple_query('lun'); +select ' 搜索 jie,命中数量(预期 0):', count(*) from t0 where x match simple_query('jie'); +select ' 搜索 zhou lun,命中数量(预期 1):', count(*) from t0 where x match simple_query('zhou lun'); +drop table t0; +-- 切回默认词典,保证下面示例行为不变 +select pinyin_dict('../../contrib/pinyin.txt'); + select '启用拼音分词:'; -- set tokenize to simple CREATE VIRTUAL TABLE t1 USING fts5(x, tokenize = 'simple'); diff --git a/src/entry.cc b/src/entry.cc index e203f89..c365eab 100644 --- a/src/entry.cc +++ b/src/entry.cc @@ -97,6 +97,23 @@ static void simple_query(sqlite3_context *pCtx, int nVal, sqlite3_value **apVal) sqlite3_result_null(pCtx); } +static void pinyin_dict(sqlite3_context *pCtx, int nVal, sqlite3_value **apVal) { + if (nVal >= 1) { + const char *text = (const char *)sqlite3_value_text(apVal[0]); + if (text) { + std::string err; + std::string path(text); + if (simple_tokenizer::SimpleTokenizer::set_pinyin_dict(path, err)) { + sqlite3_result_text(pCtx, path.c_str(), -1, SQLITE_TRANSIENT); + } else { + sqlite3_result_error(pCtx, err.c_str(), -1); + } + return; + } + } + sqlite3_result_null(pCtx); +} + int sqlite3_simple_init(sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi) { (void)pzErrMsg; int rc = SQLITE_OK; @@ -104,6 +121,8 @@ int sqlite3_simple_init(sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines rc = sqlite3_create_function(db, "simple_query", -1, SQLITE_UTF8 | SQLITE_DETERMINISTIC, NULL, &simple_query, NULL, NULL); + rc = sqlite3_create_function(db, "pinyin_dict", 1, SQLITE_UTF8 | SQLITE_DETERMINISTIC, NULL, &pinyin_dict, NULL, + NULL); #ifdef USE_JIEBA rc = sqlite3_create_function(db, "jieba_query", -1, SQLITE_UTF8 | SQLITE_DETERMINISTIC, NULL, &jieba_query, NULL, NULL); diff --git a/src/pinyin.cc b/src/pinyin.cc index e63ab86..47602c0 100644 --- a/src/pinyin.cc +++ b/src/pinyin.cc @@ -1,6 +1,7 @@ #include "pinyin.h" #include +#include #include #include #include @@ -12,7 +13,9 @@ CMRC_DECLARE(pinyin_text); namespace simple_tokenizer { -PinYin::PinYin() { pinyin = build_pinyin_map(); } +PinYin::PinYin() : PinYin("") {} + +PinYin::PinYin(const std::string &pinyin_file_path) { pinyin = build_pinyin_map(pinyin_file_path); } std::set PinYin::to_plain(const std::string &input) { std::set s; @@ -49,21 +52,45 @@ std::set PinYin::to_plain(const std::string &input) { } // clang-format off -std::map > PinYin::build_pinyin_map() { +std::map > PinYin::build_pinyin_map(const std::string &pinyin_file_path) { std::map > map; // clang-format on - auto fs = cmrc::pinyin_text::get_filesystem(); - auto pinyin_data = fs.open("contrib/pinyin.txt"); - std::istringstream pinyin_file(std::string(pinyin_data.begin(), pinyin_data.end())); + std::istringstream embedded_pinyin_file; + std::ifstream custom_pinyin_file; + std::istream *pinyin_file = nullptr; + if (pinyin_file_path.empty()) { + auto fs = cmrc::pinyin_text::get_filesystem(); + auto pinyin_data = fs.open("contrib/pinyin.txt"); + embedded_pinyin_file = std::istringstream(std::string(pinyin_data.begin(), pinyin_data.end())); + pinyin_file = &embedded_pinyin_file; + } else { + custom_pinyin_file.open(pinyin_file_path); + if (!custom_pinyin_file.is_open()) { + throw std::runtime_error("failed to open pinyin file: " + pinyin_file_path); + } + pinyin_file = &custom_pinyin_file; + } std::string line; char delimiter = ' '; std::string cp, py; - while (std::getline(pinyin_file, line)) { + int line_no = 0; + while (std::getline(*pinyin_file, line)) { + ++line_no; if (line.length() == 0 || line[0] == '#') continue; std::stringstream tokenStream(line); std::getline(tokenStream, cp, delimiter); std::getline(tokenStream, py, delimiter); - int codepoint = static_cast(std::stoul(cp.substr(2, cp.length() - 3), 0, 16l)); + if (cp.length() < 4 || cp.rfind("U+", 0) != 0 || cp.back() != ':' || py.empty()) { + throw std::runtime_error("invalid pinyin format at line " + std::to_string(line_no)); + } + + int codepoint = 0; + try { + codepoint = static_cast(std::stoul(cp.substr(2, cp.length() - 3), 0, 16l)); + } catch (const std::exception &) { + throw std::runtime_error("invalid pinyin codepoint at line " + std::to_string(line_no)); + } + std::set s = to_plain(py); std::vector m(s.size()); std::copy(s.begin(), s.end(), m.begin()); diff --git a/src/pinyin.h b/src/pinyin.h index 514c226..d7317bc 100644 --- a/src/pinyin.h +++ b/src/pinyin.h @@ -106,7 +106,7 @@ class PinYin { }; // clang-format on std::set to_plain(const std::string &input); - std::map > build_pinyin_map(); + std::map > build_pinyin_map(const std::string &pinyin_file_path); static int codepoint(const std::string &u); std::vector _split_pinyin(const std::string &input, int begin, int end); @@ -115,6 +115,7 @@ class PinYin { static int get_str_len(unsigned char byte); std::set split_pinyin(const std::string &input); PinYin(); + explicit PinYin(const std::string &pinyin_file_path); }; } // namespace simple_tokenizer diff --git a/src/simple_tokenizer.cc b/src/simple_tokenizer.cc index 3d7bb7d..22f2647 100644 --- a/src/simple_tokenizer.cc +++ b/src/simple_tokenizer.cc @@ -3,20 +3,48 @@ #include #include #include +#include +#include #include #include #include namespace simple_tokenizer { +namespace { +std::mutex pinyin_mutex; +std::shared_ptr global_pinyin; +} + SimpleTokenizer::SimpleTokenizer(const char **azArg, int nArg) { if (nArg >= 1) { enable_pinyin = atoi(azArg[0]) != 0; } } -PinYin *SimpleTokenizer::get_pinyin() { - static auto *py = new PinYin(); - return py; +std::shared_ptr SimpleTokenizer::get_pinyin() { + std::lock_guard lock(pinyin_mutex); + if (global_pinyin == nullptr) { + global_pinyin = std::make_shared(); + } + return global_pinyin; +} + +bool SimpleTokenizer::set_pinyin_dict(const std::string &pinyin_file_path, std::string &err) { + std::shared_ptr new_pinyin; + try { + if (pinyin_file_path.empty()) { + new_pinyin = std::make_shared(); + } else { + new_pinyin = std::make_shared(pinyin_file_path); + } + } catch (const std::exception &e) { + err = e.what(); + return false; + } + + std::lock_guard lock(pinyin_mutex); + global_pinyin = new_pinyin; + return true; } static TokenCategory from_char(char c) { @@ -159,7 +187,8 @@ int SimpleTokenizer::tokenize(void *pCtx, int flags, const char *text, int textL rc = xToken(pCtx, 0, result.c_str(), (int)result.length(), start, index); if (enable_pinyin && category == TokenCategory::OTHER && (flags & FTS5_TOKENIZE_DOCUMENT)) { - const std::vector &pys = SimpleTokenizer::get_pinyin()->get_pinyin(result); + std::shared_ptr pinyin = SimpleTokenizer::get_pinyin(); + const std::vector &pys = pinyin->get_pinyin(result); for (const std::string &s : pys) { rc = xToken(pCtx, FTS5_TOKEN_COLOCATED, s.c_str(), (int)s.length(), start, index); } diff --git a/src/simple_tokenizer.h b/src/simple_tokenizer.h index bea697a..9c99422 100644 --- a/src/simple_tokenizer.h +++ b/src/simple_tokenizer.h @@ -26,13 +26,14 @@ enum class TokenCategory { class SimpleTokenizer { private: - static PinYin *get_pinyin(); + static std::shared_ptr get_pinyin(); bool enable_pinyin = true; public: SimpleTokenizer(const char **zaArg, int nArg); int tokenize(void *pCtx, int flags, const char *text, int textLen, xTokenFn xToken) const; static std::string tokenize_query(const char *text, int textLen, int flags = 1); + static bool set_pinyin_dict(const std::string &pinyin_file_path, std::string &err); #ifdef USE_JIEBA static std::string tokenize_jieba_query(const char *text, int textLen, int flags = 1); #endif diff --git a/test/pinyin_test.cc b/test/pinyin_test.cc index 2c74e4e..2a57025 100644 --- a/test/pinyin_test.cc +++ b/test/pinyin_test.cc @@ -1,5 +1,9 @@ #include "pinyin.h" +#include +#include +#include + #include "gtest/gtest.h" using namespace simple_tokenizer; @@ -15,3 +19,35 @@ TEST(simple, pinyin_split) { for (auto r : res) std::cout << r << "\t"; std::cout << std::endl; } + +TEST(simple, pinyin_custom_file) { + std::string path = "simple_custom_pinyin_test.txt"; + std::ofstream file(path); + ASSERT_TRUE(file.is_open()); + file << "# custom pinyin file\n"; + file << "U+6770: jié # trailing comment\n"; + file.close(); + + PinYin pinyin(path); + auto res = pinyin.get_pinyin("杰"); + ASSERT_EQ(res.size(), 2); + ASSERT_EQ(res[0], "j"); + ASSERT_EQ(res[1], "jie"); + std::remove(path.c_str()); +} + +TEST(simple, pinyin_invalid_custom_file) { + std::string path = "simple_invalid_pinyin_test.txt"; + std::ofstream file(path); + ASSERT_TRUE(file.is_open()); + file << "invalid line\n"; + file.close(); + + EXPECT_THROW( + { + PinYin pinyin(path); + (void)pinyin; + }, + std::runtime_error); + std::remove(path.c_str()); +}