From 9d72facaba807555c1f0ffe1704b26a450f62dc1 Mon Sep 17 00:00:00 2001 From: yujing Date: Fri, 28 Feb 2025 17:47:09 +0800 Subject: [PATCH 01/32] backup --- llm_web_kit/model/rule_based_safety_module.py | 109 +++++++++++++++ llm_web_kit/model/source_safety_detector.py | 37 +++++ .../model/test_rule_based_safety_module.py | 128 ++++++++++++++++++ 3 files changed, 274 insertions(+) create mode 100644 llm_web_kit/model/rule_based_safety_module.py create mode 100644 llm_web_kit/model/source_safety_detector.py create mode 100644 tests/llm_web_kit/model/test_rule_based_safety_module.py diff --git a/llm_web_kit/model/rule_based_safety_module.py b/llm_web_kit/model/rule_based_safety_module.py new file mode 100644 index 00000000..31312db7 --- /dev/null +++ b/llm_web_kit/model/rule_based_safety_module.py @@ -0,0 +1,109 @@ +from typing import Any, Type + +from llm_web_kit.model.unsafe_words_detector import UnsafeWordsFilter +from llm_web_kit.model.source_safety_detector import SourceFilter + + +def check_type(arg_name: str, arg_value: Any, arg_type: Type): + """check the type of the argument and raise TypeError if the type is not + matched.""" + if not isinstance(arg_value, arg_type): + # TODO change TypeError to custom exception + raise TypeError( + 'The type of {} should be {}, but got {}'.format( + arg_name, arg_type, type(arg_value) + ) + ) + + +class RuleBasedSafetyModuleDataPack: + """The data pack for the rule-based-safety module.""" + + def __init__( + self, + content_str: str, + language: str, + language_details: str, + content_style: str, + ): + + # the content of the dataset + check_type('content_str', content_str, str) + self.content_str = content_str + + # the language of the content + check_type('language', language, str) + self.language = language + + # the details of the language + check_type('language_details', language_details, str) + self.language_details = language_details + + # the content style of the content + check_type('content_style', content_style, str) + self.content_style = content_style + + # the flag of the processed data should be remained or not + self.clean_remained = True + # the details of the clean process + self.clean_infos = {} + + def set_process_result(self, clean_remained: bool, clean_infos: dict) -> None: + """set the process result of the clean module.""" + check_type('clean_remained', clean_remained, bool) + check_type('clean_infos', clean_infos, dict) + if clean_remained is False: + self.clean_remained = False + self.clean_infos.update(clean_infos) + + def get_output(self) -> dict: + """get the output of the data pack.""" + return { + 'clean_remained': self.clean_remained, + 'clean_infos': self.clean_infos, + } + + +class RuleBasedSafetyModule: + def __init__(self, prod: bool): + # when in production mode + # the process will return immediately when the data is not safe + self.prod = prod + self.source_filter = SourceFilter() + self.unsafe_words_filter = UnsafeWordsFilter() + + def process( + self, + content_str: str, + language: str, + language_details: str, + content_style: str, + ) -> dict: + """The process of the rule based safety.""" + data_pack = RuleBasedSafetyModuleDataPack( + content_str=content_str, + language=language, + language_details=language_details, + content_style=content_style, + ) + data_pack = self.process_core(data_pack) + return data_pack.get_output() + + def process_core(self, data_pack: RuleBasedSafetyModuleDataPack) -> RuleBasedSafetyModuleDataPack: + """The core process of the rule based safety.""" + content_str = data_pack.content_str + language = data_pack.language + language_details = data_pack.language_details + content_style = data_pack.content_style + source_type_dict=self.source_filter.filter(content_str, language, content_style) + from_safe_source=source_type_dict["from_safe_source"] + from_domestic_source=source_type_dict["from_domestic_source"] + remained, process_info = self.unsafe_words_filter.filter( + content_str, language, language_details, content_style,from_safe_source,from_domestic_source + ) + data_pack.set_process_result(remained, process_info) + return data_pack + + def get_version(self): + version_str = '1.0.0' + return version_str diff --git a/llm_web_kit/model/source_safety_detector.py b/llm_web_kit/model/source_safety_detector.py new file mode 100644 index 00000000..df933b4c --- /dev/null +++ b/llm_web_kit/model/source_safety_detector.py @@ -0,0 +1,37 @@ +class SourceFilter: + def __init__(self): + pass + + def filter( + self, content_str: str, language: str, language_details: str, content_style: str + ) -> dict: + """Predict the quality score of the content and filter out score below + the threshold First, check if the language and content_style are + supported Then get the quality model and threshold, and predict the + quality score of the content Finally, return the result of whether the + content should be filtered out. + + Args: + content_str (str): the content string + language (str): the language of the content + language_details (str): the details of the language + content_style (str): the content style of the content + + Raises: + TODO use custom exception instead of + ValueError: raise ValueError if the language and content_style are not supported + + Returns: + bool: True if the content should remain, False if the content should be filtered out + """ + # if not self.check_supported(language, content_style): + # # TODO move the exception to the upper level + # raise ValueError( + # f"Unsupport language '{language}' with content_style '{content_style}'" + # ) + # else: + # model, threshold = get_quality_model(language, content_style) + # prob = model.predict_with_content(content_str, content_style) + # return prob > threshold, {'quality_prob': prob} + + return {"from_safe_source":True,"from_domestic_source":True} diff --git a/tests/llm_web_kit/model/test_rule_based_safety_module.py b/tests/llm_web_kit/model/test_rule_based_safety_module.py new file mode 100644 index 00000000..1c32251d --- /dev/null +++ b/tests/llm_web_kit/model/test_rule_based_safety_module.py @@ -0,0 +1,128 @@ +import unittest +from unittest import TestCase +from unittest.mock import patch + +# 需要根据实际模块路径调整 +from llm_web_kit.model.rule_based_safety_module import (RuleBasedSafetyModule, RuleBasedSafetyModuleDataPack, + check_type) +from llm_web_kit.model.unsafe_words_detector import UnsafeWordsFilter +from llm_web_kit.model.source_safety_detector import SourceFilter + + +class TestCheckType(TestCase): + def test_type_checking(self): + """测试类型检查工具函数.""" + # 测试类型匹配的情况 + try: + check_type('test', 'string', str) + except TypeError: + self.fail('Valid type check failed') + + # 测试类型不匹配的情况 + with self.assertRaises(TypeError) as cm: + check_type('test', 123, str) + + expected_error = ( + "The type of test should be , but got " + ) + self.assertEqual(str(cm.exception), expected_error) + + +class TestRuleBasedSafetyModuleDataPack(TestCase): + def test_init_type_checks(self): + """测试初始化时的类型检查.""" + valid_args = { + 'content_str': 'test', + 'language': 'en', + 'language_details': 'details', + 'content_style': 'article', + } + + # 测试所有参数的正确类型 + try: + RuleBasedSafetyModuleDataPack(**valid_args) + except TypeError: + self.fail('Type check failed for valid types') + + # 逐个测试每个参数的类型错误 + for param in valid_args: + invalid_args = valid_args.copy() + invalid_args[param] = 123 # 故意设置错误类型 + with self.assertRaises(TypeError): + RuleBasedSafetyModuleDataPack(**invalid_args) + + def test_set_process_result(self): + """测试设置处理结果的功能.""" + data_pack = RuleBasedSafetyModuleDataPack('test', 'en', 'details', 'article') + + # 测试正确类型 + data_pack.set_process_result(False, {'key': 'value'}) + self.assertFalse(data_pack.clean_remained) + self.assertEqual(data_pack.clean_infos, {'key': 'value'}) + + # 测试错误类型 + with self.assertRaises(TypeError): + data_pack.set_process_result('not_bool', {'key': 'value'}) + + with self.assertRaises(TypeError): + data_pack.set_process_result(False, 'not_dict') + + def test_get_output(self): + """测试输出字典的生成.""" + data_pack = RuleBasedSafetyModuleDataPack('test', 'en', 'details', 'article') + data_pack.set_process_result(False, {'info': 'test'}) + + expected_output = {'clean_remained': False, 'clean_infos': {'info': 'test'}} + self.assertDictEqual(data_pack.get_output(), expected_output) + + +class TestRuleBasedSafetyModule(TestCase): + @patch.object(UnsafeWordsFilter, 'filter') + def test_process_core(self, mock_filter): + """测试核心处理流程.""" + # 设置模拟返回值 + mock_filter.return_value = (False, {'reason': 'test'}) + + # 初始化测试对象 + clean_module = RuleBasedSafetyModule(prod=True) + data_pack = RuleBasedSafetyModuleDataPack('test', 'en', 'details', 'article') + + # 执行核心处理 + result = clean_module.process_core(data_pack) + + # 验证过滤方法被正确调用 + mock_filter.assert_called_once_with('test', 'en', 'details', 'article') + # 验证处理结果设置正确 + self.assertFalse(result.clean_remained) + self.assertEqual(result.clean_infos, {'reason': 'test'}) + + @patch.object(UnsafeWordsFilter, 'filter') + def test_process_flow(self, mock_filter): + """测试完整处理流程.""" + mock_filter.return_value = (True, {'quality': 0.95}) + + clean_module = RuleBasedSafetyModule(prod=False) + result = clean_module.process( + content_str='content', + language='en', + language_details='details', + content_style='article', + ) + + expected_result = {'clean_remained': True, 'clean_infos': {'quality': 0.95}} + self.assertDictEqual(result, expected_result) + + def test_production_mode_effect(self): + """测试生产模式的影响.""" + # 根据实际业务逻辑补充测试 + # 当前代码中prod参数未实际使用,需要根据具体实现调整 + pass + + def test_get_version(self): + """测试版本获取.""" + clean_module = RuleBasedSafetyModule(prod=True) + self.assertEqual(clean_module.get_version(), '1.0.0') + + +if __name__ == '__main__': + unittest.main() From 5b12757908ae495f73f83af89b4fe41449b21242 Mon Sep 17 00:00:00 2001 From: yujing Date: Fri, 28 Feb 2025 20:12:23 +0800 Subject: [PATCH 02/32] update test_unsafe_words_detector --- llm_web_kit/model/domain_safety_detector.py | 13 + llm_web_kit/model/rule_based_safety_module.py | 80 ++++-- llm_web_kit/model/source_safety_detector.py | 40 +-- llm_web_kit/model/unsafe_words_detector.py | 270 +++++++++++------- .../model/test_unsafe_words_detector.py | 4 +- 5 files changed, 244 insertions(+), 163 deletions(-) create mode 100644 llm_web_kit/model/domain_safety_detector.py diff --git a/llm_web_kit/model/domain_safety_detector.py b/llm_web_kit/model/domain_safety_detector.py new file mode 100644 index 00000000..15ce2799 --- /dev/null +++ b/llm_web_kit/model/domain_safety_detector.py @@ -0,0 +1,13 @@ +class DomainFilter: + def __init__(self): + pass + + def filter( + self, + content_str: str, + language: str, + url: str, + language_details: str, + content_style: str, + ) -> dict: + return True, {} diff --git a/llm_web_kit/model/rule_based_safety_module.py b/llm_web_kit/model/rule_based_safety_module.py index 31312db7..3b2206dc 100644 --- a/llm_web_kit/model/rule_based_safety_module.py +++ b/llm_web_kit/model/rule_based_safety_module.py @@ -2,6 +2,7 @@ from llm_web_kit.model.unsafe_words_detector import UnsafeWordsFilter from llm_web_kit.model.source_safety_detector import SourceFilter +from llm_web_kit.model.domain_safety_detector import DomainFilter def check_type(arg_name: str, arg_value: Any, arg_type: Type): @@ -10,7 +11,7 @@ def check_type(arg_name: str, arg_value: Any, arg_type: Type): if not isinstance(arg_value, arg_type): # TODO change TypeError to custom exception raise TypeError( - 'The type of {} should be {}, but got {}'.format( + "The type of {} should be {}, but got {}".format( arg_name, arg_type, type(arg_value) ) ) @@ -25,42 +26,52 @@ def __init__( language: str, language_details: str, content_style: str, + url: str, + dataset_name: str, ): # the content of the dataset - check_type('content_str', content_str, str) + check_type("content_str", content_str, str) self.content_str = content_str # the language of the content - check_type('language', language, str) + check_type("language", language, str) self.language = language # the details of the language - check_type('language_details', language_details, str) + check_type("language_details", language_details, str) self.language_details = language_details # the content style of the content - check_type('content_style', content_style, str) + check_type("content_style", content_style, str) self.content_style = content_style + # the url of the content + check_type("url", url, str) + self.url = url + + # the data source of the content + check_type("dataset_name", dataset_name, str) + self.dataset_name = dataset_name + # the flag of the processed data should be remained or not - self.clean_remained = True + self.safety_remained = True # the details of the clean process - self.clean_infos = {} + self.safety_infos = {} - def set_process_result(self, clean_remained: bool, clean_infos: dict) -> None: - """set the process result of the clean module.""" - check_type('clean_remained', clean_remained, bool) - check_type('clean_infos', clean_infos, dict) - if clean_remained is False: - self.clean_remained = False - self.clean_infos.update(clean_infos) + def set_process_result(self, safety_remained: bool, safety_infos: dict) -> None: + """set the process result of the rule_based_safety module.""" + check_type("safety_remained", safety_remained, bool) + check_type("safety_infos", safety_infos, dict) + if safety_remained is False: + self.safety_remained = False + self.safety_infos.update(safety_infos) def get_output(self) -> dict: """get the output of the data pack.""" return { - 'clean_remained': self.clean_remained, - 'clean_infos': self.clean_infos, + "safety_remained": self.safety_remained, + "safety_infos": self.safety_infos, } @@ -69,6 +80,7 @@ def __init__(self, prod: bool): # when in production mode # the process will return immediately when the data is not safe self.prod = prod + self.domain_filter = DomainFilter() self.source_filter = SourceFilter() self.unsafe_words_filter = UnsafeWordsFilter() @@ -89,21 +101,41 @@ def process( data_pack = self.process_core(data_pack) return data_pack.get_output() - def process_core(self, data_pack: RuleBasedSafetyModuleDataPack) -> RuleBasedSafetyModuleDataPack: + def process_core( + self, data_pack: RuleBasedSafetyModuleDataPack + ) -> RuleBasedSafetyModuleDataPack: """The core process of the rule based safety.""" content_str = data_pack.content_str language = data_pack.language language_details = data_pack.language_details content_style = data_pack.content_style - source_type_dict=self.source_filter.filter(content_str, language, content_style) - from_safe_source=source_type_dict["from_safe_source"] - from_domestic_source=source_type_dict["from_domestic_source"] - remained, process_info = self.unsafe_words_filter.filter( - content_str, language, language_details, content_style,from_safe_source,from_domestic_source + url = data_pack.url + data_source = data_pack.dataset_name + + domain_safe_remained, domain_safe_info = self.domain_filter.filter( + content_str, language, url, language_details, content_style + ) + data_pack.set_process_result(domain_safe_remained, domain_safe_info) + if not domain_safe_remained and self.prod: + return data_pack + + source_type_dict = self.source_filter.filter( + content_str, language, data_source, content_style + ) + + from_safe_source = source_type_dict["from_safe_source"] + from_domestic_source = source_type_dict["from_domestic_source"] + unsafe_words_remained, process_info = self.unsafe_words_filter.filter( + content_str, + language, + language_details, + content_style, + from_safe_source, + from_domestic_source, ) - data_pack.set_process_result(remained, process_info) + data_pack.set_process_result(unsafe_words_remained, process_info) return data_pack def get_version(self): - version_str = '1.0.0' + version_str = "1.0.0" return version_str diff --git a/llm_web_kit/model/source_safety_detector.py b/llm_web_kit/model/source_safety_detector.py index df933b4c..961cc0a4 100644 --- a/llm_web_kit/model/source_safety_detector.py +++ b/llm_web_kit/model/source_safety_detector.py @@ -3,35 +3,11 @@ def __init__(self): pass def filter( - self, content_str: str, language: str, language_details: str, content_style: str - ) -> dict: - """Predict the quality score of the content and filter out score below - the threshold First, check if the language and content_style are - supported Then get the quality model and threshold, and predict the - quality score of the content Finally, return the result of whether the - content should be filtered out. - - Args: - content_str (str): the content string - language (str): the language of the content - language_details (str): the details of the language - content_style (str): the content style of the content - - Raises: - TODO use custom exception instead of - ValueError: raise ValueError if the language and content_style are not supported - - Returns: - bool: True if the content should remain, False if the content should be filtered out - """ - # if not self.check_supported(language, content_style): - # # TODO move the exception to the upper level - # raise ValueError( - # f"Unsupport language '{language}' with content_style '{content_style}'" - # ) - # else: - # model, threshold = get_quality_model(language, content_style) - # prob = model.predict_with_content(content_str, content_style) - # return prob > threshold, {'quality_prob': prob} - - return {"from_safe_source":True,"from_domestic_source":True} + self, + content_str: str, + language: str, + data_source: str, + language_details: str, + content_style: str, + ) -> dict: + return {"from_safe_source": False, "from_domestic_source": False} diff --git a/llm_web_kit/model/unsafe_words_detector.py b/llm_web_kit/model/unsafe_words_detector.py index 8e336d9f..e67c13d6 100644 --- a/llm_web_kit/model/unsafe_words_detector.py +++ b/llm_web_kit/model/unsafe_words_detector.py @@ -1,6 +1,6 @@ import os import time -from typing import Any, Dict +from typing import Any, Tuple, Dict import ahocorasick @@ -8,72 +8,74 @@ from llm_web_kit.exception.exception import SafeModelException from llm_web_kit.input.datajson import DataJson from llm_web_kit.libs.standard_utils import json_loads -from llm_web_kit.model.basic_functions.format_check import (is_en_letter, - is_pure_en_word) +from llm_web_kit.model.basic_functions.format_check import is_en_letter, is_pure_en_word from llm_web_kit.model.resource_utils.download_assets import ( - CACHE_DIR, download_auto_file) -from llm_web_kit.model.resource_utils.singleton_resource_manager import \ - singleton_resource_manager + CACHE_DIR, + download_auto_file, +) +from llm_web_kit.model.resource_utils.singleton_resource_manager import ( + singleton_resource_manager, +) xyz_language_lst = [ - 'ar', - 'cs', - 'hu', - 'sr', - 'ru', - 'ko', - 'vi', - 'th', - 'arb', - 'arb_Arab', - 'arb_Latn', - 'ces', - 'ces_Latn', - 'hun', - 'hun_Latn', - 'srp', - 'srp_Cyrl', - 'rus', - 'rus_Cyrl', - 'kor', - 'kor_Hang', - 'vie', - 'vie_Latn', - 'tha', - 'tha_Thai', + "ar", + "cs", + "hu", + "sr", + "ru", + "ko", + "vi", + "th", + "arb", + "arb_Arab", + "arb_Latn", + "ces", + "ces_Latn", + "hun", + "hun_Latn", + "srp", + "srp_Cyrl", + "rus", + "rus_Cyrl", + "kor", + "kor_Hang", + "vie", + "vie_Latn", + "tha", + "tha_Thai", ] level_score_map = { - 'L1': 100, - 'L2': 10, - 'L3': 1, - 'L4': 0.1, + "L1": 100, + "L2": 10, + "L3": 1, + "L4": 0.1, } -def auto_download(language='zh-en'): - resource_config = load_config()['resources'] - if language == 'zh-en': - resource_name = 'unsafe_words' - elif language == 'xyz': - resource_name = 'xyz_internal_unsafe_words' +def auto_download(language="zh-en"): + resource_config = load_config()["resources"] + if language == "zh-en": + resource_name = "unsafe_words" + elif language == "xyz": + resource_name = "xyz_internal_unsafe_words" else: - raise SafeModelException(f'Unsupported language: {language}') + raise SafeModelException(f"Unsupported language: {language}") language_unsafe_words_config: Dict = resource_config[resource_name] - download_path = language_unsafe_words_config['download_path'] - md5 = language_unsafe_words_config['md5'] + download_path = language_unsafe_words_config["download_path"] + md5 = language_unsafe_words_config["md5"] local_path = os.path.join(CACHE_DIR, resource_name) unsafe_words_file_path = download_auto_file(download_path, local_path, md5) return unsafe_words_file_path -def get_ac(language='zh-en'): +def get_ac(language="zh-en"): t1 = time.time() unsafe_words_file_path = auto_download(language) t2 = time.time() print( - f'-----------------auto_download cost time: {t2-t1} , language: {language}------------------' + f"-----------------auto_download cost time: {t2-t1} , language: {language}------------------" ) - with open(unsafe_words_file_path, 'r') as f: + with open(unsafe_words_file_path, "r") as f: lines = f.readlines() # sub_word: [{ @@ -88,27 +90,27 @@ def get_ac(language='zh-en'): words = {} for line in lines: w = json_loads(line) - word = str(w.get('word') or '').lower() + word = str(w.get("word") or "").lower() if not word: continue if is_pure_en_word(word) and len(word) <= 4: continue - sub_words = word.split('&&&') + sub_words = word.split("&&&") w_info = { - 'word': word, - 'sub_words': set(sub_words), - 'type': w.get('type'), - 'level': w.get('level'), - 'language': w.get('language'), - 'applicable': w.get('applicable'), - 'unapplicable': w.get('unapplicable'), + "word": word, + "sub_words": set(sub_words), + "type": w.get("type"), + "level": w.get("level"), + "language": w.get("language"), + "applicable": w.get("applicable"), + "unapplicable": w.get("unapplicable"), } for sub_word in sub_words: lst = words.get(sub_word, []) - lst.append({'sub_word': sub_word, **w_info}) + lst.append({"sub_word": sub_word, **w_info}) words[sub_word] = lst ac = ahocorasick.Automaton() @@ -139,7 +141,7 @@ def is_word_standalone(sub_word, end_pos): # 遍历所有匹配的子词及其结束位置pos for pos, w_info_lst in ac.iter(content): for w_info in w_info_lst: - sub_word = w_info['sub_word'] + sub_word = w_info["sub_word"] if is_word_standalone(sub_word, pos): all_sub_words.add(sub_word) all_w_info_lst.append(w_info) @@ -147,26 +149,26 @@ def is_word_standalone(sub_word, end_pos): unsafe_words = {} for w_info in all_w_info_lst: # 检查该词的所有子词是否均被匹配到 - if all_sub_words.issuperset(w_info['sub_words']): - if w_info['word'] not in unsafe_words: - unsafe_words[w_info['word']] = { - 'word': w_info['word'], - 'type': w_info['type'], - 'level': w_info['level'], - 'language': w_info['language'], - 'count': 0.0, + if all_sub_words.issuperset(w_info["sub_words"]): + if w_info["word"] not in unsafe_words: + unsafe_words[w_info["word"]] = { + "word": w_info["word"], + "type": w_info["type"], + "level": w_info["level"], + "language": w_info["language"], + "count": 0.0, } - unsafe_words[w_info['word']]['count'] += 1.0 / len(w_info['sub_words']) + unsafe_words[w_info["word"]]["count"] += 1.0 / len(w_info["sub_words"]) return list(unsafe_words.values()) class UnsafeWordChecker: - def __init__(self, language='zh-en') -> None: + def __init__(self, language="zh-en") -> None: t1 = time.time() self.ac = get_ac(language) t2 = time.time() print( - f'---------------UnsafeWordChecker init time: {t2-t1} , language: {language}-----------------' + f"---------------UnsafeWordChecker init time: {t2-t1} , language: {language}-----------------" ) def check_unsafe_words(self, content_str: str) -> list: @@ -174,13 +176,13 @@ def check_unsafe_words(self, content_str: str) -> list: return unsafe_words_list -def get_unsafe_words_checker(language='zh-en') -> UnsafeWordChecker: +def get_unsafe_words_checker(language="zh-en") -> UnsafeWordChecker: if not singleton_resource_manager.has_name(language): singleton_resource_manager.set_resource(language, UnsafeWordChecker(language)) return singleton_resource_manager.get_resource(language) -def decide_unsafe_word_by_data_checker( +def decide_data_unsafe_word_by_data_checker( data_dict: dict, unsafeWordChecker: UnsafeWordChecker ) -> str: data_obj = DataJson(data_dict) @@ -188,12 +190,28 @@ def decide_unsafe_word_by_data_checker( unsafe_words_list = unsafeWordChecker.check_unsafe_words(content_str=content_str) unsafe_word_levels = [] for w in unsafe_words_list: - _, level, _ = w['word'], w['level'], w['count'] + _, level, _ = w["word"], w["level"], w["count"] # "涉政|观测|L4|带头人" unsafe_word_levels.append(level) unsafe_word_levels = list(set(unsafe_word_levels)) - unsafe_word_min_level = min(unsafe_word_levels + ['NF']) + unsafe_word_min_level = min(unsafe_word_levels + ["NF"]) + + return unsafe_word_min_level + + +def decide_content_unsafe_word_by_data_checker( + content_str: str, unsafeWordChecker: UnsafeWordChecker +) -> str: + unsafe_words_list = unsafeWordChecker.check_unsafe_words(content_str=content_str) + unsafe_word_levels = [] + for w in unsafe_words_list: + _, level, _ = w["word"], w["level"], w["count"] + # "涉政|观测|L4|带头人" + unsafe_word_levels.append(level) + + unsafe_word_levels = list(set(unsafe_word_levels)) + unsafe_word_min_level = min(unsafe_word_levels + ["NF"]) return unsafe_word_min_level @@ -202,24 +220,24 @@ def unsafe_words_filter( data_dict: Dict[str, Any], language: str, content_style: str ) -> str: if language in xyz_language_lst: - language = 'xyz' + language = "xyz" elif language in [ - 'zh', - 'en', - 'yue', - 'zho', - 'eng', - 'zho_Hans', - 'zho_Hant', - 'yue_Hant', - 'eng_Latn', + "zh", + "en", + "yue", + "zho", + "eng", + "zho_Hans", + "zho_Hant", + "yue_Hant", + "eng_Latn", ]: - language = 'zh-en' + language = "zh-en" else: - raise SafeModelException(f'Unsupported language: {language}') + raise SafeModelException(f"Unsupported language: {language}") unsafeWordChecker = get_unsafe_words_checker(language) - unsafe_word_min_level = decide_unsafe_word_by_data_checker( + unsafe_word_min_level = decide_data_unsafe_word_by_data_checker( data_dict, unsafeWordChecker ) @@ -233,29 +251,71 @@ def unsafe_words_filter_overall( from_safe_source, from_domestic_source, ): - unsafe_word_min_level = unsafe_words_filter(data_dict, language, content_style) - if language in xyz_language_lst: - language = 'xyz' + language = "xyz" elif language in [ - 'zh', - 'en', - 'yue', - 'zho', - 'eng', - 'zho_Hans', - 'zho_Hant', - 'yue_Hant', - 'eng_Latn', + "zh", + "en", + "yue", + "zho", + "eng", + "zho_Hans", + "zho_Hant", + "yue_Hant", + "eng_Latn", ]: - language = 'zh-en' + language = "zh-en" else: - raise SafeModelException(f'Unsupported language: {language}') + raise SafeModelException(f"Unsupported language: {language}") if from_safe_source: - return {'hit_unsafe_words': False} + return {"hit_unsafe_words": False} if from_domestic_source: - unsafe_range = ('L1',) + unsafe_range = ("L1",) else: - unsafe_range = ('L1', 'L2') - hit = (unsafe_word_min_level in unsafe_range) - return {'hit_unsafe_words': hit} + unsafe_range = ("L1", "L2") + unsafe_word_min_level = unsafe_words_filter(data_dict, language, content_style) + hit = unsafe_word_min_level in unsafe_range + return {"hit_unsafe_words": hit} + + +class UnsafeWordsFilter: + def __init__(self): + pass + + def filter( + self, + content_str: str, + language: str, + language_details: str, + content_style: str, + from_safe_source: bool, + from_domestic_source: bool, + ) -> Tuple[bool, Dict[str, Any]]: + if language in xyz_language_lst: + language = "xyz" + elif language in [ + "zh", + "en", + "yue", + "zho", + "eng", + "zho_Hans", + "zho_Hant", + "yue_Hant", + "eng_Latn", + ]: + language = "zh-en" + else: + raise SafeModelException(f"Unsupported language: {language}") + + if from_safe_source: + return True, {"hit_unsafe_words": False} + if from_domestic_source: + unsafe_range = ("L1",) + else: + unsafe_range = ("L1", "L2") + unsafe_word_min_level = decide_content_unsafe_word_by_data_checker( + content_str, get_unsafe_words_checker(language) + ) + hit = unsafe_word_min_level in unsafe_range + return not hit, {"hit_unsafe_words": hit} diff --git a/tests/llm_web_kit/model/test_unsafe_words_detector.py b/tests/llm_web_kit/model/test_unsafe_words_detector.py index a26c608a..61af44de 100644 --- a/tests/llm_web_kit/model/test_unsafe_words_detector.py +++ b/tests/llm_web_kit/model/test_unsafe_words_detector.py @@ -3,7 +3,7 @@ from llm_web_kit.exception.exception import SafeModelException from llm_web_kit.model.unsafe_words_detector import ( - UnsafeWordChecker, auto_download, decide_unsafe_word_by_data_checker, + UnsafeWordChecker, auto_download, decide_data_unsafe_word_by_data_checker, get_ac, get_unsafe_words, get_unsafe_words_checker, unsafe_words_filter, unsafe_words_filter_overall) @@ -46,7 +46,7 @@ def test_decide_unsafe_word_by_data_checker(self, mock_get_checker): mock_get_checker.return_value = mock_checker data_dict = {'content': 'Some content with unsafe elements.'} - result = decide_unsafe_word_by_data_checker(data_dict, mock_checker) + result = decide_data_unsafe_word_by_data_checker(data_dict, mock_checker) self.assertEqual(result, 'L2') def test_standalone_word_detection(self): From c9853cabf3e6e4bf5a5e61fc1c909e06b05f227b Mon Sep 17 00:00:00 2001 From: yujing Date: Fri, 28 Feb 2025 20:20:07 +0800 Subject: [PATCH 03/32] update test_unsafe_words_detector --- llm_web_kit/model/unsafe_words_detector.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/llm_web_kit/model/unsafe_words_detector.py b/llm_web_kit/model/unsafe_words_detector.py index e67c13d6..2d0f3e0a 100644 --- a/llm_web_kit/model/unsafe_words_detector.py +++ b/llm_web_kit/model/unsafe_words_detector.py @@ -251,22 +251,6 @@ def unsafe_words_filter_overall( from_safe_source, from_domestic_source, ): - if language in xyz_language_lst: - language = "xyz" - elif language in [ - "zh", - "en", - "yue", - "zho", - "eng", - "zho_Hans", - "zho_Hant", - "yue_Hant", - "eng_Latn", - ]: - language = "zh-en" - else: - raise SafeModelException(f"Unsupported language: {language}") if from_safe_source: return {"hit_unsafe_words": False} if from_domestic_source: From 6e23df86a431cea274ed9a0fb5e28c546b42d5d2 Mon Sep 17 00:00:00 2001 From: yujing Date: Fri, 28 Feb 2025 20:25:17 +0800 Subject: [PATCH 04/32] update test_unsafe_words_detector --- tests/test_domain_safety_detector.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 tests/test_domain_safety_detector.py diff --git a/tests/test_domain_safety_detector.py b/tests/test_domain_safety_detector.py new file mode 100644 index 00000000..a606102a --- /dev/null +++ b/tests/test_domain_safety_detector.py @@ -0,0 +1,22 @@ +import unittest + +from llm_web_kit.model.domain_safety_detector import DomainFilter + +class TestDomainFilter(unittest.TestCase): + def setUp(self): + self.filter = DomainFilter() + + # 测试基础过滤逻辑 + def test_filter_basic_case(self): + result = self.filter.filter( + content_str="Valid content", + language="en", + url="https://example.com", + language_details="formal", + content_style="professional" + ) + self.assertTrue(result[0]) # 预期允许通过 + self.assertEqual(result[1], {}) # 预期无附加信息 + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 4102ca104da9c50f6c9c747ddb1086419810ec2d Mon Sep 17 00:00:00 2001 From: yujing Date: Fri, 28 Feb 2025 20:26:53 +0800 Subject: [PATCH 05/32] add test_domain_safety_detector --- tests/{ => llm_web_kit/model}/test_domain_safety_detector.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{ => llm_web_kit/model}/test_domain_safety_detector.py (100%) diff --git a/tests/test_domain_safety_detector.py b/tests/llm_web_kit/model/test_domain_safety_detector.py similarity index 100% rename from tests/test_domain_safety_detector.py rename to tests/llm_web_kit/model/test_domain_safety_detector.py From c93186d6955e8f93f775ebd56de4d34ab9f5d921 Mon Sep 17 00:00:00 2001 From: yujing Date: Fri, 28 Feb 2025 20:45:43 +0800 Subject: [PATCH 06/32] add test_domain_safety_detector --- llm_web_kit/model/source_safety_detector.py | 2 +- .../model/test_domain_safety_detector.py | 16 +++++++------ .../model/test_source_safety_detector.py | 24 +++++++++++++++++++ 3 files changed, 34 insertions(+), 8 deletions(-) create mode 100644 tests/llm_web_kit/model/test_source_safety_detector.py diff --git a/llm_web_kit/model/source_safety_detector.py b/llm_web_kit/model/source_safety_detector.py index 961cc0a4..7be51f09 100644 --- a/llm_web_kit/model/source_safety_detector.py +++ b/llm_web_kit/model/source_safety_detector.py @@ -10,4 +10,4 @@ def filter( language_details: str, content_style: str, ) -> dict: - return {"from_safe_source": False, "from_domestic_source": False} + return {'from_safe_source': False, 'from_domestic_source': False} diff --git a/tests/llm_web_kit/model/test_domain_safety_detector.py b/tests/llm_web_kit/model/test_domain_safety_detector.py index a606102a..d716dddb 100644 --- a/tests/llm_web_kit/model/test_domain_safety_detector.py +++ b/tests/llm_web_kit/model/test_domain_safety_detector.py @@ -2,6 +2,7 @@ from llm_web_kit.model.domain_safety_detector import DomainFilter + class TestDomainFilter(unittest.TestCase): def setUp(self): self.filter = DomainFilter() @@ -9,14 +10,15 @@ def setUp(self): # 测试基础过滤逻辑 def test_filter_basic_case(self): result = self.filter.filter( - content_str="Valid content", - language="en", - url="https://example.com", - language_details="formal", - content_style="professional" + content_str='Valid content', + language='en', + url='https://example.com', + language_details='formal', + content_style='professional' ) self.assertTrue(result[0]) # 预期允许通过 self.assertEqual(result[1], {}) # 预期无附加信息 -if __name__ == "__main__": - unittest.main() \ No newline at end of file + +if __name__ == '__main__': + unittest.main() diff --git a/tests/llm_web_kit/model/test_source_safety_detector.py b/tests/llm_web_kit/model/test_source_safety_detector.py new file mode 100644 index 00000000..355cd277 --- /dev/null +++ b/tests/llm_web_kit/model/test_source_safety_detector.py @@ -0,0 +1,24 @@ +import unittest + +from llm_web_kit.model.source_safety_detector import SourceFilter + + +class TestSourceFilter(unittest.TestCase): + def setUp(self): + self.filter = SourceFilter() + + # 测试非法来源 + def test_unsafe_source(self): + result = self.filter.filter( + content_str='Unverified data', + language='en', + data_source='http://unknown-site.com', + language_details='informal', + content_style='user-generated' + ) + self.assertFalse(result['from_safe_source']) + self.assertFalse(result['from_domestic_source']) + + +if __name__ == '__main__': + unittest.main() From 4af08dad3f12f2693480619ca913cfc646e5f872 Mon Sep 17 00:00:00 2001 From: qiujiantao Date: Fri, 28 Feb 2025 21:13:58 +0800 Subject: [PATCH 07/32] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E6=96=87?= =?UTF-8?q?=E4=BB=B6=E9=94=81=E5=AE=9A=E6=9C=BA=E5=88=B6=EF=BC=8C=E7=A1=AE?= =?UTF-8?q?=E4=BF=9D=E9=94=81=E6=96=87=E4=BB=B6=E5=9C=A8=E5=BC=82=E5=B8=B8?= =?UTF-8?q?=E6=83=85=E5=86=B5=E4=B8=8B=E8=A2=AB=E6=AD=A3=E7=A1=AE=E5=88=A0?= =?UTF-8?q?=E9=99=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llm_web_kit/model/resource_utils/download_assets.py | 3 ++- tests/llm_web_kit/model/resource_utils/test_unzip_ext.py | 3 --- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/llm_web_kit/model/resource_utils/download_assets.py b/llm_web_kit/model/resource_utils/download_assets.py index 42397f18..efa07619 100644 --- a/llm_web_kit/model/resource_utils/download_assets.py +++ b/llm_web_kit/model/resource_utils/download_assets.py @@ -155,9 +155,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): try: if self._fd: os.close(self._fd) - os.remove(self.lock_path) except OSError: pass + finally: + try_remove(self.lock_path) def verify_file_checksum( diff --git a/tests/llm_web_kit/model/resource_utils/test_unzip_ext.py b/tests/llm_web_kit/model/resource_utils/test_unzip_ext.py index 7df3a14d..39267ab1 100644 --- a/tests/llm_web_kit/model/resource_utils/test_unzip_ext.py +++ b/tests/llm_web_kit/model/resource_utils/test_unzip_ext.py @@ -28,14 +28,11 @@ def test_unzip_local_file(): with open(os.path.join(target_dir, 'test2.txt')) as f: assert f.read() == 'This is another test file' - os.remove(zip_path + '.lock') - unzip_local_file(zip_path, target_dir, exist_ok=True) with open(os.path.join(target_dir, 'test1.txt')) as f: assert f.read() == 'This is a test file' with open(os.path.join(target_dir, 'test2.txt')) as f: assert f.read() == 'This is another test file' - os.remove(zip_path + '.lock') try: unzip_local_file(zip_path, target_dir, exist_ok=False) except Exception as e: From 0a6f21d1d2a8397b9ba44ab6964d9b2308e8f186 Mon Sep 17 00:00:00 2001 From: qiujiantao Date: Mon, 3 Mar 2025 13:42:02 +0800 Subject: [PATCH 08/32] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=9F=BA?= =?UTF-8?q?=E4=BA=8E=E6=A8=A1=E5=9E=8B=E7=9A=84=E5=AE=89=E5=85=A8=E6=A8=A1?= =?UTF-8?q?=E5=9D=97=EF=BC=8C=E6=94=AF=E6=8C=81=E5=86=85=E5=AE=B9=E5=AE=89?= =?UTF-8?q?=E5=85=A8=E6=A3=80=E6=B5=8B=E5=92=8C=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/llm_web_kit/model/model_based_safety.md | 0 .../model/model_based_safety_module.py | 181 ++++++++++++++++++ llm_web_kit/model/policical.py | 2 +- 3 files changed, 182 insertions(+), 1 deletion(-) create mode 100644 docs/llm_web_kit/model/model_based_safety.md create mode 100644 llm_web_kit/model/model_based_safety_module.py diff --git a/docs/llm_web_kit/model/model_based_safety.md b/docs/llm_web_kit/model/model_based_safety.md new file mode 100644 index 00000000..e69de29b diff --git a/llm_web_kit/model/model_based_safety_module.py b/llm_web_kit/model/model_based_safety_module.py new file mode 100644 index 00000000..1edb96d1 --- /dev/null +++ b/llm_web_kit/model/model_based_safety_module.py @@ -0,0 +1,181 @@ +from typing import List, Tuple, Any, Type, TypeVar +from llm_web_kit.model.policical import PoliticalDetector, decide_political_by_prob +from llm_web_kit.model.porn_detector import BertModel as EnPornBertModel +from llm_web_kit.exception.exception import ModelInputException + +I = TypeVar("I") # input type +B = TypeVar("B") # batch type + + +def check_type(arg_name: str, arg_value: Any, arg_type: Type): + """check the type of the argument and raise TypeError if the type is not + matched.""" + if not isinstance(arg_value, arg_type): + # TODO change TypeError to custom exception + raise TypeError( + "The type of {} should be {}, but got {}".format( + arg_name, arg_type, type(arg_value) + ) + ) + + +class ModelBasedSafetyDataPack: + """The data pack for the model based safety module.""" + + def __init__(self, content_str: str, langurage: str, langurage_details: str): + + self._dict = {} + # the content of the dataset + check_type("content_str", content_str, str) + self._dict["content_str"] = content_str + + # the language of the content + check_type("langurage", langurage, str) + self._dict["langurage"] = langurage + + # the details of the language + check_type("langurage_details", langurage_details, str) + self._dict["langurage_details"] = langurage_details + + # the flag of the processed data should be remained or not + self._dict["model_based_safety_remained"] = True + + # the details of the model based safety process + self._dict["model_based_safety_infos"] = {} + + @classmethod + def from_dict(cls, data: dict): + new_data_pack = cls( + content_str=data["content_str"], + langurage=data["langurage"], + langurage_details=data["langurage_details"], + ) + new_data_pack._dict.update(data) + return new_data_pack + + def as_dict(self) -> dict: + return self._dict + + def set_process_result( + self, model_based_safety_remained: bool, model_based_safety_infos: dict + ) -> None: + """set the process result of the model based safety module.""" + check_type("model_based_safety_remained", model_based_safety_remained, bool) + check_type("model_based_safety_infos", model_based_safety_infos, dict) + if model_based_safety_remained is False: + self._dict["model_based_safety_remained"] = False + self._dict["model_based_safety_infos"].update(model_based_safety_infos) + + def get_output(self) -> dict: + """get the output of the data pack.""" + return { + "model_based_safety_remained": self._dict["model_based_safety_remained"], + "model_based_safety_infos": self._dict["model_based_safety_infos"], + } + + +class ContentStrBatchModel: + def __init__(self, model_config: dict): + self.model_config = model_config + + def check_support(self, data_pack: ModelBasedSafetyDataPack) -> bool: + raise NotImplementedError + + def preprocess(self, data_pack: ModelBasedSafetyDataPack) -> Tuple[dict, I]: + if not self.check_support(data_pack): + # use class name + model_name = self.__class__.__name__ + raise ModelInputException( + f"The data pack is not supported for {model_name}." + ) + return data_pack.as_dict(), data_pack._dict["content_str"] + + def collate_fn(self, lst: List[Tuple[dict, I]]) -> Tuple[List[dict], B]: + infos, batch = zip(*lst) + return list(infos), list(batch) + + def inference(self, batch: B) -> List[dict]: + """(batch: B) -> results""" + raise NotImplementedError() + + def postprocess(self, info: dict, result: dict) -> ModelBasedSafetyDataPack: + """(info: dict, result: dict) -> output""" + # return {**info, **result} + raise NotImplementedError() + + def process_one_core( + self, data_pack: ModelBasedSafetyDataPack + ) -> ModelBasedSafetyDataPack: + info, batch = self.preprocess(data_pack) + batch = self.collate_fn([(info, batch)])[1] + results = self.inference(batch) + return self.postprocess(info, results[0]) + + def process_one( + self, content_str: str, langurage: str, langurage_details: str + ) -> dict: + data_pack = ModelBasedSafetyDataPack(content_str, langurage, langurage_details) + return self.process_one_core(data_pack).get_output() + + +class ZhEnPoliticalModel(ContentStrBatchModel): + def __init__(self, model_config: dict): + super().__init__(model_config) + self.political_detect = PoliticalDetector( + model_path=model_config["model_path"], + ) + self.threshold = model_config["threshold"] + + def check_support(self, data_pack: ModelBasedSafetyDataPack) -> bool: + return data_pack.langurage in ["zh", "en"] + + def inference(self, batch: List[str]) -> List[dict]: + result_list = [] + for content_str in batch: + predictions, probabilities = self.political_detect.predict(content_str) + normal_score = decide_political_by_prob(predictions, probabilities) + result_list.append( + { + "political_prob": normal_score, + "political_info": { + "predictions": predictions, + "probabilities": probabilities, + }, + } + ) + return result_list + + def postprocess(self, info: dict, result: dict) -> dict: + remained = result["political_prob"] > self.threshold + datapack = ModelBasedSafetyDataPack.from_dict(info) + datapack.set_process_result( + model_based_safety_remained=remained, model_based_safety_infos=result + ) + return datapack + + +class EnPronModel(ContentStrBatchModel): + def __init__(self, model_config: dict): + super().__init__(model_config) + self.model = EnPornBertModel(model_config["model_path"]) + self.threshold = model_config["threshold"] + + def check_support(self, data_pack: ModelBasedSafetyDataPack) -> bool: + return data_pack.langurage == "en" + + def inference(self, batch: List[str]) -> List[dict]: + result_list = [] + for content_str in batch: + prob = self.model.predict(content_str) + result_list.append(prob) + return result_list + + def postprocess(self, info: dict, result: dict) -> dict: + porn_prob = list(result[0].values())[0] + remained = porn_prob < self.threshold + datapack = ModelBasedSafetyDataPack.from_dict(info) + datapack.set_process_result( + model_based_safety_remained=remained, + model_based_safety_infos={"porn_prob": porn_prob}, + ) + return datapack diff --git a/llm_web_kit/model/policical.py b/llm_web_kit/model/policical.py index a46e933d..e458da70 100644 --- a/llm_web_kit/model/policical.py +++ b/llm_web_kit/model/policical.py @@ -80,7 +80,7 @@ def get_singleton_political_detect() -> PoliticalDetector: def decide_political_by_prob(predictions: Tuple[str], probabilities: Tuple[float]) -> float: idx = predictions.index('__label__normal') normal_score = probabilities[idx] - return normal_score + return float(normal_score) def decide_political_func(content_str: str, political_detect: PoliticalDetector) -> float: From d201836810998852fbcb2b0df8a9ce30b7403c48 Mon Sep 17 00:00:00 2001 From: yujing Date: Mon, 3 Mar 2025 15:47:01 +0800 Subject: [PATCH 09/32] test rule-based-safety-module --- llm_web_kit/model/rule_based_safety_module.py | 36 +++++++++------- .../model/test_rule_based_safety_module.py | 43 ++++++++++++------- 2 files changed, 48 insertions(+), 31 deletions(-) diff --git a/llm_web_kit/model/rule_based_safety_module.py b/llm_web_kit/model/rule_based_safety_module.py index 3b2206dc..cb1c695d 100644 --- a/llm_web_kit/model/rule_based_safety_module.py +++ b/llm_web_kit/model/rule_based_safety_module.py @@ -1,8 +1,8 @@ from typing import Any, Type -from llm_web_kit.model.unsafe_words_detector import UnsafeWordsFilter -from llm_web_kit.model.source_safety_detector import SourceFilter from llm_web_kit.model.domain_safety_detector import DomainFilter +from llm_web_kit.model.source_safety_detector import SourceFilter +from llm_web_kit.model.unsafe_words_detector import UnsafeWordsFilter def check_type(arg_name: str, arg_value: Any, arg_type: Type): @@ -11,7 +11,7 @@ def check_type(arg_name: str, arg_value: Any, arg_type: Type): if not isinstance(arg_value, arg_type): # TODO change TypeError to custom exception raise TypeError( - "The type of {} should be {}, but got {}".format( + 'The type of {} should be {}, but got {}'.format( arg_name, arg_type, type(arg_value) ) ) @@ -31,27 +31,27 @@ def __init__( ): # the content of the dataset - check_type("content_str", content_str, str) + check_type('content_str', content_str, str) self.content_str = content_str # the language of the content - check_type("language", language, str) + check_type('language', language, str) self.language = language # the details of the language - check_type("language_details", language_details, str) + check_type('language_details', language_details, str) self.language_details = language_details # the content style of the content - check_type("content_style", content_style, str) + check_type('content_style', content_style, str) self.content_style = content_style # the url of the content - check_type("url", url, str) + check_type('url', url, str) self.url = url # the data source of the content - check_type("dataset_name", dataset_name, str) + check_type('dataset_name', dataset_name, str) self.dataset_name = dataset_name # the flag of the processed data should be remained or not @@ -61,8 +61,8 @@ def __init__( def set_process_result(self, safety_remained: bool, safety_infos: dict) -> None: """set the process result of the rule_based_safety module.""" - check_type("safety_remained", safety_remained, bool) - check_type("safety_infos", safety_infos, dict) + check_type('safety_remained', safety_remained, bool) + check_type('safety_infos', safety_infos, dict) if safety_remained is False: self.safety_remained = False self.safety_infos.update(safety_infos) @@ -70,8 +70,8 @@ def set_process_result(self, safety_remained: bool, safety_infos: dict) -> None: def get_output(self) -> dict: """get the output of the data pack.""" return { - "safety_remained": self.safety_remained, - "safety_infos": self.safety_infos, + 'safety_remained': self.safety_remained, + 'safety_infos': self.safety_infos, } @@ -90,6 +90,8 @@ def process( language: str, language_details: str, content_style: str, + url: str, + dataset_name: str, ) -> dict: """The process of the rule based safety.""" data_pack = RuleBasedSafetyModuleDataPack( @@ -97,6 +99,8 @@ def process( language=language, language_details=language_details, content_style=content_style, + url=url, + dataset_name=dataset_name, ) data_pack = self.process_core(data_pack) return data_pack.get_output() @@ -123,8 +127,8 @@ def process_core( content_str, language, data_source, content_style ) - from_safe_source = source_type_dict["from_safe_source"] - from_domestic_source = source_type_dict["from_domestic_source"] + from_safe_source = source_type_dict['from_safe_source'] + from_domestic_source = source_type_dict['from_domestic_source'] unsafe_words_remained, process_info = self.unsafe_words_filter.filter( content_str, language, @@ -137,5 +141,5 @@ def process_core( return data_pack def get_version(self): - version_str = "1.0.0" + version_str = '1.0.0' return version_str diff --git a/tests/llm_web_kit/model/test_rule_based_safety_module.py b/tests/llm_web_kit/model/test_rule_based_safety_module.py index 1c32251d..05aa517f 100644 --- a/tests/llm_web_kit/model/test_rule_based_safety_module.py +++ b/tests/llm_web_kit/model/test_rule_based_safety_module.py @@ -3,10 +3,10 @@ from unittest.mock import patch # 需要根据实际模块路径调整 -from llm_web_kit.model.rule_based_safety_module import (RuleBasedSafetyModule, RuleBasedSafetyModuleDataPack, - check_type) -from llm_web_kit.model.unsafe_words_detector import UnsafeWordsFilter +from llm_web_kit.model.rule_based_safety_module import ( + RuleBasedSafetyModule, RuleBasedSafetyModuleDataPack, check_type) from llm_web_kit.model.source_safety_detector import SourceFilter +from llm_web_kit.model.unsafe_words_detector import UnsafeWordsFilter class TestCheckType(TestCase): @@ -36,6 +36,8 @@ def test_init_type_checks(self): 'language': 'en', 'language_details': 'details', 'content_style': 'article', + 'url': 'http://test.com', + 'dataset_name': 'test_dataset', } # 测试所有参数的正确类型 @@ -53,7 +55,7 @@ def test_init_type_checks(self): def test_set_process_result(self): """测试设置处理结果的功能.""" - data_pack = RuleBasedSafetyModuleDataPack('test', 'en', 'details', 'article') + data_pack = RuleBasedSafetyModuleDataPack('test', 'en', 'details', 'article','http://test.com', 'test_dataset') # 测试正确类型 data_pack.set_process_result(False, {'key': 'value'}) @@ -69,37 +71,46 @@ def test_set_process_result(self): def test_get_output(self): """测试输出字典的生成.""" - data_pack = RuleBasedSafetyModuleDataPack('test', 'en', 'details', 'article') + data_pack = RuleBasedSafetyModuleDataPack('test', 'en', 'details', 'article','http://test.com', 'test_dataset') data_pack.set_process_result(False, {'info': 'test'}) - expected_output = {'clean_remained': False, 'clean_infos': {'info': 'test'}} + expected_output = {'safety_remained': False, 'safety_infos': {'info': 'test'}} self.assertDictEqual(data_pack.get_output(), expected_output) class TestRuleBasedSafetyModule(TestCase): - @patch.object(UnsafeWordsFilter, 'filter') - def test_process_core(self, mock_filter): + @patch.object('llm_web_kit.model.rule_based_safety_module.DomainFilter', 'filter') + @patch.object('llm_web_kit.model.rule_based_safety_module.SourceFilter', 'filter') + @patch.object('llm_web_kit.model.rule_based_safety_module.UnsafeWordsFilter', 'filter') + def test_process_core(self, mock_unsafe_words_filter, mock_source_filter, mock_domain_filter): """测试核心处理流程.""" # 设置模拟返回值 - mock_filter.return_value = (False, {'reason': 'test'}) + mock_source_filter.return_value = {'from_safe_source': False, 'from_domestic_source': False} + mock_domain_filter.return_value = (True, {}) + mock_unsafe_words_filter.return_value = (False, {'reason': 'test'}) + # 初始化测试对象 clean_module = RuleBasedSafetyModule(prod=True) - data_pack = RuleBasedSafetyModuleDataPack('test', 'en', 'details', 'article') + data_pack = RuleBasedSafetyModuleDataPack('test', 'en', 'details', 'article','http://test.com', 'test_dataset') # 执行核心处理 result = clean_module.process_core(data_pack) # 验证过滤方法被正确调用 - mock_filter.assert_called_once_with('test', 'en', 'details', 'article') + mock_unsafe_words_filter.assert_called_once_with('test', 'en', 'details', 'article', False, False) # 验证处理结果设置正确 self.assertFalse(result.clean_remained) self.assertEqual(result.clean_infos, {'reason': 'test'}) - @patch.object(UnsafeWordsFilter, 'filter') - def test_process_flow(self, mock_filter): + @patch.object('llm_web_kit.model.rule_based_safety_module.DomainFilter', 'filter') + @patch.object('llm_web_kit.model.rule_based_safety_module.SourceFilter', 'filter') + @patch.object('llm_web_kit.model.rule_based_safety_module.UnsafeWordsFilter', 'filter') + def test_process_flow(self, mock_unsafe_words_filter, mock_source_filter, mock_domain_filter): """测试完整处理流程.""" - mock_filter.return_value = (True, {'quality': 0.95}) + mock_source_filter.return_value = {'from_safe_source': False, 'from_domestic_source': False} + mock_domain_filter.return_value = (True, {}) + mock_unsafe_words_filter.return_value = (True, {}) clean_module = RuleBasedSafetyModule(prod=False) result = clean_module.process( @@ -107,9 +118,11 @@ def test_process_flow(self, mock_filter): language='en', language_details='details', content_style='article', + url='http://test.com', + dataset_name='test_dataset', ) - expected_result = {'clean_remained': True, 'clean_infos': {'quality': 0.95}} + expected_result = {'clean_remained': True, 'clean_infos': {}} self.assertDictEqual(result, expected_result) def test_production_mode_effect(self): From 9040a7fbeb8447f5f8e7a0d3b01ccd470abcc4c4 Mon Sep 17 00:00:00 2001 From: yujing Date: Mon, 3 Mar 2025 15:51:40 +0800 Subject: [PATCH 10/32] test rule-based-safety-module --- tests/llm_web_kit/model/test_rule_based_safety_module.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/llm_web_kit/model/test_rule_based_safety_module.py b/tests/llm_web_kit/model/test_rule_based_safety_module.py index 05aa517f..1dc0a3a8 100644 --- a/tests/llm_web_kit/model/test_rule_based_safety_module.py +++ b/tests/llm_web_kit/model/test_rule_based_safety_module.py @@ -5,8 +5,6 @@ # 需要根据实际模块路径调整 from llm_web_kit.model.rule_based_safety_module import ( RuleBasedSafetyModule, RuleBasedSafetyModuleDataPack, check_type) -from llm_web_kit.model.source_safety_detector import SourceFilter -from llm_web_kit.model.unsafe_words_detector import UnsafeWordsFilter class TestCheckType(TestCase): @@ -89,7 +87,6 @@ def test_process_core(self, mock_unsafe_words_filter, mock_source_filter, mock_d mock_domain_filter.return_value = (True, {}) mock_unsafe_words_filter.return_value = (False, {'reason': 'test'}) - # 初始化测试对象 clean_module = RuleBasedSafetyModule(prod=True) data_pack = RuleBasedSafetyModuleDataPack('test', 'en', 'details', 'article','http://test.com', 'test_dataset') From 47a8f0695690561d39bec69ed86d5b9a93fb1390 Mon Sep 17 00:00:00 2001 From: yujing Date: Mon, 3 Mar 2025 20:05:02 +0800 Subject: [PATCH 11/32] test rule-based-safety-module --- .../model/model_based_safety_module.py | 23 ++++++++++--------- .../model/test_model_based_safe_model.py | 0 .../model/test_rule_based_safety_module.py | 14 +++++------ 3 files changed, 19 insertions(+), 18 deletions(-) create mode 100644 tests/llm_web_kit/model/test_model_based_safe_model.py diff --git a/llm_web_kit/model/model_based_safety_module.py b/llm_web_kit/model/model_based_safety_module.py index 1edb96d1..108b6d6d 100644 --- a/llm_web_kit/model/model_based_safety_module.py +++ b/llm_web_kit/model/model_based_safety_module.py @@ -1,6 +1,7 @@ from typing import List, Tuple, Any, Type, TypeVar from llm_web_kit.model.policical import PoliticalDetector, decide_political_by_prob from llm_web_kit.model.porn_detector import BertModel as EnPornBertModel +from llm_web_kit.model.porn_detector import XlmrModel as ZhPornXlmrModel from llm_web_kit.exception.exception import ModelInputException I = TypeVar("I") # input type @@ -22,7 +23,7 @@ def check_type(arg_name: str, arg_value: Any, arg_type: Type): class ModelBasedSafetyDataPack: """The data pack for the model based safety module.""" - def __init__(self, content_str: str, langurage: str, langurage_details: str): + def __init__(self, content_str: str, language: str, language_details: str): self._dict = {} # the content of the dataset @@ -30,12 +31,12 @@ def __init__(self, content_str: str, langurage: str, langurage_details: str): self._dict["content_str"] = content_str # the language of the content - check_type("langurage", langurage, str) - self._dict["langurage"] = langurage + check_type("language", language, str) + self._dict["language"] = language # the details of the language - check_type("langurage_details", langurage_details, str) - self._dict["langurage_details"] = langurage_details + check_type("language_details", language_details, str) + self._dict["language_details"] = language_details # the flag of the processed data should be remained or not self._dict["model_based_safety_remained"] = True @@ -47,8 +48,8 @@ def __init__(self, content_str: str, langurage: str, langurage_details: str): def from_dict(cls, data: dict): new_data_pack = cls( content_str=data["content_str"], - langurage=data["langurage"], - langurage_details=data["langurage_details"], + language=data["language"], + language_details=data["language_details"], ) new_data_pack._dict.update(data) return new_data_pack @@ -112,9 +113,9 @@ def process_one_core( return self.postprocess(info, results[0]) def process_one( - self, content_str: str, langurage: str, langurage_details: str + self, content_str: str, language: str, language_details: str ) -> dict: - data_pack = ModelBasedSafetyDataPack(content_str, langurage, langurage_details) + data_pack = ModelBasedSafetyDataPack(content_str, language, language_details) return self.process_one_core(data_pack).get_output() @@ -127,7 +128,7 @@ def __init__(self, model_config: dict): self.threshold = model_config["threshold"] def check_support(self, data_pack: ModelBasedSafetyDataPack) -> bool: - return data_pack.langurage in ["zh", "en"] + return data_pack.language in ["zh", "en"] def inference(self, batch: List[str]) -> List[dict]: result_list = [] @@ -161,7 +162,7 @@ def __init__(self, model_config: dict): self.threshold = model_config["threshold"] def check_support(self, data_pack: ModelBasedSafetyDataPack) -> bool: - return data_pack.langurage == "en" + return data_pack.language == "en" def inference(self, batch: List[str]) -> List[dict]: result_list = [] diff --git a/tests/llm_web_kit/model/test_model_based_safe_model.py b/tests/llm_web_kit/model/test_model_based_safe_model.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/llm_web_kit/model/test_rule_based_safety_module.py b/tests/llm_web_kit/model/test_rule_based_safety_module.py index 1dc0a3a8..0c86b1c2 100644 --- a/tests/llm_web_kit/model/test_rule_based_safety_module.py +++ b/tests/llm_web_kit/model/test_rule_based_safety_module.py @@ -77,9 +77,9 @@ def test_get_output(self): class TestRuleBasedSafetyModule(TestCase): - @patch.object('llm_web_kit.model.rule_based_safety_module.DomainFilter', 'filter') - @patch.object('llm_web_kit.model.rule_based_safety_module.SourceFilter', 'filter') - @patch.object('llm_web_kit.model.rule_based_safety_module.UnsafeWordsFilter', 'filter') + @patch.object(RuleBasedSafetyModule.domain_filter, 'filter') + @patch.object(RuleBasedSafetyModule.source_filter, 'filter') + @patch.object(RuleBasedSafetyModule.unsafe_words_filter, 'filter') def test_process_core(self, mock_unsafe_words_filter, mock_source_filter, mock_domain_filter): """测试核心处理流程.""" # 设置模拟返回值 @@ -88,7 +88,7 @@ def test_process_core(self, mock_unsafe_words_filter, mock_source_filter, mock_d mock_unsafe_words_filter.return_value = (False, {'reason': 'test'}) # 初始化测试对象 - clean_module = RuleBasedSafetyModule(prod=True) + clean_module = RuleBasedSafetyModule(prod=True) data_pack = RuleBasedSafetyModuleDataPack('test', 'en', 'details', 'article','http://test.com', 'test_dataset') # 执行核心处理 @@ -100,9 +100,9 @@ def test_process_core(self, mock_unsafe_words_filter, mock_source_filter, mock_d self.assertFalse(result.clean_remained) self.assertEqual(result.clean_infos, {'reason': 'test'}) - @patch.object('llm_web_kit.model.rule_based_safety_module.DomainFilter', 'filter') - @patch.object('llm_web_kit.model.rule_based_safety_module.SourceFilter', 'filter') - @patch.object('llm_web_kit.model.rule_based_safety_module.UnsafeWordsFilter', 'filter') + @patch.object(RuleBasedSafetyModule.domain_filter, 'filter') + @patch.object(RuleBasedSafetyModule.source_filter, 'filter') + @patch.object(RuleBasedSafetyModule.unsafe_words_filter, 'filter') def test_process_flow(self, mock_unsafe_words_filter, mock_source_filter, mock_domain_filter): """测试完整处理流程.""" mock_source_filter.return_value = {'from_safe_source': False, 'from_domestic_source': False} From abd23209d4c9f0e3f10dd503555a52d3d8793261 Mon Sep 17 00:00:00 2001 From: yujing Date: Mon, 3 Mar 2025 20:14:09 +0800 Subject: [PATCH 12/32] test rule-based-safety-module --- .../model/test_rule_based_safety_module.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/llm_web_kit/model/test_rule_based_safety_module.py b/tests/llm_web_kit/model/test_rule_based_safety_module.py index 0c86b1c2..e626b9eb 100644 --- a/tests/llm_web_kit/model/test_rule_based_safety_module.py +++ b/tests/llm_web_kit/model/test_rule_based_safety_module.py @@ -5,7 +5,9 @@ # 需要根据实际模块路径调整 from llm_web_kit.model.rule_based_safety_module import ( RuleBasedSafetyModule, RuleBasedSafetyModuleDataPack, check_type) - +from llm_web_kit.model.domain_safety_detector import DomainFilter +from llm_web_kit.model.source_safety_detector import SourceFilter +from llm_web_kit.model.unsafe_words_detector import UnsafeWordsFilter class TestCheckType(TestCase): def test_type_checking(self): @@ -77,9 +79,9 @@ def test_get_output(self): class TestRuleBasedSafetyModule(TestCase): - @patch.object(RuleBasedSafetyModule.domain_filter, 'filter') - @patch.object(RuleBasedSafetyModule.source_filter, 'filter') - @patch.object(RuleBasedSafetyModule.unsafe_words_filter, 'filter') + @patch.object(DomainFilter.filter) + @patch.object(SourceFilter.filter) + @patch.object(UnsafeWordsFilter.filter) def test_process_core(self, mock_unsafe_words_filter, mock_source_filter, mock_domain_filter): """测试核心处理流程.""" # 设置模拟返回值 @@ -88,7 +90,7 @@ def test_process_core(self, mock_unsafe_words_filter, mock_source_filter, mock_d mock_unsafe_words_filter.return_value = (False, {'reason': 'test'}) # 初始化测试对象 - clean_module = RuleBasedSafetyModule(prod=True) + clean_module = RuleBasedSafetyModule(prod=True) data_pack = RuleBasedSafetyModuleDataPack('test', 'en', 'details', 'article','http://test.com', 'test_dataset') # 执行核心处理 @@ -100,9 +102,9 @@ def test_process_core(self, mock_unsafe_words_filter, mock_source_filter, mock_d self.assertFalse(result.clean_remained) self.assertEqual(result.clean_infos, {'reason': 'test'}) - @patch.object(RuleBasedSafetyModule.domain_filter, 'filter') - @patch.object(RuleBasedSafetyModule.source_filter, 'filter') - @patch.object(RuleBasedSafetyModule.unsafe_words_filter, 'filter') + @patch.object(DomainFilter.filter) + @patch.object(SourceFilter.filter) + @patch.object(UnsafeWordsFilter.filter) def test_process_flow(self, mock_unsafe_words_filter, mock_source_filter, mock_domain_filter): """测试完整处理流程.""" mock_source_filter.return_value = {'from_safe_source': False, 'from_domestic_source': False} From 992b4d518e6d6ae186482acc14b521e7a20de2c6 Mon Sep 17 00:00:00 2001 From: yujing Date: Mon, 3 Mar 2025 20:25:39 +0800 Subject: [PATCH 13/32] test rule-based-safety-module --- .../model/test_rule_based_safety_module.py | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/tests/llm_web_kit/model/test_rule_based_safety_module.py b/tests/llm_web_kit/model/test_rule_based_safety_module.py index e626b9eb..983018cf 100644 --- a/tests/llm_web_kit/model/test_rule_based_safety_module.py +++ b/tests/llm_web_kit/model/test_rule_based_safety_module.py @@ -2,13 +2,14 @@ from unittest import TestCase from unittest.mock import patch +from llm_web_kit.model.domain_safety_detector import DomainFilter # 需要根据实际模块路径调整 from llm_web_kit.model.rule_based_safety_module import ( RuleBasedSafetyModule, RuleBasedSafetyModuleDataPack, check_type) -from llm_web_kit.model.domain_safety_detector import DomainFilter from llm_web_kit.model.source_safety_detector import SourceFilter from llm_web_kit.model.unsafe_words_detector import UnsafeWordsFilter + class TestCheckType(TestCase): def test_type_checking(self): """测试类型检查工具函数.""" @@ -59,8 +60,8 @@ def test_set_process_result(self): # 测试正确类型 data_pack.set_process_result(False, {'key': 'value'}) - self.assertFalse(data_pack.clean_remained) - self.assertEqual(data_pack.clean_infos, {'key': 'value'}) + self.assertFalse(data_pack.safety_remained) + self.assertEqual(data_pack.safety_infos, {'key': 'value'}) # 测试错误类型 with self.assertRaises(TypeError): @@ -79,9 +80,9 @@ def test_get_output(self): class TestRuleBasedSafetyModule(TestCase): - @patch.object(DomainFilter.filter) - @patch.object(SourceFilter.filter) - @patch.object(UnsafeWordsFilter.filter) + @patch.object(DomainFilter,'filter') + @patch.object(SourceFilter,'filter') + @patch.object(UnsafeWordsFilter,'filter') def test_process_core(self, mock_unsafe_words_filter, mock_source_filter, mock_domain_filter): """测试核心处理流程.""" # 设置模拟返回值 @@ -90,29 +91,29 @@ def test_process_core(self, mock_unsafe_words_filter, mock_source_filter, mock_d mock_unsafe_words_filter.return_value = (False, {'reason': 'test'}) # 初始化测试对象 - clean_module = RuleBasedSafetyModule(prod=True) + safety_module = RuleBasedSafetyModule(prod=True) data_pack = RuleBasedSafetyModuleDataPack('test', 'en', 'details', 'article','http://test.com', 'test_dataset') # 执行核心处理 - result = clean_module.process_core(data_pack) + result = safety_module.process_core(data_pack) # 验证过滤方法被正确调用 mock_unsafe_words_filter.assert_called_once_with('test', 'en', 'details', 'article', False, False) # 验证处理结果设置正确 - self.assertFalse(result.clean_remained) - self.assertEqual(result.clean_infos, {'reason': 'test'}) + self.assertFalse(result.safety_remained) + self.assertEqual(result.safety_infos, {'reason': 'test'}) - @patch.object(DomainFilter.filter) - @patch.object(SourceFilter.filter) - @patch.object(UnsafeWordsFilter.filter) + @patch.object(DomainFilter,'filter') + @patch.object(SourceFilter,'filter') + @patch.object(UnsafeWordsFilter,'filter') def test_process_flow(self, mock_unsafe_words_filter, mock_source_filter, mock_domain_filter): """测试完整处理流程.""" mock_source_filter.return_value = {'from_safe_source': False, 'from_domestic_source': False} mock_domain_filter.return_value = (True, {}) mock_unsafe_words_filter.return_value = (True, {}) - clean_module = RuleBasedSafetyModule(prod=False) - result = clean_module.process( + safety_module = RuleBasedSafetyModule(prod=False) + result = safety_module.process( content_str='content', language='en', language_details='details', @@ -121,7 +122,7 @@ def test_process_flow(self, mock_unsafe_words_filter, mock_source_filter, mock_d dataset_name='test_dataset', ) - expected_result = {'clean_remained': True, 'clean_infos': {}} + expected_result = {'safety_remained': True, 'safety_infos': {}} self.assertDictEqual(result, expected_result) def test_production_mode_effect(self): @@ -132,8 +133,8 @@ def test_production_mode_effect(self): def test_get_version(self): """测试版本获取.""" - clean_module = RuleBasedSafetyModule(prod=True) - self.assertEqual(clean_module.get_version(), '1.0.0') + safety_module = RuleBasedSafetyModule(prod=True) + self.assertEqual(safety_module.get_version(), '1.0.0') if __name__ == '__main__': From 928feff296d6b7a479835604431288454cbc732a Mon Sep 17 00:00:00 2001 From: yujing Date: Mon, 3 Mar 2025 20:32:39 +0800 Subject: [PATCH 14/32] test rule-based-safety-module --- tests/llm_web_kit/model/test_unsafe_words_detector.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/llm_web_kit/model/test_unsafe_words_detector.py b/tests/llm_web_kit/model/test_unsafe_words_detector.py index 61af44de..16729fcc 100644 --- a/tests/llm_web_kit/model/test_unsafe_words_detector.py +++ b/tests/llm_web_kit/model/test_unsafe_words_detector.py @@ -159,15 +159,6 @@ def test_unsafe_words_filter_overall(self, mock_filter): self.assertIsInstance(result, dict) self.assertTrue(result['hit_unsafe_words']) - with self.assertRaises(SafeModelException): - result = unsafe_words_filter_overall( - data_dict, - language='unknown', - content_style='text', - from_safe_source=False, - from_domestic_source=False, - ) - @patch('llm_web_kit.model.unsafe_words_detector.load_config') @patch('llm_web_kit.model.unsafe_words_detector.download_auto_file') def test_auto_download(self, mock_download_auto_file, mock_load_config): From 972a39a086f92bfc6f9dea6c9bec0fffcb46ccda Mon Sep 17 00:00:00 2001 From: yujing Date: Tue, 4 Mar 2025 20:24:55 +0800 Subject: [PATCH 15/32] model_based_safety_module --- .../model/model_based_safety_module.py | 9 +- llm_web_kit/model/porn_detector.py | 194 ++++++++++++++---- 2 files changed, 159 insertions(+), 44 deletions(-) diff --git a/llm_web_kit/model/model_based_safety_module.py b/llm_web_kit/model/model_based_safety_module.py index 108b6d6d..f48a318a 100644 --- a/llm_web_kit/model/model_based_safety_module.py +++ b/llm_web_kit/model/model_based_safety_module.py @@ -155,7 +155,7 @@ def postprocess(self, info: dict, result: dict) -> dict: return datapack -class EnPronModel(ContentStrBatchModel): +class EnPornModel(ContentStrBatchModel): def __init__(self, model_config: dict): super().__init__(model_config) self.model = EnPornBertModel(model_config["model_path"]) @@ -180,3 +180,10 @@ def postprocess(self, info: dict, result: dict) -> dict: model_based_safety_infos={"porn_prob": porn_prob}, ) return datapack + + +class ZhPornModel(EnPornModel): + def __init__(self, model_config: dict): + self.model_config = model_config + self.model = ZhPornXlmrModel(model_config["model_path"]) + self.threshold = model_config["threshold"] diff --git a/llm_web_kit/model/porn_detector.py b/llm_web_kit/model/porn_detector.py index ae3df185..c0c94652 100644 --- a/llm_web_kit/model/porn_detector.py +++ b/llm_web_kit/model/porn_detector.py @@ -8,68 +8,77 @@ from llm_web_kit.config.cfg_reader import load_config from llm_web_kit.libs.logger import mylogger as logger from llm_web_kit.model.resource_utils.download_assets import ( - CACHE_DIR, download_auto_file) -from llm_web_kit.model.resource_utils.unzip_ext import (get_unzip_dir, - unzip_local_file) + CACHE_DIR, + download_auto_file, +) +from llm_web_kit.model.resource_utils.unzip_ext import get_unzip_dir, unzip_local_file -class BertModel(): +class BertModel: def __init__(self, model_path: str = None) -> None: if not model_path: model_path = self.auto_download() - self.model = AutoModelForSequenceClassification.from_pretrained(os.path.join(model_path, 'porn_classifier/classifier_hf')) - with open(os.path.join(model_path, 'porn_classifier/extra_parameters.json')) as reader: + self.model = AutoModelForSequenceClassification.from_pretrained( + os.path.join(model_path, "porn_classifier/classifier_hf") + ) + with open( + os.path.join(model_path, "porn_classifier/extra_parameters.json") + ) as reader: model_config = json.load(reader) - self.cls_index = int(model_config.get('cls_index', 1)) - self.use_sigmoid = bool(model_config.get('use_sigmoid', False)) - self.max_tokens = int(model_config.get('max_tokens', 512)) - self.remain_tail = min(self.max_tokens - 1, int(model_config.get('remain_tail', -1))) - self.device = model_config.get('device', 'cpu') + self.cls_index = int(model_config.get("cls_index", 1)) + self.use_sigmoid = bool(model_config.get("use_sigmoid", False)) + self.max_tokens = int(model_config.get("max_tokens", 512)) + self.remain_tail = min( + self.max_tokens - 1, int(model_config.get("remain_tail", -1)) + ) + self.device = model_config.get("device", "cpu") self.model.eval() self.model.to(self.device, dtype=torch.float16) - if hasattr(self.model, 'to_bettertransformer'): + if hasattr(self.model, "to_bettertransformer"): self.model = self.model.to_bettertransformer() - self.tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, 'porn_classifier/classifier_hf')) + self.tokenizer = AutoTokenizer.from_pretrained( + os.path.join(model_path, "porn_classifier/classifier_hf") + ) self.tokenizer_config = { - 'padding': True, - 'truncation': self.remain_tail <= 0, - 'max_length': self.max_tokens if self.remain_tail <= 0 else None, - 'return_tensors': 'pt' if self.remain_tail <= 0 else None, + "padding": True, + "truncation": self.remain_tail <= 0, + "max_length": self.max_tokens if self.remain_tail <= 0 else None, + "return_tensors": "pt" if self.remain_tail <= 0 else None, } - self.output_prefix = str(model_config.get('output_prefix', '')).rstrip('_') - self.output_postfix = str(model_config.get('output_postfix', '')).lstrip('_') + self.output_prefix = str(model_config.get("output_prefix", "")).rstrip("_") + self.output_postfix = str(model_config.get("output_postfix", "")).lstrip("_") - self.model_name = str(model_config.get('model_name', 'porn-23w44')) + self.model_name = str(model_config.get("model_name", "porn-23w44")) def auto_download(self) -> str: """Default download the 23w44.zip model.""" - resource_name = 'porn-23w44' - resource_config = load_config()['resources'] + resource_name = "porn-23w44" + resource_config = load_config()["resources"] porn_23w44_config: Dict = resource_config[resource_name] - porn_23w44_s3 = porn_23w44_config['download_path'] - porn_23w44_md5 = porn_23w44_config.get('md5', '') + porn_23w44_s3 = porn_23w44_config["download_path"] + porn_23w44_md5 = porn_23w44_config.get("md5", "") # get the zip path calculated by the s3 path - zip_path = os.path.join(CACHE_DIR, f'{resource_name}.zip') + zip_path = os.path.join(CACHE_DIR, f"{resource_name}.zip") # the unzip path is calculated by the zip path unzip_path = get_unzip_dir(zip_path) - logger.info(f'try to make unzip_path: {unzip_path}') + logger.info(f"try to make unzip_path: {unzip_path}") # if the unzip path does not exist, download the zip file and unzip it if not os.path.exists(unzip_path): - logger.info(f'unzip_path: {unzip_path} does not exist') - logger.info(f'try to unzip from zip_path: {zip_path}') + logger.info(f"unzip_path: {unzip_path} does not exist") + logger.info(f"try to unzip from zip_path: {zip_path}") if not os.path.exists(zip_path): - logger.info(f'zip_path: {zip_path} does not exist') - logger.info(f'downloading {porn_23w44_s3}') + logger.info(f"zip_path: {zip_path} does not exist") + logger.info(f"downloading {porn_23w44_s3}") zip_path = download_auto_file(porn_23w44_s3, zip_path, porn_23w44_md5) - logger.info(f'unzipping {zip_path}') + logger.info(f"unzipping {zip_path}") unzip_path = unzip_local_file(zip_path, unzip_path) else: - logger.info(f'unzip_path: {unzip_path} exist') + logger.info(f"unzip_path: {unzip_path} exist") return unzip_path def pre_process(self, samples: Union[List[str], str]) -> Dict: @@ -81,40 +90,54 @@ def pre_process(self, samples: Union[List[str], str]) -> Dict: processed_inputs = [] # 对每个输入进行处理 - for tokens_id in inputs['input_ids']: + for tokens_id in inputs["input_ids"]: # 通过sep_token_id找到tokens的长度 length = tokens_id.index(self.tokenizer.sep_token_id) + 1 # 如果tokens的长度小于等于max_tokens,则直接在尾部补0,不需要截断 if length <= self.max_tokens: - tokens = tokens_id[:length] + [self.tokenizer.pad_token_id] * (self.max_tokens - length) + tokens = tokens_id[:length] + [self.tokenizer.pad_token_id] * ( + self.max_tokens - length + ) attn = [1] * length + [0] * (self.max_tokens - length) # 如果tokens的长度大于max_tokens,则需要取头部max_tokens-remain_tail个tokens和尾部remain_tail个tokens else: head_length = self.max_tokens - self.remain_tail tail_length = self.remain_tail - tokens = tokens_id[:head_length] + tokens_id[length - tail_length : length] + tokens = ( + tokens_id[:head_length] + + tokens_id[length - tail_length : length] + ) attn = [1] * self.max_tokens # 将处理后的tokens添加到新的inputs列表中 - processed_inputs.append({'input_ids': torch.tensor(tokens), 'attention_mask': torch.tensor(attn)}) + processed_inputs.append( + { + "input_ids": torch.tensor(tokens), + "attention_mask": torch.tensor(attn), + } + ) # 将所有inputs整合成一个batch inputs = { - 'input_ids': torch.cat([inp['input_ids'].unsqueeze(0) for inp in processed_inputs]), - 'attention_mask': torch.cat([inp['attention_mask'].unsqueeze(0) for inp in processed_inputs]), + "input_ids": torch.cat( + [inp["input_ids"].unsqueeze(0) for inp in processed_inputs] + ), + "attention_mask": torch.cat( + [inp["attention_mask"].unsqueeze(0) for inp in processed_inputs] + ), } inputs = {name: tensor.to(self.device) for name, tensor in inputs.items()} - return {'inputs': inputs} + return {"inputs": inputs} def get_output_key(self, f: str): prefix = self.output_prefix if self.output_prefix else self.model_name - postfix = f'_{self.output_postfix}' if self.output_postfix else '' - return f'{prefix}_{f}{postfix}' + postfix = f"_{self.output_postfix}" if self.output_postfix else "" + return f"{prefix}_{f}{postfix}" def predict(self, texts: Union[List[str], str]): inputs_dict = self.pre_process(texts) with torch.no_grad(): - logits = self.model(**inputs_dict['inputs']).logits + logits = self.model(**inputs_dict["inputs"]).logits if self.use_sigmoid: probs = torch.sigmoid(logits) @@ -126,7 +149,92 @@ def predict(self, texts: Union[List[str], str]): outputs = [] for prob in pos_prob: prob = round(float(prob), 6) - output = {self.get_output_key('prob'): prob} + output = {self.get_output_key("prob"): prob} + outputs.append(output) + + return outputs + + +class XlmrModel(BertModel): + def __init__(self, model_path: str = None) -> None: + if not model_path: + model_path = self.auto_download() + self.model = AutoModelForSequenceClassification.from_pretrained( + os.path.join(model_path, "porn_classifier/classifier_hf") + ) + with open( + os.path.join(model_path, "porn_classifier/extra_parameters.json") + ) as reader: + model_config = json.load(reader) + + self.clip = bool(model_config.get("clip", False)) + self.max_tokens = int(model_config.get("max_tokens", 300)) + self.remain_tail = min( + self.max_tokens - 1, int(model_config.get("remain_tail", -1)) + ) + self.device = model_config.get("device", "cpu") + + self.model.eval() + self.model.to(self.device, dtype=torch.float16) + + if hasattr(self.model, "to_bettertransformer"): + self.model = self.model.to_bettertransformer() + + self.tokenizer = AutoTokenizer.from_pretrained( + os.path.join(model_path, "porn_classifier/classifier_hf") + ) + self.tokenizer_config = { + "padding": True, + "truncation": self.remain_tail <= 0, + "max_length": self.max_tokens if self.remain_tail <= 0 else None, + "return_tensors": "pt" if self.remain_tail <= 0 else None, + } + + self.output_prefix = str(model_config.get("output_prefix", "")).rstrip("_") + self.output_postfix = str(model_config.get("output_postfix", "")).lstrip("_") + + self.model_name = str(model_config.get("model_name", "porn-24m5")) + + def auto_download(self) -> str: + """Default download the 23w44.zip model.""" + resource_name = "porn-24m5" + resource_config = load_config()["resources"] + porn_24m5_config: Dict = resource_config[resource_name] + porn_24m5_s3 = porn_24m5_config["download_path"] + porn_24m5_md5 = porn_24m5_config.get("md5", "") + # get the zip path calculated by the s3 path + zip_path = os.path.join(CACHE_DIR, f"{resource_name}.zip") + # the unzip path is calculated by the zip path + unzip_path = get_unzip_dir(zip_path) + logger.info(f"try to make unzip_path: {unzip_path}") + # if the unzip path does not exist, download the zip file and unzip it + if not os.path.exists(unzip_path): + logger.info(f"unzip_path: {unzip_path} does not exist") + logger.info(f"try to unzip from zip_path: {zip_path}") + if not os.path.exists(zip_path): + logger.info(f"zip_path: {zip_path} does not exist") + logger.info(f"downloading {porn_24m5_s3}") + zip_path = download_auto_file(porn_24m5_s3, zip_path, porn_24m5_md5) + logger.info(f"unzipping {zip_path}") + unzip_path = unzip_local_file(zip_path, unzip_path) + else: + logger.info(f"unzip_path: {unzip_path} exist") + return unzip_path + + def predict(self, texts: Union[List[str], str]): + inputs_dict = self.pre_process(texts) + with torch.no_grad(): + logits = self.model(**inputs_dict["inputs"]).logits + + if self.clip: + probs = logits.detach().cpu().numpy().clip(min=0, max=1) + else: + probs = logits.detach().cpu().numpy() + + outputs = [] + for prob in probs: + prob = round(float(prob[0]), 6) + output = {self.get_output_key("prob"): prob} outputs.append(output) return outputs From ed7e141c9f5511d96ccac6bc60384f8afe15c05a Mon Sep 17 00:00:00 2001 From: yujing Date: Wed, 5 Mar 2025 18:59:21 +0800 Subject: [PATCH 16/32] model_based_safety_module --- llm_web_kit/model/porn_detector.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/llm_web_kit/model/porn_detector.py b/llm_web_kit/model/porn_detector.py index c0c94652..9f23c92b 100644 --- a/llm_web_kit/model/porn_detector.py +++ b/llm_web_kit/model/porn_detector.py @@ -37,9 +37,6 @@ def __init__(self, model_path: str = None) -> None: self.model.eval() self.model.to(self.device, dtype=torch.float16) - if hasattr(self.model, "to_bettertransformer"): - self.model = self.model.to_bettertransformer() - self.tokenizer = AutoTokenizer.from_pretrained( os.path.join(model_path, "porn_classifier/classifier_hf") ) @@ -177,9 +174,6 @@ def __init__(self, model_path: str = None) -> None: self.model.eval() self.model.to(self.device, dtype=torch.float16) - if hasattr(self.model, "to_bettertransformer"): - self.model = self.model.to_bettertransformer() - self.tokenizer = AutoTokenizer.from_pretrained( os.path.join(model_path, "porn_classifier/classifier_hf") ) From 3463dfc0b04cadebaacb3b98e9bcd31f0952c8af Mon Sep 17 00:00:00 2001 From: yujing Date: Wed, 12 Mar 2025 17:38:14 +0800 Subject: [PATCH 17/32] part merge https://github.com/yogacc33/llm-webkit-mirror/blob/feature/model_api/ --- llm_web_kit/model/model_impl.py | 248 +++++++++++++++++++++++++++ llm_web_kit/model/model_interface.py | 148 ++++++++++++++++ 2 files changed, 396 insertions(+) create mode 100644 llm_web_kit/model/model_impl.py create mode 100644 llm_web_kit/model/model_interface.py diff --git a/llm_web_kit/model/model_impl.py b/llm_web_kit/model/model_impl.py new file mode 100644 index 00000000..02066c60 --- /dev/null +++ b/llm_web_kit/model/model_impl.py @@ -0,0 +1,248 @@ +from abc import abstractmethod +from enum import Enum +from typing import Dict, List, Type + +from llm_web_kit.model.model_interface import (BatchProcessConfig, + ModelPredictor, ModelResource, + PoliticalRequest, + PoliticalResponse, PornRequest, + PornResponse, + ResourceRequirement, + ResourceType) +from llm_web_kit.model.policical import (get_singleton_political_detect, + update_political_by_str) +from llm_web_kit.model.porn_detector import BertModel as PornBertModel + + +class ModelType(Enum): + """模型类型枚举.""" + + POLITICAL = 'political' # 涉政模型 + PORN = 'porn' # 色情模型 + + +class DeviceType(Enum): + """设备类型枚举.""" + + CPU = 'cpu' + GPU = 'gpu' + + +class BaseModelResource(ModelResource): + """基础模型资源类.""" + + def __init__(self, model_path: str): + self.model_path = model_path + self.model = None + + def initialize(self) -> None: + self.model = self._load_model() + + @abstractmethod + def _load_model(self): + pass + + def cleanup(self) -> None: + if self.model: + self._cleanup_model() + self.model = None + + def _cleanup_model(self): + pass + + +class BasePredictor(ModelPredictor): + """基础预测器类.""" + + def __init__(self, cpu_model_path: str, gpu_model_path: str): + self.cpu_model = self._create_cpu_model(cpu_model_path) + self.gpu_model = self._create_gpu_model(gpu_model_path) + + # 初始化模型 + if self.cpu_model: + self.cpu_model.initialize() + if self.gpu_model: + self.gpu_model.initialize() + + self.language_map = { + 'zh': DeviceType.CPU, + 'en': DeviceType.CPU, + # 其他语言映射到GPU + } + + @abstractmethod + def _create_cpu_model(self, model_path: str) -> ModelResource: + pass + + @abstractmethod + def _create_gpu_model(self, model_path: str) -> ModelResource: + pass + + def get_model_info(self) -> Dict[str, BatchProcessConfig]: + info = {} + if self.cpu_model: + info[DeviceType.CPU.value] = self.cpu_model.get_batch_config() + if self.gpu_model: + info[DeviceType.GPU.value] = self.gpu_model.get_batch_config() + return info + + +# 涉政模型实现 +class PoliticalCPUModel(BaseModelResource): + """涉政检测CPU模型.""" + + def _load_model(self): + try: + model = get_singleton_political_detect() + if model is None: + raise RuntimeError('Failed to load political model') + return model + except Exception as e: + raise RuntimeError(f'Failed to load political CPU model: {e}') + + def get_batch_config(self) -> BatchProcessConfig: + return BatchProcessConfig(max_batch_size=128, optimal_batch_size=64, min_batch_size=8) + + def predict_batch(self, contents: List[str]) -> List[float]: + if not self.model: + raise RuntimeError('Model not initialized') + try: + # 批量处理 + results = [] + for content in contents: + result = update_political_by_str(content) + results.append(result['political_prob']) + return results + except Exception as e: + raise RuntimeError(f'Prediction failed: {e}') + + +class PoliticalPredictorImpl(BasePredictor): + """涉政检测预测器实现.""" + + def _create_cpu_model(self, model_path: str) -> ModelResource: + return PoliticalCPUModel(model_path) + + def _create_gpu_model(self, model_path: str) -> ModelResource: + return None + + def get_resource_requirement(self, language: str) -> ResourceRequirement: + """获取资源需求.""" + # 涉政模型对中英文使用CPU,其他语言使用默认资源 + if language in ['zh', 'en']: + return ResourceRequirement(resource_type=ResourceType.CPU) + return ResourceRequirement() + + def predict_batch(self, requests: List[PoliticalRequest]) -> List[PoliticalResponse]: + """批量预测接口.""" + responses = [None] * len(requests) + + try: + # 收集所有请求内容 + batch_contents = [] + valid_indices = [] + + for idx, req in enumerate(requests): + try: + # 验证语言支持 + if req.language not in ['zh', 'en']: + continue + batch_contents.append(req.content) + valid_indices.append(idx) + except Exception as e: + print(f'Skip invalid request at index {idx}: {e}') + + if batch_contents: + # 批量处理 + probs = self.cpu_model.predict_batch(batch_contents) + # 填充结果 + for idx, prob in zip(valid_indices, probs): + responses[idx] = PoliticalResponse(probability=prob) + + except Exception as e: + raise RuntimeError(f'Political prediction failed: {e}') + + return responses + + +# 色情模型实现 +class PornGPUModel(BaseModelResource): + """色情检测GPU模型.""" + + def _load_model(self): + try: + return PornBertModel('') + except Exception as e: + raise RuntimeError(f'Failed to load porn GPU model: {e}') + + def get_batch_config(self) -> BatchProcessConfig: + return BatchProcessConfig(max_batch_size=128, optimal_batch_size=64, min_batch_size=8) + + def predict_batch(self, contents: List[str]) -> List[float]: + if not self.model: + raise RuntimeError('Model not initialized') + try: + # 色情模型本身支持批处理 + results = self.model.predict(contents) + return [result[self.model.get_output_key('prob')] for result in results] + except Exception as e: + raise RuntimeError(f'Prediction failed: {e}') + + +class PornPredictorImpl(BasePredictor): + """色情检测预测器实现.""" + + def _create_cpu_model(self, model_path: str) -> ModelResource: + return None + + def _create_gpu_model(self, model_path: str) -> ModelResource: + return PornGPUModel(model_path='') + + def get_resource_requirement(self, language: str) -> ResourceRequirement: + """获取资源需求.""" + # 色情模型统一使用GPU + return ResourceRequirement(resource_type=ResourceType.GPU) + + def predict_batch(self, requests: List[PornRequest]) -> List[PornResponse]: + """批量预测接口.""" + responses = [None] * len(requests) + + try: + # 收集所有中英文请求 + batch_contents = [] + valid_indices = [] + + for idx, req in enumerate(requests): + if req.language in ['zh', 'en']: + batch_contents.append(req.content) + valid_indices.append(idx) + + if batch_contents: + # 批量处理 + probs = self.gpu_model.predict_batch(batch_contents) + # 填充结果 + for idx, prob in zip(valid_indices, probs): + responses[idx] = PornResponse(probability=prob) + + except Exception as e: + raise RuntimeError(f'Porn prediction failed: {e}') + + return responses + + +# 模型工厂 +class ModelFactory: + """模型工厂类.""" + + _predictor_registry: Dict[ModelType, Type[BasePredictor]] = { + ModelType.POLITICAL: PoliticalPredictorImpl, + ModelType.PORN: PornPredictorImpl, + } + + @classmethod + def create_predictor(cls, model_type: ModelType, cpu_model_path: str, gpu_model_path: str) -> BasePredictor: + """创建预测器实例.""" + predictor_class = cls._predictor_registry.get(model_type) + if not predictor_class: + raise ValueError(f'No predictor registered for type: {model_type}') + return predictor_class(cpu_model_path, gpu_model_path) \ No newline at end of file diff --git a/llm_web_kit/model/model_interface.py b/llm_web_kit/model/model_interface.py new file mode 100644 index 00000000..23210af2 --- /dev/null +++ b/llm_web_kit/model/model_interface.py @@ -0,0 +1,148 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List + + +@dataclass +class ModelRequest: + """通用模型请求基类.""" + + content: str + language: str + extra_params: Dict[str, Any] = None + + +@dataclass +class ModelResponse: + """通用模型响应基类.""" + + probability: float + details: Dict[str, Any] = None + + +@dataclass +class PoliticalRequest(ModelRequest): + """涉政检测请求.""" + + pass + + +@dataclass +class PoliticalResponse(ModelResponse): + """涉政检测响应.""" + + pass + + +@dataclass +class PornRequest(ModelRequest): + """色情检测请求.""" + + pass + + +@dataclass +class PornResponse(ModelResponse): + """色情检测响应.""" + + pass + + +@dataclass +class BatchProcessConfig: + """批处理配置.""" + + max_batch_size: int + optimal_batch_size: int + min_batch_size: int + + +class ResourceType(Enum): + """资源类型枚举.""" + + CPU = 'cpu_only' + GPU = 'num_gpus' + DEFAULT = 'default' + + +class ResourceRequirement: + """资源需求配置.""" + + def __init__(self, resource_type: ResourceType = ResourceType.DEFAULT, num_cpus: int = 1, memory: int = 4 << 30): + self.resource_type = resource_type + self.num_cpus = num_cpus + self.memory = memory + + def to_ray_resources(self) -> Dict: + """转换为Ray资源配置.""" + resources = { + 'num_cpus': self.num_cpus, + 'memory': self.memory, + } + + # 根据资源类型设置正确的资源配置 + if self.resource_type == ResourceType.CPU: + resources['resources'] = {'cpu_only': 1} + elif self.resource_type == ResourceType.GPU: + # 使用 num_gpus 而不是在 resources 字典中设置 + resources['num_gpus'] = 0.25 + + return resources + + +class ModelResource(ABC): + """模型资源接口.""" + + @abstractmethod + def initialize(self) -> None: + """初始化模型资源.""" + pass + + @abstractmethod + def get_batch_config(self) -> BatchProcessConfig: + """获取模型的批处理配置.""" + pass + + @abstractmethod + def predict_batch(self, contents: List[str]) -> List[float]: + """批量预测.""" + pass + + @abstractmethod + def cleanup(self) -> None: + """清理资源.""" + pass + + +class ModelPredictor(ABC): + """通用预测器接口.""" + + @abstractmethod + def get_model_info(self) -> Dict[str, BatchProcessConfig]: + """获取模型信息.""" + pass + + @abstractmethod + def get_resource_requirement(self, language: str) -> ResourceRequirement: + """获取资源需求.""" + pass + + @abstractmethod + def predict_batch(self, requests: List[ModelRequest]) -> List[ModelResponse]: + """批量预测接口 - 同步版本.""" + pass + + +class PoliticalPredictor(ModelPredictor): + """涉政预测器接口.""" + + def predict_batch(self, requests: List[PoliticalRequest]) -> List[PoliticalResponse]: + pass + + +class PornPredictor(ModelPredictor): + """色情预测器接口.""" + + def predict_batch(self, requests: List[PornRequest]) -> List[PornResponse]: + pass \ No newline at end of file From a5a0d8f83a6b4053ef761d18ce6063b546fb8bc8 Mon Sep 17 00:00:00 2001 From: qiujiantao Date: Wed, 12 Mar 2025 21:35:13 +0800 Subject: [PATCH 18/32] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20ModelRuntime?= =?UTF-8?q?Exception=20=E5=BC=82=E5=B8=B8=E5=A4=84=E7=90=86=EF=BC=8C?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=A8=A1=E5=9E=8B=E8=B5=84=E6=BA=90=E7=AE=A1?= =?UTF-8?q?=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llm_web_kit/exception/exception.jsonc | 4 + llm_web_kit/exception/exception.py | 8 ++ llm_web_kit/model/model_impl.py | 166 ++++++++++++-------------- llm_web_kit/model/model_interface.py | 12 +- 4 files changed, 93 insertions(+), 97 deletions(-) diff --git a/llm_web_kit/exception/exception.jsonc b/llm_web_kit/exception/exception.jsonc index 24600d60..47df28d8 100644 --- a/llm_web_kit/exception/exception.jsonc +++ b/llm_web_kit/exception/exception.jsonc @@ -142,6 +142,10 @@ "CleanModelException": { "code": 46000000, "message": "Clean model exception" + }, + "ModelRuntimeException": { + "code": 47000000, + "message": "Model runtime exception" } } } diff --git a/llm_web_kit/exception/exception.py b/llm_web_kit/exception/exception.py index c3f4f5d1..3aa84572 100644 --- a/llm_web_kit/exception/exception.py +++ b/llm_web_kit/exception/exception.py @@ -336,6 +336,14 @@ def __init__(self, custom_message: str | None = None, error_code: int | None = N super().__init__(custom_message, error_code) +class ModelRuntimeException(ModelBaseException): + """Exception raised for model input data format.""" + def __init__(self, custom_message: str | None = None, error_code: int | None = None): + if error_code is None: + error_code = ErrorMsg.get_error_code('Model', 'ModelRuntimeException') + super().__init__(custom_message, error_code) + + class ModelOutputException(ModelBaseException): """Exception raised for model output data format.""" def __init__(self, custom_message: str | None = None, error_code: int | None = None): diff --git a/llm_web_kit/model/model_impl.py b/llm_web_kit/model/model_impl.py index 02066c60..9dc30ddf 100644 --- a/llm_web_kit/model/model_impl.py +++ b/llm_web_kit/model/model_impl.py @@ -2,13 +2,14 @@ from enum import Enum from typing import Dict, List, Type +from llm_web_kit.exception.exception import (ModelInitException, + ModelInputException, + ModelRuntimeException) from llm_web_kit.model.model_interface import (BatchProcessConfig, ModelPredictor, ModelResource, PoliticalRequest, PoliticalResponse, PornRequest, - PornResponse, - ResourceRequirement, - ResourceType) + PornResponse) from llm_web_kit.model.policical import (get_singleton_political_detect, update_political_by_str) from llm_web_kit.model.porn_detector import BertModel as PornBertModel @@ -31,8 +32,7 @@ class DeviceType(Enum): class BaseModelResource(ModelResource): """基础模型资源类.""" - def __init__(self, model_path: str): - self.model_path = model_path + def __init__(self): self.model = None def initialize(self) -> None: @@ -54,37 +54,19 @@ def _cleanup_model(self): class BasePredictor(ModelPredictor): """基础预测器类.""" - def __init__(self, cpu_model_path: str, gpu_model_path: str): - self.cpu_model = self._create_cpu_model(cpu_model_path) - self.gpu_model = self._create_gpu_model(gpu_model_path) + def __init__(self, language: str): + self.language = language + self.model = self._create_model(language) # 初始化模型 - if self.cpu_model: - self.cpu_model.initialize() - if self.gpu_model: - self.gpu_model.initialize() - - self.language_map = { - 'zh': DeviceType.CPU, - 'en': DeviceType.CPU, - # 其他语言映射到GPU - } - - @abstractmethod - def _create_cpu_model(self, model_path: str) -> ModelResource: - pass + self.model.initialize() @abstractmethod - def _create_gpu_model(self, model_path: str) -> ModelResource: + def _create_model(self, language) -> ModelResource: pass - def get_model_info(self) -> Dict[str, BatchProcessConfig]: - info = {} - if self.cpu_model: - info[DeviceType.CPU.value] = self.cpu_model.get_batch_config() - if self.gpu_model: - info[DeviceType.GPU.value] = self.gpu_model.get_batch_config() - return info + def get_resource_requirement(self): + return self.model.get_resource_requirement() # 涉政模型实现 @@ -101,7 +83,9 @@ def _load_model(self): raise RuntimeError(f'Failed to load political CPU model: {e}') def get_batch_config(self) -> BatchProcessConfig: - return BatchProcessConfig(max_batch_size=128, optimal_batch_size=64, min_batch_size=8) + return BatchProcessConfig( + max_batch_size=128, optimal_batch_size=64, min_batch_size=8 + ) def predict_batch(self, contents: List[str]) -> List[float]: if not self.model: @@ -120,54 +104,44 @@ def predict_batch(self, contents: List[str]) -> List[float]: class PoliticalPredictorImpl(BasePredictor): """涉政检测预测器实现.""" - def _create_cpu_model(self, model_path: str) -> ModelResource: - return PoliticalCPUModel(model_path) - - def _create_gpu_model(self, model_path: str) -> ModelResource: - return None + def _create_model(self, language: str) -> ModelResource: - def get_resource_requirement(self, language: str) -> ResourceRequirement: - """获取资源需求.""" - # 涉政模型对中英文使用CPU,其他语言使用默认资源 if language in ['zh', 'en']: - return ResourceRequirement(resource_type=ResourceType.CPU) - return ResourceRequirement() - - def predict_batch(self, requests: List[PoliticalRequest]) -> List[PoliticalResponse]: + return PoliticalCPUModel() + raise ModelInitException( + f'Poltical model does not support language: {language}' + ) + + def predict_batch( + self, requests: List[PoliticalRequest] + ) -> List[PoliticalResponse]: """批量预测接口.""" - responses = [None] * len(requests) try: # 收集所有请求内容 batch_contents = [] - valid_indices = [] - - for idx, req in enumerate(requests): - try: - # 验证语言支持 - if req.language not in ['zh', 'en']: - continue - batch_contents.append(req.content) - valid_indices.append(idx) - except Exception as e: - print(f'Skip invalid request at index {idx}: {e}') + + for req in requests: + # 验证语言支持 + if req.language != self.language: + raise ModelInputException( + f'Language mismatch: {req.language} vs {self.language}' + ) + batch_contents.append(req.content) if batch_contents: # 批量处理 - probs = self.cpu_model.predict_batch(batch_contents) - # 填充结果 - for idx, prob in zip(valid_indices, probs): - responses[idx] = PoliticalResponse(probability=prob) - + probs = self.model.predict_batch(batch_contents) + responses = [PoliticalResponse(probability=prob) for prob in probs] except Exception as e: - raise RuntimeError(f'Political prediction failed: {e}') + raise ModelRuntimeException(f'Political prediction failed: {e}') return responses # 色情模型实现 -class PornGPUModel(BaseModelResource): - """色情检测GPU模型.""" +class PornEnGPUModel(BaseModelResource): + """英文色情检测GPU模型.""" def _load_model(self): try: @@ -176,7 +150,9 @@ def _load_model(self): raise RuntimeError(f'Failed to load porn GPU model: {e}') def get_batch_config(self) -> BatchProcessConfig: - return BatchProcessConfig(max_batch_size=128, optimal_batch_size=64, min_batch_size=8) + return BatchProcessConfig( + max_batch_size=128, optimal_batch_size=64, min_batch_size=8 + ) def predict_batch(self, contents: List[str]) -> List[float]: if not self.model: @@ -189,44 +165,52 @@ def predict_batch(self, contents: List[str]) -> List[float]: raise RuntimeError(f'Prediction failed: {e}') -class PornPredictorImpl(BasePredictor): - """色情检测预测器实现.""" +class PornZhGPUModel(BaseModelResource): + """中文色情检测GPU模型.""" + + def _load_model(self): + raise NotImplementedError('TODO') + + def get_batch_config(self) -> BatchProcessConfig: + raise NotImplementedError('TODO') + return BatchProcessConfig( + max_batch_size=128, optimal_batch_size=64, min_batch_size=8 + ) - def _create_cpu_model(self, model_path: str) -> ModelResource: - return None + def predict_batch(self, contents: List[str]) -> List[float]: + raise NotImplementedError('TODO') - def _create_gpu_model(self, model_path: str) -> ModelResource: - return PornGPUModel(model_path='') - def get_resource_requirement(self, language: str) -> ResourceRequirement: - """获取资源需求.""" - # 色情模型统一使用GPU - return ResourceRequirement(resource_type=ResourceType.GPU) +class PornPredictorImpl(BasePredictor): + """色情检测预测器实现.""" + + def _create_models(self, language: str) -> ModelResource: + if language == 'en': + return PornEnGPUModel() + elif language == 'zh': + return PornZhGPUModel() + raise ModelInitException(f'Porn model does not support language: {language}') def predict_batch(self, requests: List[PornRequest]) -> List[PornResponse]: """批量预测接口.""" - responses = [None] * len(requests) - try: - # 收集所有中英文请求 + # 收集所有请求内容 batch_contents = [] - valid_indices = [] - for idx, req in enumerate(requests): - if req.language in ['zh', 'en']: - batch_contents.append(req.content) - valid_indices.append(idx) + for req in requests: + # 验证语言支持 + if req.language != self.language: + raise ModelInputException( + f'Language mismatch: {req.language} vs {self.language}' + ) + batch_contents.append(req.content) if batch_contents: # 批量处理 - probs = self.gpu_model.predict_batch(batch_contents) - # 填充结果 - for idx, prob in zip(valid_indices, probs): - responses[idx] = PornResponse(probability=prob) - + probs = self.model.predict_batch(batch_contents) + responses = [PornResponse(probability=prob) for prob in probs] except Exception as e: - raise RuntimeError(f'Porn prediction failed: {e}') - + raise ModelRuntimeException(f'Porn prediction failed: {e}') return responses @@ -240,9 +224,9 @@ class ModelFactory: } @classmethod - def create_predictor(cls, model_type: ModelType, cpu_model_path: str, gpu_model_path: str) -> BasePredictor: + def create_predictor(cls, model_type: ModelType, language: str) -> BasePredictor: """创建预测器实例.""" predictor_class = cls._predictor_registry.get(model_type) if not predictor_class: raise ValueError(f'No predictor registered for type: {model_type}') - return predictor_class(cpu_model_path, gpu_model_path) \ No newline at end of file + return predictor_class(language=language) diff --git a/llm_web_kit/model/model_interface.py b/llm_web_kit/model/model_interface.py index 23210af2..dc7ce031 100644 --- a/llm_web_kit/model/model_interface.py +++ b/llm_web_kit/model/model_interface.py @@ -114,15 +114,15 @@ def cleanup(self) -> None: """清理资源.""" pass + @abstractmethod + def get_resource_requirement(self) -> ResourceRequirement: + """获取资源需求.""" + pass + class ModelPredictor(ABC): """通用预测器接口.""" - @abstractmethod - def get_model_info(self) -> Dict[str, BatchProcessConfig]: - """获取模型信息.""" - pass - @abstractmethod def get_resource_requirement(self, language: str) -> ResourceRequirement: """获取资源需求.""" @@ -145,4 +145,4 @@ class PornPredictor(ModelPredictor): """色情预测器接口.""" def predict_batch(self, requests: List[PornRequest]) -> List[PornResponse]: - pass \ No newline at end of file + pass From 222db58f306c16707e0cf5112299bfebec043925 Mon Sep 17 00:00:00 2001 From: qiujiantao Date: Thu, 13 Mar 2025 11:49:33 +0800 Subject: [PATCH 19/32] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E5=8A=A0=E8=BD=BD=E5=92=8C=E8=B5=84=E6=BA=90=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=EF=BC=8C=E8=B0=83=E6=95=B4=E7=B1=BB=E5=B1=9E=E6=80=A7?= =?UTF-8?q?=E4=BB=A5=E5=A2=9E=E5=BC=BA=E5=8F=AF=E8=AF=BB=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llm_web_kit/model/model_impl.py | 89 +++++++++++++++++---- llm_web_kit/model/model_interface.py | 68 +++++++++++----- llm_web_kit/model/porn_detector.py | 113 ++++++++++++++------------- 3 files changed, 179 insertions(+), 91 deletions(-) diff --git a/llm_web_kit/model/model_impl.py b/llm_web_kit/model/model_impl.py index 9dc30ddf..0d008a32 100644 --- a/llm_web_kit/model/model_impl.py +++ b/llm_web_kit/model/model_impl.py @@ -7,12 +7,12 @@ ModelRuntimeException) from llm_web_kit.model.model_interface import (BatchProcessConfig, ModelPredictor, ModelResource, - PoliticalRequest, + ModelResponse, PoliticalRequest, PoliticalResponse, PornRequest, - PornResponse) + PornResponse, + ResourceRequirement) from llm_web_kit.model.policical import (get_singleton_political_detect, update_political_by_str) -from llm_web_kit.model.porn_detector import BertModel as PornBertModel class ModelType(Enum): @@ -42,6 +42,10 @@ def initialize(self) -> None: def _load_model(self): pass + @abstractmethod + def convert_result_to_response(self, result: dict) -> ModelResponse: + pass + def cleanup(self) -> None: if self.model: self._cleanup_model() @@ -82,12 +86,15 @@ def _load_model(self): except Exception as e: raise RuntimeError(f'Failed to load political CPU model: {e}') + def get_resource_requirement(self): + return ResourceRequirement(num_cpus=1, memory_GB=4, num_gpus=0) + def get_batch_config(self) -> BatchProcessConfig: return BatchProcessConfig( max_batch_size=128, optimal_batch_size=64, min_batch_size=8 ) - def predict_batch(self, contents: List[str]) -> List[float]: + def predict_batch(self, contents: List[str]) -> List[dict]: if not self.model: raise RuntimeError('Model not initialized') try: @@ -95,11 +102,19 @@ def predict_batch(self, contents: List[str]) -> List[float]: results = [] for content in contents: result = update_political_by_str(content) - results.append(result['political_prob']) + results.append(result) + return results except Exception as e: raise RuntimeError(f'Prediction failed: {e}') + def convert_result_to_response(self, result: dict) -> ModelResponse: + # raise NotImplementedError + # TODO convert result to response ensure the threshold + return PoliticalResponse( + remained=result['political_prob'] > 0.5, details=result + ) + class PoliticalPredictorImpl(BasePredictor): """涉政检测预测器实现.""" @@ -132,7 +147,7 @@ def predict_batch( if batch_contents: # 批量处理 probs = self.model.predict_batch(batch_contents) - responses = [PoliticalResponse(probability=prob) for prob in probs] + responses = [self.model.convert_result_to_response(prob) for prob in probs] except Exception as e: raise ModelRuntimeException(f'Political prediction failed: {e}') @@ -145,46 +160,87 @@ class PornEnGPUModel(BaseModelResource): def _load_model(self): try: - return PornBertModel('') + from llm_web_kit.model.porn_detector import \ + BertModel as PornEnModel + + return PornEnModel() except Exception as e: - raise RuntimeError(f'Failed to load porn GPU model: {e}') + raise ModelInitException(f'Failed to init the en pron model: {e}') + + def get_resource_requirement(self): + # S2 cluster has 128 CPUs, 1TB memory, 8 GPUs + # so we can use 16 CPUs, 64GB memory, 1 GPU for this model + return ResourceRequirement(num_cpus=16, memory_GB=64, num_gpus=1) def get_batch_config(self) -> BatchProcessConfig: return BatchProcessConfig( max_batch_size=128, optimal_batch_size=64, min_batch_size=8 ) - def predict_batch(self, contents: List[str]) -> List[float]: + def predict_batch(self, contents: List[str]) -> List[dict]: if not self.model: raise RuntimeError('Model not initialized') try: # 色情模型本身支持批处理 results = self.model.predict(contents) - return [result[self.model.get_output_key('prob')] for result in results] + return [ + {'porn_prob': result[self.model.get_output_key('prob')]} + for result in results + ] except Exception as e: raise RuntimeError(f'Prediction failed: {e}') + def convert_result_to_response(self, result: dict) -> ModelResponse: + # raise NotImplementedError + # TODO convert result to response ensure the threshold + return PornResponse(remained=result['porn_prob'] < 0.5, details=result) + class PornZhGPUModel(BaseModelResource): """中文色情检测GPU模型.""" def _load_model(self): - raise NotImplementedError('TODO') + try: + from llm_web_kit.model.porn_detector import \ + XlmrModel as PronZhModel + + return PronZhModel() + except Exception as e: + raise ModelInitException(f'Failed to init the zh porn model: {e}') + + def get_resource_requirement(self): + # S2 cluster has 128 CPUs, 1TB memory, 8 GPUs + # so we can use 16 CPUs, 64GB memory, 1 GPU for this model + return ResourceRequirement(num_cpus=16, memory_GB=64, num_gpus=1) def get_batch_config(self) -> BatchProcessConfig: - raise NotImplementedError('TODO') return BatchProcessConfig( max_batch_size=128, optimal_batch_size=64, min_batch_size=8 ) - def predict_batch(self, contents: List[str]) -> List[float]: - raise NotImplementedError('TODO') + def predict_batch(self, contents: List[str]) -> List[dict]: + if not self.model: + raise RuntimeError('Model not initialized') + try: + # 色情模型本身支持批处理 + results = self.model.predict(contents) + return [ + {'porn_prob': result[self.model.get_output_key('prob')]} + for result in results + ] + except Exception as e: + raise RuntimeError(f'Prediction failed: {e}') + + def convert_result_to_response(self, result: dict) -> ModelResponse: + # raise NotImplementedError + # TODO convert result to response ensure the threshold + return PornResponse(remained=result['porn_prob'] < 0.5, details=result) class PornPredictorImpl(BasePredictor): """色情检测预测器实现.""" - def _create_models(self, language: str) -> ModelResource: + def _create_model(self, language: str) -> ModelResource: if language == 'en': return PornEnGPUModel() elif language == 'zh': @@ -208,7 +264,7 @@ def predict_batch(self, requests: List[PornRequest]) -> List[PornResponse]: if batch_contents: # 批量处理 probs = self.model.predict_batch(batch_contents) - responses = [PornResponse(probability=prob) for prob in probs] + responses = [self.model.convert_result_to_response(prob) for prob in probs] except Exception as e: raise ModelRuntimeException(f'Porn prediction failed: {e}') return responses @@ -227,6 +283,7 @@ class ModelFactory: def create_predictor(cls, model_type: ModelType, language: str) -> BasePredictor: """创建预测器实例.""" predictor_class = cls._predictor_registry.get(model_type) + print(predictor_class) if not predictor_class: raise ValueError(f'No predictor registered for type: {model_type}') return predictor_class(language=language) diff --git a/llm_web_kit/model/model_interface.py b/llm_web_kit/model/model_interface.py index dc7ce031..009dce5a 100644 --- a/llm_web_kit/model/model_interface.py +++ b/llm_web_kit/model/model_interface.py @@ -17,7 +17,7 @@ class ModelRequest: class ModelResponse: """通用模型响应基类.""" - probability: float + remained: bool details: Dict[str, Any] = None @@ -66,27 +66,53 @@ class ResourceType(Enum): DEFAULT = 'default' -class ResourceRequirement: - """资源需求配置.""" +# class ResourceRequirement: +# """资源需求配置.""" + +# def __init__(self, resource_type: ResourceType = ResourceType.DEFAULT, num_cpus: int = 1, memory: int = 4 << 30): +# self.resource_type = resource_type +# self.num_cpus = num_cpus +# self.memory = memory + +# def to_ray_resources(self) -> Dict: +# """转换为Ray资源配置.""" +# resources = { +# 'num_cpus': self.num_cpus, +# 'memory': self.memory, +# } + +# # 根据资源类型设置正确的资源配置 +# if self.resource_type == ResourceType.CPU: +# resources['resources'] = {'cpu_only': 1} +# elif self.resource_type == ResourceType.GPU: +# # 使用 num_gpus 而不是在 resources 字典中设置 +# resources['num_gpus'] = 0.25 - def __init__(self, resource_type: ResourceType = ResourceType.DEFAULT, num_cpus: int = 1, memory: int = 4 << 30): - self.resource_type = resource_type +# return resources + + +class ResourceRequirement: + def __init__(self, num_cpus: float, memory_GB: float, num_gpus: float = 0.0): self.num_cpus = num_cpus - self.memory = memory + self.memory_GB = memory_GB + self.num_gpus = num_gpus def to_ray_resources(self) -> Dict: - """转换为Ray资源配置.""" - resources = { - 'num_cpus': self.num_cpus, - 'memory': self.memory, - } - - # 根据资源类型设置正确的资源配置 - if self.resource_type == ResourceType.CPU: - resources['resources'] = {'cpu_only': 1} - elif self.resource_type == ResourceType.GPU: - # 使用 num_gpus 而不是在 resources 字典中设置 - resources['num_gpus'] = 0.25 + if self.num_gpus > 0: + resources = { + 'num_cpus': self.num_cpus, + 'memory': self.memory_GB * 2**30, + 'num_gpus': self.num_gpus, + } + else: + # prefer to use CPU on CPU only node + # we set dummy resource "cpu_only" on CPU only node + # so set resources.cpu_only = 1 to ensure the task can be scheduled on CPU only node + resources = { + 'num_cpus': self.num_cpus, + 'memory': self.memory_GB * 2**30, + 'resources': {'cpu_only': 1}, + } return resources @@ -105,7 +131,7 @@ def get_batch_config(self) -> BatchProcessConfig: pass @abstractmethod - def predict_batch(self, contents: List[str]) -> List[float]: + def predict_batch(self, contents: List[str]) -> List[dict]: """批量预测.""" pass @@ -137,7 +163,9 @@ def predict_batch(self, requests: List[ModelRequest]) -> List[ModelResponse]: class PoliticalPredictor(ModelPredictor): """涉政预测器接口.""" - def predict_batch(self, requests: List[PoliticalRequest]) -> List[PoliticalResponse]: + def predict_batch( + self, requests: List[PoliticalRequest] + ) -> List[PoliticalResponse]: pass diff --git a/llm_web_kit/model/porn_detector.py b/llm_web_kit/model/porn_detector.py index bb57962e..8475afb3 100644 --- a/llm_web_kit/model/porn_detector.py +++ b/llm_web_kit/model/porn_detector.py @@ -45,41 +45,41 @@ def __init__(self, model_path: str = None) -> None: os.path.join(model_path, 'porn_classifier/classifier_hf') ) self.tokenizer_config = { - "padding": True, - "truncation": self.remain_tail <= 0, - "max_length": self.max_tokens if self.remain_tail <= 0 else None, - "return_tensors": "pt" if self.remain_tail <= 0 else None, + 'padding': True, + 'truncation': self.remain_tail <= 0, + 'max_length': self.max_tokens if self.remain_tail <= 0 else None, + 'return_tensors': 'pt' if self.remain_tail <= 0 else None, } - self.output_prefix = str(model_config.get("output_prefix", "")).rstrip("_") - self.output_postfix = str(model_config.get("output_postfix", "")).lstrip("_") + self.output_prefix = str(model_config.get('output_prefix', '')).rstrip('_') + self.output_postfix = str(model_config.get('output_postfix', '')).lstrip('_') - self.model_name = str(model_config.get("model_name", "porn-23w44")) + self.model_name = str(model_config.get('model_name', 'porn-23w44')) def auto_download(self) -> str: """Default download the 23w44.zip model.""" - resource_name = "porn-23w44" - resource_config = load_config()["resources"] + resource_name = 'porn-23w44' + resource_config = load_config()['resources'] porn_23w44_config: Dict = resource_config[resource_name] - porn_23w44_s3 = porn_23w44_config["download_path"] - porn_23w44_md5 = porn_23w44_config.get("md5", "") + porn_23w44_s3 = porn_23w44_config['download_path'] + porn_23w44_md5 = porn_23w44_config.get('md5', '') # get the zip path calculated by the s3 path - zip_path = os.path.join(CACHE_DIR, f"{resource_name}.zip") + zip_path = os.path.join(CACHE_DIR, f'{resource_name}.zip') # the unzip path is calculated by the zip path unzip_path = get_unzip_dir(zip_path) - logger.info(f"try to make unzip_path: {unzip_path}") + logger.info(f'try to make unzip_path: {unzip_path}') # if the unzip path does not exist, download the zip file and unzip it if not os.path.exists(unzip_path): - logger.info(f"unzip_path: {unzip_path} does not exist") - logger.info(f"try to unzip from zip_path: {zip_path}") + logger.info(f'unzip_path: {unzip_path} does not exist') + logger.info(f'try to unzip from zip_path: {zip_path}') if not os.path.exists(zip_path): - logger.info(f"zip_path: {zip_path} does not exist") - logger.info(f"downloading {porn_23w44_s3}") + logger.info(f'zip_path: {zip_path} does not exist') + logger.info(f'downloading {porn_23w44_s3}') zip_path = download_auto_file(porn_23w44_s3, zip_path, porn_23w44_md5) - logger.info(f"unzipping {zip_path}") + logger.info(f'unzipping {zip_path}') unzip_path = unzip_local_file(zip_path, unzip_path) else: - logger.info(f"unzip_path: {unzip_path} exist") + logger.info(f'unzip_path: {unzip_path} exist') return unzip_path def pre_process(self, samples: Union[List[str], str]) -> Dict: @@ -91,7 +91,7 @@ def pre_process(self, samples: Union[List[str], str]) -> Dict: processed_inputs = [] # 对每个输入进行处理 - for tokens_id in inputs["input_ids"]: + for tokens_id in inputs['input_ids']: # 通过sep_token_id找到tokens的长度 length = tokens_id.index(self.tokenizer.sep_token_id) + 1 # 如果tokens的长度小于等于max_tokens,则直接在尾部补0,不需要截断 @@ -128,17 +128,17 @@ def pre_process(self, samples: Union[List[str], str]) -> Dict: ), } inputs = {name: tensor.to(self.device) for name, tensor in inputs.items()} - return {"inputs": inputs} + return {'inputs': inputs} def get_output_key(self, f: str): prefix = self.output_prefix if self.output_prefix else self.model_name - postfix = f"_{self.output_postfix}" if self.output_postfix else "" - return f"{prefix}_{f}{postfix}" + postfix = f'_{self.output_postfix}' if self.output_postfix else '' + return f'{prefix}_{f}{postfix}' def predict(self, texts: Union[List[str], str]): inputs_dict = self.pre_process(texts) with torch.no_grad(): - logits = self.model(**inputs_dict["inputs"]).logits + logits = self.model(**inputs_dict['inputs']).logits if self.use_sigmoid: probs = torch.sigmoid(logits) @@ -150,7 +150,7 @@ def predict(self, texts: Union[List[str], str]): outputs = [] for prob in pos_prob: prob = round(float(prob), 6) - output = {self.get_output_key("prob"): prob} + output = {self.get_output_key('prob'): prob} outputs.append(output) return outputs @@ -160,69 +160,72 @@ class XlmrModel(BertModel): def __init__(self, model_path: str = None) -> None: if not model_path: model_path = self.auto_download() - self.model = AutoModelForSequenceClassification.from_pretrained( - os.path.join(model_path, "porn_classifier/classifier_hf") + + transformers_module = import_transformer() + + self.model = transformers_module.AutoModelForSequenceClassification.from_pretrained( + os.path.join(model_path, 'porn_classifier/classifier_hf') ) with open( - os.path.join(model_path, "porn_classifier/extra_parameters.json") + os.path.join(model_path, 'porn_classifier/extra_parameters.json') ) as reader: model_config = json.load(reader) - self.clip = bool(model_config.get("clip", False)) - self.max_tokens = int(model_config.get("max_tokens", 300)) + self.clip = bool(model_config.get('clip', False)) + self.max_tokens = int(model_config.get('max_tokens', 300)) self.remain_tail = min( - self.max_tokens - 1, int(model_config.get("remain_tail", -1)) + self.max_tokens - 1, int(model_config.get('remain_tail', -1)) ) - self.device = model_config.get("device", "cpu") + self.device = model_config.get('device', 'cpu') self.model.eval() self.model.to(self.device, dtype=torch.float16) - self.tokenizer = AutoTokenizer.from_pretrained( - os.path.join(model_path, "porn_classifier/classifier_hf") + self.tokenizer = transformers_module.AutoTokenizer.from_pretrained( + os.path.join(model_path, 'porn_classifier/classifier_hf') ) self.tokenizer_config = { - "padding": True, - "truncation": self.remain_tail <= 0, - "max_length": self.max_tokens if self.remain_tail <= 0 else None, - "return_tensors": "pt" if self.remain_tail <= 0 else None, + 'padding': True, + 'truncation': self.remain_tail <= 0, + 'max_length': self.max_tokens if self.remain_tail <= 0 else None, + 'return_tensors': 'pt' if self.remain_tail <= 0 else None, } - self.output_prefix = str(model_config.get("output_prefix", "")).rstrip("_") - self.output_postfix = str(model_config.get("output_postfix", "")).lstrip("_") + self.output_prefix = str(model_config.get('output_prefix', '')).rstrip('_') + self.output_postfix = str(model_config.get('output_postfix', '')).lstrip('_') - self.model_name = str(model_config.get("model_name", "porn-24m5")) + self.model_name = str(model_config.get('model_name', 'porn-24m5')) def auto_download(self) -> str: """Default download the 23w44.zip model.""" - resource_name = "porn-24m5" - resource_config = load_config()["resources"] + resource_name = 'porn-24m5' + resource_config = load_config()['resources'] porn_24m5_config: Dict = resource_config[resource_name] - porn_24m5_s3 = porn_24m5_config["download_path"] - porn_24m5_md5 = porn_24m5_config.get("md5", "") + porn_24m5_s3 = porn_24m5_config['download_path'] + porn_24m5_md5 = porn_24m5_config.get('md5', '') # get the zip path calculated by the s3 path - zip_path = os.path.join(CACHE_DIR, f"{resource_name}.zip") + zip_path = os.path.join(CACHE_DIR, f'{resource_name}.zip') # the unzip path is calculated by the zip path unzip_path = get_unzip_dir(zip_path) - logger.info(f"try to make unzip_path: {unzip_path}") + logger.info(f'try to make unzip_path: {unzip_path}') # if the unzip path does not exist, download the zip file and unzip it if not os.path.exists(unzip_path): - logger.info(f"unzip_path: {unzip_path} does not exist") - logger.info(f"try to unzip from zip_path: {zip_path}") + logger.info(f'unzip_path: {unzip_path} does not exist') + logger.info(f'try to unzip from zip_path: {zip_path}') if not os.path.exists(zip_path): - logger.info(f"zip_path: {zip_path} does not exist") - logger.info(f"downloading {porn_24m5_s3}") + logger.info(f'zip_path: {zip_path} does not exist') + logger.info(f'downloading {porn_24m5_s3}') zip_path = download_auto_file(porn_24m5_s3, zip_path, porn_24m5_md5) - logger.info(f"unzipping {zip_path}") + logger.info(f'unzipping {zip_path}') unzip_path = unzip_local_file(zip_path, unzip_path) else: - logger.info(f"unzip_path: {unzip_path} exist") + logger.info(f'unzip_path: {unzip_path} exist') return unzip_path def predict(self, texts: Union[List[str], str]): inputs_dict = self.pre_process(texts) with torch.no_grad(): - logits = self.model(**inputs_dict["inputs"]).logits + logits = self.model(**inputs_dict['inputs']).logits if self.clip: probs = logits.detach().cpu().numpy().clip(min=0, max=1) @@ -232,7 +235,7 @@ def predict(self, texts: Union[List[str], str]): outputs = [] for prob in probs: prob = round(float(prob[0]), 6) - output = {self.get_output_key("prob"): prob} + output = {self.get_output_key('prob'): prob} outputs.append(output) return outputs From b6ff7f23f87c4f5f18a01801e7d21e6525b1a952 Mon Sep 17 00:00:00 2001 From: yujing Date: Thu, 13 Mar 2025 17:57:54 +0800 Subject: [PATCH 20/32] add top readme of models --- docs/llm_web_kit/model/model_interface.md | 0 docs/llm_web_kit/model/readme.md | 16 ++++++++++++++++ .../model/rule_based_safety_module.md | 0 llm_web_kit/model/model_impl.py | 6 +++--- llm_web_kit/model/model_interface.py | 2 +- 5 files changed, 20 insertions(+), 4 deletions(-) create mode 100644 docs/llm_web_kit/model/model_interface.md create mode 100644 docs/llm_web_kit/model/readme.md create mode 100644 docs/llm_web_kit/model/rule_based_safety_module.md diff --git a/docs/llm_web_kit/model/model_interface.md b/docs/llm_web_kit/model/model_interface.md new file mode 100644 index 00000000..e69de29b diff --git a/docs/llm_web_kit/model/readme.md b/docs/llm_web_kit/model/readme.md new file mode 100644 index 00000000..83b1f7cc --- /dev/null +++ b/docs/llm_web_kit/model/readme.md @@ -0,0 +1,16 @@ +# 面向用户的接口 + +## html分类 +html_simplify_classify.md + +## 语言检测 +lang_id.md + +## 清洗模型 +clean_module.md + +## 安全规则 +rule_based_safety_module.md + +## 安全模型 +model_interface.md \ No newline at end of file diff --git a/docs/llm_web_kit/model/rule_based_safety_module.md b/docs/llm_web_kit/model/rule_based_safety_module.md new file mode 100644 index 00000000..e69de29b diff --git a/llm_web_kit/model/model_impl.py b/llm_web_kit/model/model_impl.py index 0d008a32..00c410f7 100644 --- a/llm_web_kit/model/model_impl.py +++ b/llm_web_kit/model/model_impl.py @@ -209,9 +209,9 @@ def _load_model(self): raise ModelInitException(f'Failed to init the zh porn model: {e}') def get_resource_requirement(self): - # S2 cluster has 128 CPUs, 1TB memory, 8 GPUs - # so we can use 16 CPUs, 64GB memory, 1 GPU for this model - return ResourceRequirement(num_cpus=16, memory_GB=64, num_gpus=1) + # S2 cluster has at least 96 CPUs, 1TB memory, 8 GPUs + # so we can use 12 CPUs, 64GB memory, 1 GPU for this model + return ResourceRequirement(num_cpus=12, memory_GB=64, num_gpus=1) def get_batch_config(self) -> BatchProcessConfig: return BatchProcessConfig( diff --git a/llm_web_kit/model/model_interface.py b/llm_web_kit/model/model_interface.py index 009dce5a..48b92fa7 100644 --- a/llm_web_kit/model/model_interface.py +++ b/llm_web_kit/model/model_interface.py @@ -17,7 +17,7 @@ class ModelRequest: class ModelResponse: """通用模型响应基类.""" - remained: bool + is_remained: bool details: Dict[str, Any] = None From 7ba7a08e93a3880d517d4edc27b1a90a36bf6d5a Mon Sep 17 00:00:00 2001 From: yujing Date: Tue, 18 Mar 2025 15:01:10 +0800 Subject: [PATCH 21/32] backup tests --- llm_web_kit/model/model_impl.py | 24 +- llm_web_kit/model/model_interface.py | 25 -- tests/llm_web_kit/model/test_model_impl.py | 312 ++++++++++++++++++ .../llm_web_kit/model/test_model_interface.py | 122 +++++++ 4 files changed, 446 insertions(+), 37 deletions(-) create mode 100644 tests/llm_web_kit/model/test_model_impl.py create mode 100644 tests/llm_web_kit/model/test_model_interface.py diff --git a/llm_web_kit/model/model_impl.py b/llm_web_kit/model/model_impl.py index 00c410f7..d2d44112 100644 --- a/llm_web_kit/model/model_impl.py +++ b/llm_web_kit/model/model_impl.py @@ -91,7 +91,7 @@ def get_resource_requirement(self): def get_batch_config(self) -> BatchProcessConfig: return BatchProcessConfig( - max_batch_size=128, optimal_batch_size=64, min_batch_size=8 + max_batch_size=1000, optimal_batch_size=512, min_batch_size=8 ) def predict_batch(self, contents: List[str]) -> List[dict]: @@ -112,7 +112,7 @@ def convert_result_to_response(self, result: dict) -> ModelResponse: # raise NotImplementedError # TODO convert result to response ensure the threshold return PoliticalResponse( - remained=result['political_prob'] > 0.5, details=result + is_remained=result['political_prob'] > 0.99, details=result ) @@ -165,16 +165,16 @@ def _load_model(self): return PornEnModel() except Exception as e: - raise ModelInitException(f'Failed to init the en pron model: {e}') + raise ModelInitException(f'Failed to init the en porn model: {e}') def get_resource_requirement(self): - # S2 cluster has 128 CPUs, 1TB memory, 8 GPUs - # so we can use 16 CPUs, 64GB memory, 1 GPU for this model - return ResourceRequirement(num_cpus=16, memory_GB=64, num_gpus=1) + # S2 cluster has 96 CPUs, 1TB memory, 8 GPUs + # so we can use 12 CPUs, 64GB memory, 1 GPU for this model + return ResourceRequirement(num_cpus=12, memory_GB=64, num_gpus=1) def get_batch_config(self) -> BatchProcessConfig: return BatchProcessConfig( - max_batch_size=128, optimal_batch_size=64, min_batch_size=8 + max_batch_size=1000, optimal_batch_size=512, min_batch_size=8 ) def predict_batch(self, contents: List[str]) -> List[dict]: @@ -193,7 +193,7 @@ def predict_batch(self, contents: List[str]) -> List[dict]: def convert_result_to_response(self, result: dict) -> ModelResponse: # raise NotImplementedError # TODO convert result to response ensure the threshold - return PornResponse(remained=result['porn_prob'] < 0.5, details=result) + return PornResponse(is_remained=result['porn_prob'] < 0.2, details=result) class PornZhGPUModel(BaseModelResource): @@ -202,9 +202,9 @@ class PornZhGPUModel(BaseModelResource): def _load_model(self): try: from llm_web_kit.model.porn_detector import \ - XlmrModel as PronZhModel + XlmrModel as PornZhModel - return PronZhModel() + return PornZhModel() except Exception as e: raise ModelInitException(f'Failed to init the zh porn model: {e}') @@ -215,7 +215,7 @@ def get_resource_requirement(self): def get_batch_config(self) -> BatchProcessConfig: return BatchProcessConfig( - max_batch_size=128, optimal_batch_size=64, min_batch_size=8 + max_batch_size=300, optimal_batch_size=256, min_batch_size=8 ) def predict_batch(self, contents: List[str]) -> List[dict]: @@ -234,7 +234,7 @@ def predict_batch(self, contents: List[str]) -> List[dict]: def convert_result_to_response(self, result: dict) -> ModelResponse: # raise NotImplementedError # TODO convert result to response ensure the threshold - return PornResponse(remained=result['porn_prob'] < 0.5, details=result) + return PornResponse(is_remained=result['porn_prob'] > 0.95, details=result) class PornPredictorImpl(BasePredictor): diff --git a/llm_web_kit/model/model_interface.py b/llm_web_kit/model/model_interface.py index 48b92fa7..3136cb99 100644 --- a/llm_web_kit/model/model_interface.py +++ b/llm_web_kit/model/model_interface.py @@ -66,31 +66,6 @@ class ResourceType(Enum): DEFAULT = 'default' -# class ResourceRequirement: -# """资源需求配置.""" - -# def __init__(self, resource_type: ResourceType = ResourceType.DEFAULT, num_cpus: int = 1, memory: int = 4 << 30): -# self.resource_type = resource_type -# self.num_cpus = num_cpus -# self.memory = memory - -# def to_ray_resources(self) -> Dict: -# """转换为Ray资源配置.""" -# resources = { -# 'num_cpus': self.num_cpus, -# 'memory': self.memory, -# } - -# # 根据资源类型设置正确的资源配置 -# if self.resource_type == ResourceType.CPU: -# resources['resources'] = {'cpu_only': 1} -# elif self.resource_type == ResourceType.GPU: -# # 使用 num_gpus 而不是在 resources 字典中设置 -# resources['num_gpus'] = 0.25 - -# return resources - - class ResourceRequirement: def __init__(self, num_cpus: float, memory_GB: float, num_gpus: float = 0.0): self.num_cpus = num_cpus diff --git a/tests/llm_web_kit/model/test_model_impl.py b/tests/llm_web_kit/model/test_model_impl.py new file mode 100644 index 00000000..3e656068 --- /dev/null +++ b/tests/llm_web_kit/model/test_model_impl.py @@ -0,0 +1,312 @@ +"""Test cases for model_impl.py.""" + +import unittest +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from llm_web_kit.exception.exception import ModelRuntimeException +from llm_web_kit.model.model_impl import (ModelFactory, ModelType, + PoliticalCPUModel, + PoliticalPredictorImpl, + PornEnGPUModel, PornPredictorImpl, + PornZhGPUModel) +from llm_web_kit.model.model_interface import PornRequest, PornResponse + + +class TestPoliticalCPUModel(TestCase): + """Test cases for PoliticalCPUModel.""" + + @patch.object(PoliticalCPUModel, '_load_model') + def test_load_model(self, mock_load_model): + """Test model loading.""" + mock_load_model.return_value = MagicMock() + model = PoliticalCPUModel() + model._load_model() + assert mock_load_model.call_count == 1 + + @patch.object(PoliticalCPUModel, '_load_model') + def test_get_resource_requirement(self, mock_load_model): + """Test resource requirements.""" + mock_load_model.return_value = MagicMock() + model = PoliticalCPUModel() + resource_requirement = model.get_resource_requirement() + assert resource_requirement.num_cpus == 1 + assert resource_requirement.memory_GB == 4 + assert resource_requirement.num_gpus == 0 + + @patch.object(PoliticalCPUModel, '_load_model') + def test_get_batch_config(self, mock_load_model): + """Test batch configuration.""" + mock_load_model.return_value = MagicMock() + model = PoliticalCPUModel() + batch_config = model.get_batch_config() + assert batch_config.max_batch_size == 1000 + assert batch_config.optimal_batch_size == 512 + assert batch_config.min_batch_size == 8 + + @patch.object(PoliticalCPUModel, '_load_model') + @patch('llm_web_kit.model.model_impl.update_political_by_str') + def test_predict_batch(self, mock_update_political_by_str, mock_load_model): + """Test batch prediction.""" + mock_model = MagicMock() + mock_load_model.return_value = mock_model + mock_update_political_by_str.return_value = {'political_prob': 0.96} + + model = PoliticalCPUModel() + model.model = mock_model + + results = model.predict_batch(['test1', 'test2']) + assert len(results) == 2 + assert results[0]['political_prob'] == 0.96 + assert results[1]['political_prob'] == 0.96 + assert mock_update_political_by_str.call_count == 2 + + @patch.object(PoliticalCPUModel, '_load_model') + def test_convert_result_to_response(self, mock_load_model): + """Test result conversion to response.""" + mock_load_model.return_value = MagicMock() + model = PoliticalCPUModel() + + # Test case where political_prob > 0.99 (should be flagged) + result = {'political_prob': 0.995} + response = model.convert_result_to_response(result) + assert response.is_remained + assert response.details == result + + # Test case where political_prob <= 0.99 (should not be flagged) + result = {'political_prob': 0.985} + response = model.convert_result_to_response(result) + assert not response.is_remained + assert response.details == result + + +class TestPornEnGPUModel(TestCase): + """Test cases for PornEnGPUModel.""" + + from llm_web_kit.model.porn_detector import BertModel as PornEnModel + + @patch.object(PornEnModel, '__init__') + def test_load_model(self, mock_init): + """Test model loading.""" + mock_init.return_value = None + model = PornEnGPUModel() + model._load_model() + assert mock_init.call_count == 1 + + @patch.object(PornEnGPUModel, '_load_model') + def test_get_resource_requirement(self, mock_load_model): + """Test resource requirements.""" + mock_load_model.return_value = MagicMock() + model = PornEnGPUModel() + resource_requirement = model.get_resource_requirement() + assert resource_requirement.num_cpus == 12 + assert resource_requirement.memory_GB == 64 + assert resource_requirement.num_gpus == 1 + + @patch.object(PornEnGPUModel, '_load_model') + def test_get_batch_config(self, mock_load_model): + """Test batch configuration.""" + mock_load_model.return_value = MagicMock() + model = PornEnGPUModel() + batch_config = model.get_batch_config() + assert batch_config.max_batch_size == 1000 + assert batch_config.optimal_batch_size == 512 + assert batch_config.min_batch_size == 8 + + @patch.object(PornEnGPUModel, '_load_model') + def test_predict_batch(self, mock_load_model): + """Test batch prediction.""" + mock_model = MagicMock() + mock_model.get_output_key.return_value = 'prob' + mock_model.predict.return_value = [{'prob': 0.96}, {'prob': 0.94}] + mock_load_model.return_value = mock_model + + model = PornEnGPUModel() + model.model = mock_model + + results = model.predict_batch(['test1', 'test2']) + assert len(results) == 2 + assert results[0]['porn_prob'] == 0.96 + assert results[1]['porn_prob'] == 0.94 + + # Test model not initialized + model.model = None + with self.assertRaises(RuntimeError): + model.predict_batch(['test']) + + @patch.object(PornEnGPUModel, '_load_model') + def test_convert_result_to_response(self, mock_load_model): + """Test result conversion to response.""" + mock_load_model.return_value = MagicMock() + model = PornEnGPUModel() + + # Test with high probability (should be remained) + response = model.convert_result_to_response({'porn_prob': 0.21}) + assert isinstance(response, PornResponse) + assert not response.is_remained + assert response.details == {'porn_prob': 0.21} + + # Test with low probability (should not be remained) + response = model.convert_result_to_response({'porn_prob': 0.19}) + assert isinstance(response, PornResponse) + assert response.is_remained + assert response.details == {'porn_prob': 0.19} + + +class TestPornZhGPUModel(TestCase): + """Test cases for PornZhGPUModel.""" + + from llm_web_kit.model.porn_detector import XlmrModel as PornZhModel + + @patch.object(PornZhModel, '__init__') + def test_load_model(self, mock_init): + """Test model loading.""" + mock_init.return_value = None + model = PornZhGPUModel() + model._load_model() + assert mock_init.call_count == 1 + + @patch.object(PornZhGPUModel, '_load_model') + def test_get_resource_requirement(self, mock_load_model): + """Test resource requirements.""" + mock_load_model.return_value = MagicMock() + model = PornZhGPUModel() + resource_requirement = model.get_resource_requirement() + assert resource_requirement.num_cpus == 12 + assert resource_requirement.memory_GB == 64 + assert resource_requirement.num_gpus == 1 + + @patch.object(PornZhGPUModel, '_load_model') + def test_get_batch_config(self, mock_load_model): + """Test batch configuration.""" + mock_load_model.return_value = MagicMock() + model = PornZhGPUModel() + batch_config = model.get_batch_config() + assert batch_config.max_batch_size == 300 + assert batch_config.optimal_batch_size == 256 + assert batch_config.min_batch_size == 8 + + @patch.object(PornZhGPUModel, '_load_model') + def test_predict_batch(self, mock_load_model): + """Test batch prediction.""" + mock_model = MagicMock() + mock_model.get_output_key.return_value = 'prob' + mock_model.predict.return_value = [{'prob': 0.96}, {'prob': 0.94}] + mock_load_model.return_value = mock_model + + model = PornZhGPUModel() + model.model = mock_model + + results = model.predict_batch(['test1', 'test2']) + assert len(results) == 2 + assert results[0]['porn_prob'] == 0.96 + assert results[1]['porn_prob'] == 0.94 + + # Test model not initialized + model.model = None + with self.assertRaises(RuntimeError): + model.predict_batch(['test']) + + @patch.object(PornZhGPUModel, '_load_model') + def test_convert_result_to_response(self, mock_load_model): + """Test result conversion to response.""" + mock_load_model.return_value = MagicMock() + model = PornZhGPUModel() + + # Test with high probability (should be remained) + response = model.convert_result_to_response({'porn_prob': 0.96}) + assert isinstance(response, PornResponse) + assert response.is_remained + assert response.details == {'porn_prob': 0.96} + + # Test with low probability (should not be remained) + response = model.convert_result_to_response({'porn_prob': 0.94}) + assert isinstance(response, PornResponse) + assert not response.is_remained + assert response.details == {'porn_prob': 0.94} + + +class TestPornPredictorImpl(TestCase): + """Test cases for PornPredictorImpl.""" + + @patch.object(PornEnGPUModel, '_load_model') + @patch.object(PornZhGPUModel, '_load_model') + @patch.object(PornZhGPUModel, 'predict_batch') + @patch.object(PornEnGPUModel, 'predict_batch') + def test_predict_batch(self, mock_predict_batch_en, mock_predict_batch_zh, + mock_load_model_en, mock_load_model_zh): + """Test batch prediction.""" + mock_load_model_en.return_value = MagicMock() + mock_load_model_zh.return_value = MagicMock() + mock_predict_batch_en.return_value = [{'porn_prob': 0.19}, {'porn_prob': 0.3}] + mock_predict_batch_zh.return_value = [{'porn_prob': 0.19}, {'porn_prob': 0.3}] + + predictor = PornPredictorImpl('en') + assert predictor.language == 'en' + with self.assertRaises(ModelRuntimeException): + predictor.predict_batch([ + PornRequest(content='Hello, world!', language='en'), + PornRequest(content='你好', language='zh') + ]) + assert mock_predict_batch_en.call_count == 1 + + results = predictor.predict_batch([ + PornRequest(content='Hello, world!', language='en'), + PornRequest(content='nihao', language='en') + ]) + assert results[0].is_remained + assert not results[1].is_remained + + @patch.object(PornEnGPUModel, '_load_model') + @patch.object(PornZhGPUModel, '_load_model') + def test_create_model(self, mock_load_model_en, mock_load_model_zh): + """Test model creation.""" + mock_load_model_en.return_value = MagicMock() + mock_load_model_zh.return_value = MagicMock() + predictor = PornPredictorImpl('en') + assert predictor.language == 'en' + assert predictor.model is not None + + +def test_model_factory(): + """Test ModelFactory creation.""" + factory = ModelFactory() + assert factory is not None + + +class TestModelFactory(TestCase): + """Test cases for ModelFactory.""" + + @patch.object(PoliticalPredictorImpl, '_create_model') + @patch.object(PoliticalCPUModel, '_load_model') + def test_create_predictor(self, mock_load_model, mock_create_model): + """Test ModelFactory.create_predictor method.""" + mock_load_model.return_value = MagicMock() + mock_create_model.return_value = MagicMock() + predictor = ModelFactory.create_predictor(ModelType.POLITICAL, 'en') + assert isinstance(predictor, PoliticalPredictorImpl) + assert mock_create_model.call_count == 1 + + @patch.object(PornPredictorImpl, '_create_model') + @patch.object(PornEnGPUModel, '_load_model') + def test_create_predictor_porn(self, mock_load_model, mock_create_model): + """Test ModelFactory.create_predictor method for porn model.""" + mock_load_model.return_value = MagicMock() + mock_create_model.return_value = MagicMock() + predictor = ModelFactory.create_predictor(ModelType.PORN, 'en') + assert isinstance(predictor, PornPredictorImpl) + assert mock_create_model.call_count == 1 + + @patch.object(PornPredictorImpl, '_create_model') + @patch.object(PornZhGPUModel, '_load_model') + def test_create_predictor_porn_zh(self, mock_load_model, mock_create_model): + """Test ModelFactory.create_predictor method for porn model.""" + mock_load_model.return_value = MagicMock() + mock_create_model.return_value = MagicMock() + predictor = ModelFactory.create_predictor(ModelType.PORN, 'zh') + assert isinstance(predictor, PornPredictorImpl) + assert mock_create_model.call_count == 1 + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/llm_web_kit/model/test_model_interface.py b/tests/llm_web_kit/model/test_model_interface.py new file mode 100644 index 00000000..d16c2042 --- /dev/null +++ b/tests/llm_web_kit/model/test_model_interface.py @@ -0,0 +1,122 @@ +"""Test cases for model_interface.py.""" + +import pytest + +from llm_web_kit.model.model_interface import (BatchProcessConfig, + ModelPredictor, ModelRequest, + ModelResource, ModelResponse, + PoliticalPredictor, + PoliticalRequest, + PoliticalResponse, + PornPredictor, PornRequest, + PornResponse, + ResourceRequirement, + ResourceType) + + +def test_model_request() -> None: + """Test ModelRequest initialization and attributes.""" + request = ModelRequest(content='Hello, world!', language='en') + assert request.content == 'Hello, world!' + assert request.language == 'en' + + +def test_model_response() -> None: + """Test ModelResponse initialization and attributes.""" + response = ModelResponse(is_remained=True, details={'score': 0.95}) + assert response.is_remained + assert response.details['score'] == 0.95 + + +def test_political_request() -> None: + """Test PoliticalRequest initialization and attributes.""" + request = PoliticalRequest(content='Hello, world!', language='en') + assert request.content == 'Hello, world!' + assert request.language == 'en' + + +def test_political_response() -> None: + """Test PoliticalResponse initialization and attributes.""" + response = PoliticalResponse(is_remained=True, details={'score': 0.95}) + assert response.is_remained + assert response.details['score'] == 0.95 + + +def test_porn_request() -> None: + """Test PornRequest initialization and attributes.""" + request = PornRequest(content='Hello, world!', language='en') + assert request.content == 'Hello, world!' + assert request.language == 'en' + + +def test_porn_response() -> None: + """Test PornResponse initialization and attributes.""" + response = PornResponse(is_remained=True, details={'score': 0.95}) + assert response.is_remained + assert response.details['score'] == 0.95 + + +def test_batch_process_config() -> None: + """Test BatchProcessConfig initialization and attributes.""" + config = BatchProcessConfig( + max_batch_size=100, + optimal_batch_size=50, + min_batch_size=10 + ) + assert config.max_batch_size == 100 + assert config.optimal_batch_size == 50 + assert config.min_batch_size == 10 + + +def test_resource_type() -> None: + """Test ResourceType enum values.""" + assert ResourceType.CPU + assert ResourceType.GPU + assert ResourceType.DEFAULT + + +def test_resource_requirement() -> None: + """Test ResourceRequirement initialization and conversion to ray + resources.""" + requirement = ResourceRequirement(num_cpus=1, memory_GB=4, num_gpus=0.25) + assert requirement.num_cpus == 1 + assert requirement.memory_GB == 4 + assert requirement.num_gpus == 0.25 + assert requirement.to_ray_resources() == { + 'num_cpus': 1, + 'memory': 4 * 2**30, + 'num_gpus': 0.25 + } + + cpu_only_requirement = ResourceRequirement(num_cpus=1, memory_GB=4) + assert cpu_only_requirement.to_ray_resources() == { + 'num_cpus': 1, + 'memory': 4 * 2**30, + 'resources': {'cpu_only': 1} + } + + +def test_model_resource() -> None: + """Test that ModelResource cannot be instantiated directly.""" + with pytest.raises(TypeError): + ModelResource() + + +def test_model_predictor() -> None: + """Test that ModelPredictor cannot be instantiated directly.""" + with pytest.raises(TypeError): + ModelPredictor() + + +def test_political_predictor() -> None: + """Test PoliticalPredictor implementation.""" + with pytest.raises(TypeError): + predictor = PoliticalPredictor() + assert isinstance(predictor, ModelPredictor) + + +def test_porn_predictor() -> None: + """Test PornPredictor implementation.""" + with pytest.raises(TypeError): + predictor = PornPredictor() + assert isinstance(predictor, ModelPredictor) From e906ba59734a8348fcf753a0b5267b8036bd0e13 Mon Sep 17 00:00:00 2001 From: yujing Date: Tue, 18 Mar 2025 15:04:49 +0800 Subject: [PATCH 22/32] backup tests --- .../model/model_based_safety_module.py | 189 ------------------ .../model/test_model_based_safe_model.py | 0 2 files changed, 189 deletions(-) delete mode 100644 llm_web_kit/model/model_based_safety_module.py delete mode 100644 tests/llm_web_kit/model/test_model_based_safe_model.py diff --git a/llm_web_kit/model/model_based_safety_module.py b/llm_web_kit/model/model_based_safety_module.py deleted file mode 100644 index f48a318a..00000000 --- a/llm_web_kit/model/model_based_safety_module.py +++ /dev/null @@ -1,189 +0,0 @@ -from typing import List, Tuple, Any, Type, TypeVar -from llm_web_kit.model.policical import PoliticalDetector, decide_political_by_prob -from llm_web_kit.model.porn_detector import BertModel as EnPornBertModel -from llm_web_kit.model.porn_detector import XlmrModel as ZhPornXlmrModel -from llm_web_kit.exception.exception import ModelInputException - -I = TypeVar("I") # input type -B = TypeVar("B") # batch type - - -def check_type(arg_name: str, arg_value: Any, arg_type: Type): - """check the type of the argument and raise TypeError if the type is not - matched.""" - if not isinstance(arg_value, arg_type): - # TODO change TypeError to custom exception - raise TypeError( - "The type of {} should be {}, but got {}".format( - arg_name, arg_type, type(arg_value) - ) - ) - - -class ModelBasedSafetyDataPack: - """The data pack for the model based safety module.""" - - def __init__(self, content_str: str, language: str, language_details: str): - - self._dict = {} - # the content of the dataset - check_type("content_str", content_str, str) - self._dict["content_str"] = content_str - - # the language of the content - check_type("language", language, str) - self._dict["language"] = language - - # the details of the language - check_type("language_details", language_details, str) - self._dict["language_details"] = language_details - - # the flag of the processed data should be remained or not - self._dict["model_based_safety_remained"] = True - - # the details of the model based safety process - self._dict["model_based_safety_infos"] = {} - - @classmethod - def from_dict(cls, data: dict): - new_data_pack = cls( - content_str=data["content_str"], - language=data["language"], - language_details=data["language_details"], - ) - new_data_pack._dict.update(data) - return new_data_pack - - def as_dict(self) -> dict: - return self._dict - - def set_process_result( - self, model_based_safety_remained: bool, model_based_safety_infos: dict - ) -> None: - """set the process result of the model based safety module.""" - check_type("model_based_safety_remained", model_based_safety_remained, bool) - check_type("model_based_safety_infos", model_based_safety_infos, dict) - if model_based_safety_remained is False: - self._dict["model_based_safety_remained"] = False - self._dict["model_based_safety_infos"].update(model_based_safety_infos) - - def get_output(self) -> dict: - """get the output of the data pack.""" - return { - "model_based_safety_remained": self._dict["model_based_safety_remained"], - "model_based_safety_infos": self._dict["model_based_safety_infos"], - } - - -class ContentStrBatchModel: - def __init__(self, model_config: dict): - self.model_config = model_config - - def check_support(self, data_pack: ModelBasedSafetyDataPack) -> bool: - raise NotImplementedError - - def preprocess(self, data_pack: ModelBasedSafetyDataPack) -> Tuple[dict, I]: - if not self.check_support(data_pack): - # use class name - model_name = self.__class__.__name__ - raise ModelInputException( - f"The data pack is not supported for {model_name}." - ) - return data_pack.as_dict(), data_pack._dict["content_str"] - - def collate_fn(self, lst: List[Tuple[dict, I]]) -> Tuple[List[dict], B]: - infos, batch = zip(*lst) - return list(infos), list(batch) - - def inference(self, batch: B) -> List[dict]: - """(batch: B) -> results""" - raise NotImplementedError() - - def postprocess(self, info: dict, result: dict) -> ModelBasedSafetyDataPack: - """(info: dict, result: dict) -> output""" - # return {**info, **result} - raise NotImplementedError() - - def process_one_core( - self, data_pack: ModelBasedSafetyDataPack - ) -> ModelBasedSafetyDataPack: - info, batch = self.preprocess(data_pack) - batch = self.collate_fn([(info, batch)])[1] - results = self.inference(batch) - return self.postprocess(info, results[0]) - - def process_one( - self, content_str: str, language: str, language_details: str - ) -> dict: - data_pack = ModelBasedSafetyDataPack(content_str, language, language_details) - return self.process_one_core(data_pack).get_output() - - -class ZhEnPoliticalModel(ContentStrBatchModel): - def __init__(self, model_config: dict): - super().__init__(model_config) - self.political_detect = PoliticalDetector( - model_path=model_config["model_path"], - ) - self.threshold = model_config["threshold"] - - def check_support(self, data_pack: ModelBasedSafetyDataPack) -> bool: - return data_pack.language in ["zh", "en"] - - def inference(self, batch: List[str]) -> List[dict]: - result_list = [] - for content_str in batch: - predictions, probabilities = self.political_detect.predict(content_str) - normal_score = decide_political_by_prob(predictions, probabilities) - result_list.append( - { - "political_prob": normal_score, - "political_info": { - "predictions": predictions, - "probabilities": probabilities, - }, - } - ) - return result_list - - def postprocess(self, info: dict, result: dict) -> dict: - remained = result["political_prob"] > self.threshold - datapack = ModelBasedSafetyDataPack.from_dict(info) - datapack.set_process_result( - model_based_safety_remained=remained, model_based_safety_infos=result - ) - return datapack - - -class EnPornModel(ContentStrBatchModel): - def __init__(self, model_config: dict): - super().__init__(model_config) - self.model = EnPornBertModel(model_config["model_path"]) - self.threshold = model_config["threshold"] - - def check_support(self, data_pack: ModelBasedSafetyDataPack) -> bool: - return data_pack.language == "en" - - def inference(self, batch: List[str]) -> List[dict]: - result_list = [] - for content_str in batch: - prob = self.model.predict(content_str) - result_list.append(prob) - return result_list - - def postprocess(self, info: dict, result: dict) -> dict: - porn_prob = list(result[0].values())[0] - remained = porn_prob < self.threshold - datapack = ModelBasedSafetyDataPack.from_dict(info) - datapack.set_process_result( - model_based_safety_remained=remained, - model_based_safety_infos={"porn_prob": porn_prob}, - ) - return datapack - - -class ZhPornModel(EnPornModel): - def __init__(self, model_config: dict): - self.model_config = model_config - self.model = ZhPornXlmrModel(model_config["model_path"]) - self.threshold = model_config["threshold"] diff --git a/tests/llm_web_kit/model/test_model_based_safe_model.py b/tests/llm_web_kit/model/test_model_based_safe_model.py deleted file mode 100644 index e69de29b..00000000 From 55712d62f23498177bab5da87cb7b69861c6b76b Mon Sep 17 00:00:00 2001 From: yujing Date: Tue, 18 Mar 2025 15:19:19 +0800 Subject: [PATCH 23/32] lint code --- llm_web_kit/model/unsafe_words_detector.py | 203 ++++++++++----------- 1 file changed, 101 insertions(+), 102 deletions(-) diff --git a/llm_web_kit/model/unsafe_words_detector.py b/llm_web_kit/model/unsafe_words_detector.py index ce831b5f..3acd57d9 100644 --- a/llm_web_kit/model/unsafe_words_detector.py +++ b/llm_web_kit/model/unsafe_words_detector.py @@ -1,6 +1,6 @@ import os import time -from typing import Any, Tuple, Dict +from typing import Any, Dict, Tuple import ahocorasick @@ -13,66 +13,65 @@ from llm_web_kit.model.resource_utils import (CACHE_DIR, download_auto_file, singleton_resource_manager) - xyz_language_lst = [ - "ar", - "cs", - "hu", - "sr", - "ru", - "ko", - "vi", - "th", - "arb", - "arb_Arab", - "arb_Latn", - "ces", - "ces_Latn", - "hun", - "hun_Latn", - "srp", - "srp_Cyrl", - "rus", - "rus_Cyrl", - "kor", - "kor_Hang", - "vie", - "vie_Latn", - "tha", - "tha_Thai", + 'ar', + 'cs', + 'hu', + 'sr', + 'ru', + 'ko', + 'vi', + 'th', + 'arb', + 'arb_Arab', + 'arb_Latn', + 'ces', + 'ces_Latn', + 'hun', + 'hun_Latn', + 'srp', + 'srp_Cyrl', + 'rus', + 'rus_Cyrl', + 'kor', + 'kor_Hang', + 'vie', + 'vie_Latn', + 'tha', + 'tha_Thai', ] level_score_map = { - "L1": 100, - "L2": 10, - "L3": 1, - "L4": 0.1, + 'L1': 100, + 'L2': 10, + 'L3': 1, + 'L4': 0.1, } -def auto_download(language="zh-en"): - resource_config = load_config()["resources"] - if language == "zh-en": - resource_name = "unsafe_words" - elif language == "xyz": - resource_name = "xyz_internal_unsafe_words" +def auto_download(language='zh-en'): + resource_config = load_config()['resources'] + if language == 'zh-en': + resource_name = 'unsafe_words' + elif language == 'xyz': + resource_name = 'xyz_internal_unsafe_words' else: - raise SafeModelException(f"Unsupported language: {language}") + raise SafeModelException(f'Unsupported language: {language}') language_unsafe_words_config: Dict = resource_config[resource_name] - download_path = language_unsafe_words_config["download_path"] - md5 = language_unsafe_words_config["md5"] + download_path = language_unsafe_words_config['download_path'] + md5 = language_unsafe_words_config['md5'] local_path = os.path.join(CACHE_DIR, resource_name) unsafe_words_file_path = download_auto_file(download_path, local_path, md5) return unsafe_words_file_path -def get_ac(language="zh-en"): +def get_ac(language='zh-en'): t1 = time.time() unsafe_words_file_path = auto_download(language) t2 = time.time() print( - f"-----------------auto_download cost time: {t2-t1} , language: {language}------------------" + f'-----------------auto_download cost time: {t2-t1} , language: {language}------------------' ) - with open(unsafe_words_file_path, "r") as f: + with open(unsafe_words_file_path, 'r') as f: lines = f.readlines() # sub_word: [{ @@ -87,27 +86,27 @@ def get_ac(language="zh-en"): words = {} for line in lines: w = json_loads(line) - word = str(w.get("word") or "").lower() + word = str(w.get('word') or '').lower() if not word: continue if is_pure_en_word(word) and len(word) <= 4: continue - sub_words = word.split("&&&") + sub_words = word.split('&&&') w_info = { - "word": word, - "sub_words": set(sub_words), - "type": w.get("type"), - "level": w.get("level"), - "language": w.get("language"), - "applicable": w.get("applicable"), - "unapplicable": w.get("unapplicable"), + 'word': word, + 'sub_words': set(sub_words), + 'type': w.get('type'), + 'level': w.get('level'), + 'language': w.get('language'), + 'applicable': w.get('applicable'), + 'unapplicable': w.get('unapplicable'), } for sub_word in sub_words: lst = words.get(sub_word, []) - lst.append({"sub_word": sub_word, **w_info}) + lst.append({'sub_word': sub_word, **w_info}) words[sub_word] = lst ac = ahocorasick.Automaton() @@ -138,7 +137,7 @@ def is_word_standalone(sub_word, end_pos): # 遍历所有匹配的子词及其结束位置pos for pos, w_info_lst in ac.iter(content): for w_info in w_info_lst: - sub_word = w_info["sub_word"] + sub_word = w_info['sub_word'] if is_word_standalone(sub_word, pos): all_sub_words.add(sub_word) all_w_info_lst.append(w_info) @@ -146,26 +145,26 @@ def is_word_standalone(sub_word, end_pos): unsafe_words = {} for w_info in all_w_info_lst: # 检查该词的所有子词是否均被匹配到 - if all_sub_words.issuperset(w_info["sub_words"]): - if w_info["word"] not in unsafe_words: - unsafe_words[w_info["word"]] = { - "word": w_info["word"], - "type": w_info["type"], - "level": w_info["level"], - "language": w_info["language"], - "count": 0.0, + if all_sub_words.issuperset(w_info['sub_words']): + if w_info['word'] not in unsafe_words: + unsafe_words[w_info['word']] = { + 'word': w_info['word'], + 'type': w_info['type'], + 'level': w_info['level'], + 'language': w_info['language'], + 'count': 0.0, } - unsafe_words[w_info["word"]]["count"] += 1.0 / len(w_info["sub_words"]) + unsafe_words[w_info['word']]['count'] += 1.0 / len(w_info['sub_words']) return list(unsafe_words.values()) class UnsafeWordChecker: - def __init__(self, language="zh-en") -> None: + def __init__(self, language='zh-en') -> None: t1 = time.time() self.ac = get_ac(language) t2 = time.time() print( - f"---------------UnsafeWordChecker init time: {t2-t1} , language: {language}-----------------" + f'---------------UnsafeWordChecker init time: {t2-t1} , language: {language}-----------------' ) def check_unsafe_words(self, content_str: str) -> list: @@ -173,7 +172,7 @@ def check_unsafe_words(self, content_str: str) -> list: return unsafe_words_list -def get_unsafe_words_checker(language="zh-en") -> UnsafeWordChecker: +def get_unsafe_words_checker(language='zh-en') -> UnsafeWordChecker: if not singleton_resource_manager.has_name(language): singleton_resource_manager.set_resource(language, UnsafeWordChecker(language)) return singleton_resource_manager.get_resource(language) @@ -187,12 +186,12 @@ def decide_data_unsafe_word_by_data_checker( unsafe_words_list = unsafeWordChecker.check_unsafe_words(content_str=content_str) unsafe_word_levels = [] for w in unsafe_words_list: - _, level, _ = w["word"], w["level"], w["count"] + _, level, _ = w['word'], w['level'], w['count'] # "涉政|观测|L4|带头人" unsafe_word_levels.append(level) unsafe_word_levels = list(set(unsafe_word_levels)) - unsafe_word_min_level = min(unsafe_word_levels + ["NF"]) + unsafe_word_min_level = min(unsafe_word_levels + ['NF']) return unsafe_word_min_level @@ -203,12 +202,12 @@ def decide_content_unsafe_word_by_data_checker( unsafe_words_list = unsafeWordChecker.check_unsafe_words(content_str=content_str) unsafe_word_levels = [] for w in unsafe_words_list: - _, level, _ = w["word"], w["level"], w["count"] + _, level, _ = w['word'], w['level'], w['count'] # "涉政|观测|L4|带头人" unsafe_word_levels.append(level) unsafe_word_levels = list(set(unsafe_word_levels)) - unsafe_word_min_level = min(unsafe_word_levels + ["NF"]) + unsafe_word_min_level = min(unsafe_word_levels + ['NF']) return unsafe_word_min_level @@ -217,21 +216,21 @@ def unsafe_words_filter( data_dict: Dict[str, Any], language: str, content_style: str ) -> str: if language in xyz_language_lst: - language = "xyz" + language = 'xyz' elif language in [ - "zh", - "en", - "yue", - "zho", - "eng", - "zho_Hans", - "zho_Hant", - "yue_Hant", - "eng_Latn", + 'zh', + 'en', + 'yue', + 'zho', + 'eng', + 'zho_Hans', + 'zho_Hant', + 'yue_Hant', + 'eng_Latn', ]: - language = "zh-en" + language = 'zh-en' else: - raise SafeModelException(f"Unsupported language: {language}") + raise SafeModelException(f'Unsupported language: {language}') unsafeWordChecker = get_unsafe_words_checker(language) unsafe_word_min_level = decide_data_unsafe_word_by_data_checker( @@ -249,14 +248,14 @@ def unsafe_words_filter_overall( from_domestic_source, ): if from_safe_source: - return {"hit_unsafe_words": False} + return {'hit_unsafe_words': False} if from_domestic_source: - unsafe_range = ("L1",) + unsafe_range = ('L1',) else: - unsafe_range = ("L1", "L2") + unsafe_range = ('L1', 'L2') unsafe_word_min_level = unsafe_words_filter(data_dict, language, content_style) hit = unsafe_word_min_level in unsafe_range - return {"hit_unsafe_words": hit} + return {'hit_unsafe_words': hit} class UnsafeWordsFilter: @@ -273,30 +272,30 @@ def filter( from_domestic_source: bool, ) -> Tuple[bool, Dict[str, Any]]: if language in xyz_language_lst: - language = "xyz" + language = 'xyz' elif language in [ - "zh", - "en", - "yue", - "zho", - "eng", - "zho_Hans", - "zho_Hant", - "yue_Hant", - "eng_Latn", + 'zh', + 'en', + 'yue', + 'zho', + 'eng', + 'zho_Hans', + 'zho_Hant', + 'yue_Hant', + 'eng_Latn', ]: - language = "zh-en" + language = 'zh-en' else: - raise SafeModelException(f"Unsupported language: {language}") + raise SafeModelException(f'Unsupported language: {language}') if from_safe_source: - return True, {"hit_unsafe_words": False} + return True, {'hit_unsafe_words': False} if from_domestic_source: - unsafe_range = ("L1",) + unsafe_range = ('L1',) else: - unsafe_range = ("L1", "L2") + unsafe_range = ('L1', 'L2') unsafe_word_min_level = decide_content_unsafe_word_by_data_checker( content_str, get_unsafe_words_checker(language) ) hit = unsafe_word_min_level in unsafe_range - return not hit, {"hit_unsafe_words": hit} \ No newline at end of file + return not hit, {'hit_unsafe_words': hit} From ba5e3531aa1dcef030ddd0cebdcf685faa175929 Mon Sep 17 00:00:00 2001 From: yujing Date: Tue, 18 Mar 2025 15:41:33 +0800 Subject: [PATCH 24/32] lint readme --- docs/llm_web_kit/model/readme.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/llm_web_kit/model/readme.md b/docs/llm_web_kit/model/readme.md index 83b1f7cc..4f2ad8b3 100644 --- a/docs/llm_web_kit/model/readme.md +++ b/docs/llm_web_kit/model/readme.md @@ -1,16 +1,21 @@ # 面向用户的接口 ## html分类 + html_simplify_classify.md ## 语言检测 + lang_id.md ## 清洗模型 + clean_module.md ## 安全规则 + rule_based_safety_module.md ## 安全模型 -model_interface.md \ No newline at end of file + +model_interface.md From 65f770022f8978f3ca596b452895689c76556180 Mon Sep 17 00:00:00 2001 From: yujing Date: Tue, 18 Mar 2025 16:03:51 +0800 Subject: [PATCH 25/32] lint all code --- .../extractor_chain_input/good_data/html/list_nest_three.html | 2 +- .../good_data/html/table_include_entity.html | 2 +- .../good_data/html/table_include_math_p.html | 2 +- .../good_data/html/table_include_table_math.html | 2 +- .../extractor_chain_input/good_data/html/table_tail_text.html | 2 +- .../extractor_chain_input/good_data/html/test_list_empty.html | 2 +- .../good_data/html/test_table_elem_include_enter.html | 2 +- .../extractor_chain_input/good_data/html_data_input.jsonl | 2 +- .../assets/recognizer/table_involve_complex_code.html | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/list_nest_three.html b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/list_nest_three.html index 018a85ab..8985f943 100644 --- a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/list_nest_three.html +++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/list_nest_three.html @@ -27,4 +27,4 @@ - \ No newline at end of file + diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_entity.html b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_entity.html index f9b20e14..c473cb4c 100644 --- a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_entity.html +++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_entity.html @@ -1686,4 +1686,4 @@

- \ No newline at end of file + diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_math_p.html b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_math_p.html index 257b0bac..b3370ab3 100644 --- a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_math_p.html +++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_math_p.html @@ -2323,4 +2323,4 @@

- \ No newline at end of file + diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_table_math.html b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_table_math.html index 16d7b72e..d25a8bed 100644 --- a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_table_math.html +++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_table_math.html @@ -87,4 +87,4 @@

STEM 综合展示表

- \ No newline at end of file + diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_tail_text.html b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_tail_text.html index 4044b9a3..43d8867d 100644 --- a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_tail_text.html +++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_tail_text.html @@ -364,4 +364,4 @@

Comments

- \ No newline at end of file + diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/test_table_elem_include_enter.html b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/test_table_elem_include_enter.html index 176f4fab..ff573ea4 100644 --- a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/test_table_elem_include_enter.html +++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/test_table_elem_include_enter.html @@ -3133,4 +3133,4 @@

پشتیبانی

images=mutation.addedNodes[i].getElementsByTagName('img');is_image=mutation.addedNodes[i].tagName=="IMG";iframes=mutation.addedNodes[i].getElementsByTagName('iframe');is_iframe=mutation.addedNodes[i].tagName=="IFRAME";rocket_lazy=mutation.addedNodes[i].getElementsByClassName('rocket-lazyload');image_count+=images.length;iframe_count+=iframes.length;rocketlazy_count+=rocket_lazy.length;if(is_image){image_count+=1} if(is_iframe){iframe_count+=1}}});if(image_count>0||iframe_count>0||rocketlazy_count>0){lazyLoadInstance.update()}});var b=document.getElementsByTagName("body")[0];var config={childList:!0,subtree:!0};observer.observe(b,config)}},!1) - \ No newline at end of file + diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html_data_input.jsonl b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html_data_input.jsonl index 3aa72f85..25336245 100644 --- a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html_data_input.jsonl +++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html_data_input.jsonl @@ -17,4 +17,4 @@ {"track_id": "table_include_table_math", "dataset_name": "table_include_table_math", "url": "https://test","data_source_category": "HTML", "path":"table_include_table_math.html", "file_bytes": 1000, "meta_info": {"input_datetime": "2020-01-01 00:00:00"}} {"track_id": "test_clean_tags", "dataset_name": "test_pipeline_suit", "url": "https://math.stackexchange.com/questions/4082284/solving-for-vector-contained-in-a-diagonal-matrix","data_source_category": "HTML", "path":"test_clean_tags.html", "file_bytes": 1000, "page_layout_type":"forum", "meta_info": {"input_datetime": "2020-01-01 00:00:00"}} {"track_id": "list_nest_three", "dataset_name": "list_nest_three", "url": "http://test.com","data_source_category": "HTML", "path":"list_nest_three.html", "file_bytes": 1000, "page_layout_type":"forum", "meta_info": {"input_datetime": "2020-01-01 00:00:00"}} -{"track_id": "table_include_entity", "dataset_name": "table_include_entity", "url": "http://math.stackexchange.com/questions/658871/perfectly-centered-break-of-a-perfectly-aligned-pool-ball-rack?answertab=active","data_source_category": "HTML", "path":"table_include_entity.html", "file_bytes": 1000, "page_layout_type":"forum", "meta_info": {"input_datetime": "2020-01-01 00:00:00"}} \ No newline at end of file +{"track_id": "table_include_entity", "dataset_name": "table_include_entity", "url": "http://math.stackexchange.com/questions/658871/perfectly-centered-break-of-a-perfectly-aligned-pool-ball-rack?answertab=active","data_source_category": "HTML", "path":"table_include_entity.html", "file_bytes": 1000, "page_layout_type":"forum", "meta_info": {"input_datetime": "2020-01-01 00:00:00"}} diff --git a/tests/llm_web_kit/extractor/html/recognizer/assets/recognizer/table_involve_complex_code.html b/tests/llm_web_kit/extractor/html/recognizer/assets/recognizer/table_involve_complex_code.html index b929d7e0..d66d62b7 100644 --- a/tests/llm_web_kit/extractor/html/recognizer/assets/recognizer/table_involve_complex_code.html +++ b/tests/llm_web_kit/extractor/html/recognizer/assets/recognizer/table_involve_complex_code.html @@ -234,4 +234,4 @@

ClientNetworkWrapper

- \ No newline at end of file + From 725f6cb00d058ea5e66a58268039ae3a21a65462 Mon Sep 17 00:00:00 2001 From: yujing Date: Tue, 18 Mar 2025 16:50:01 +0800 Subject: [PATCH 26/32] add error code --- llm_web_kit/exception/exception.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llm_web_kit/exception/exception.py b/llm_web_kit/exception/exception.py index dc7939de..0aae2129 100644 --- a/llm_web_kit/exception/exception.py +++ b/llm_web_kit/exception/exception.py @@ -6,7 +6,7 @@ class ErrorMsg: """Error message manager class.""" - _errors = {} + _errors = {'10001': {'module': 'Model', 'error_name': 'ModelRuntimeException', 'message': 'Model runtime exception'}} @classmethod def _load_errors(cls): From e9c55efa249e23c30c6165091beac5599b05a437 Mon Sep 17 00:00:00 2001 From: yujing Date: Wed, 19 Mar 2025 11:58:23 +0800 Subject: [PATCH 27/32] bug fix --- llm_web_kit/exception/exception.jsonc | 4 ++++ llm_web_kit/exception/exception.py | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/llm_web_kit/exception/exception.jsonc b/llm_web_kit/exception/exception.jsonc index c72a62f2..4de7ea29 100644 --- a/llm_web_kit/exception/exception.jsonc +++ b/llm_web_kit/exception/exception.jsonc @@ -146,6 +146,10 @@ "CleanModelUnsupportedLanguageException": { "code": 46100000, "message": "Clean model unsupported language exception" + }, + "ModelRuntimeException": { + "code": 47000000, + "message": "Model runtime exception" } } } diff --git a/llm_web_kit/exception/exception.py b/llm_web_kit/exception/exception.py index 0aae2129..dc7939de 100644 --- a/llm_web_kit/exception/exception.py +++ b/llm_web_kit/exception/exception.py @@ -6,7 +6,7 @@ class ErrorMsg: """Error message manager class.""" - _errors = {'10001': {'module': 'Model', 'error_name': 'ModelRuntimeException', 'message': 'Model runtime exception'}} + _errors = {} @classmethod def _load_errors(cls): From 930d5b3884929604847b118fa56d62e82155f3ab Mon Sep 17 00:00:00 2001 From: yujing Date: Wed, 19 Mar 2025 12:26:19 +0800 Subject: [PATCH 28/32] roll back end of html --- .../extractor_chain_input/good_data/html/list_nest_three.html | 2 +- .../good_data/html/table_include_entity.html | 2 +- .../good_data/html/table_include_math_p.html | 2 +- .../good_data/html/table_include_table_math.html | 2 +- .../extractor_chain_input/good_data/html/table_tail_text.html | 2 +- .../extractor_chain_input/good_data/html/test_list_empty.html | 2 +- .../good_data/html/test_table_elem_include_enter.html | 2 +- .../assets/extractor_chain_input/good_data/html/text3.html | 2 +- .../assets/extractor_chain_input/good_data/html/text9.html | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/list_nest_three.html b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/list_nest_three.html index 8985f943..018a85ab 100644 --- a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/list_nest_three.html +++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/list_nest_three.html @@ -27,4 +27,4 @@ - + \ No newline at end of file diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_entity.html b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_entity.html index c473cb4c..f9b20e14 100644 --- a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_entity.html +++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_entity.html @@ -1686,4 +1686,4 @@

- + \ No newline at end of file diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_math_p.html b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_math_p.html index b3370ab3..257b0bac 100644 --- a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_math_p.html +++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_math_p.html @@ -2323,4 +2323,4 @@

- + \ No newline at end of file diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_table_math.html b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_table_math.html index d25a8bed..16d7b72e 100644 --- a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_table_math.html +++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_include_table_math.html @@ -87,4 +87,4 @@

STEM 综合展示表

- + \ No newline at end of file diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_tail_text.html b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_tail_text.html index 43d8867d..4044b9a3 100644 --- a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_tail_text.html +++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/table_tail_text.html @@ -364,4 +364,4 @@

Comments

- + \ No newline at end of file diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/test_table_elem_include_enter.html b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/test_table_elem_include_enter.html index ff573ea4..176f4fab 100644 --- a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/test_table_elem_include_enter.html +++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/test_table_elem_include_enter.html @@ -3133,4 +3133,4 @@

پشتیبانی

images=mutation.addedNodes[i].getElementsByTagName('img');is_image=mutation.addedNodes[i].tagName=="IMG";iframes=mutation.addedNodes[i].getElementsByTagName('iframe');is_iframe=mutation.addedNodes[i].tagName=="IFRAME";rocket_lazy=mutation.addedNodes[i].getElementsByClassName('rocket-lazyload');image_count+=images.length;iframe_count+=iframes.length;rocketlazy_count+=rocket_lazy.length;if(is_image){image_count+=1} if(is_iframe){iframe_count+=1}}});if(image_count>0||iframe_count>0||rocketlazy_count>0){lazyLoadInstance.update()}});var b=document.getElementsByTagName("body")[0];var config={childList:!0,subtree:!0};observer.observe(b,config)}},!1) - + \ No newline at end of file diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/text3.html b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/text3.html index 93715a22..bc4bf044 100644 --- a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/text3.html +++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/text3.html @@ -1246,4 +1246,4 @@ src=\"https://physicsforums-bernhardtmediall.netdna-ssl.com/copyright.js\">\n\n\n \n\n\n \ No newline at end of file +\n\n\n diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/text9.html b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/text9.html index c13350c8..cee33aea 100644 --- a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/text9.html +++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/text9.html @@ -356,4 +356,4 @@ //]]> - \ No newline at end of file + From e41766dc6262677c5f66a0d5d0cad6c3dcf89c0f Mon Sep 17 00:00:00 2001 From: yujing Date: Wed, 19 Mar 2025 12:36:14 +0800 Subject: [PATCH 29/32] roll back end of html --- .../assets/extractor_chain_input/good_data/html/text3.html | 2 +- .../assets/extractor_chain_input/good_data/html/text9.html | 2 +- .../assets/recognizer/table_involve_complex_code.html | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/text3.html b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/text3.html index bc4bf044..93715a22 100644 --- a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/text3.html +++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/text3.html @@ -1246,4 +1246,4 @@ src=\"https://physicsforums-bernhardtmediall.netdna-ssl.com/copyright.js\">\n\n\n \n\n\n +\n\n\n \ No newline at end of file diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/text9.html b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/text9.html index cee33aea..c13350c8 100644 --- a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/text9.html +++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/text9.html @@ -356,4 +356,4 @@ //]]> - + \ No newline at end of file diff --git a/tests/llm_web_kit/extractor/html/recognizer/assets/recognizer/table_involve_complex_code.html b/tests/llm_web_kit/extractor/html/recognizer/assets/recognizer/table_involve_complex_code.html index d66d62b7..b929d7e0 100644 --- a/tests/llm_web_kit/extractor/html/recognizer/assets/recognizer/table_involve_complex_code.html +++ b/tests/llm_web_kit/extractor/html/recognizer/assets/recognizer/table_involve_complex_code.html @@ -234,4 +234,4 @@

ClientNetworkWrapper

- + \ No newline at end of file From 3413fd7feb9d7b96f57b3b15ea4cc723313cf57c Mon Sep 17 00:00:00 2001 From: yujing Date: Wed, 19 Mar 2025 12:55:08 +0800 Subject: [PATCH 30/32] roll back end of jsonl --- .../extractor_chain_input/good_data/html_data_input.jsonl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html_data_input.jsonl b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html_data_input.jsonl index 25336245..3aa72f85 100644 --- a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html_data_input.jsonl +++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html_data_input.jsonl @@ -17,4 +17,4 @@ {"track_id": "table_include_table_math", "dataset_name": "table_include_table_math", "url": "https://test","data_source_category": "HTML", "path":"table_include_table_math.html", "file_bytes": 1000, "meta_info": {"input_datetime": "2020-01-01 00:00:00"}} {"track_id": "test_clean_tags", "dataset_name": "test_pipeline_suit", "url": "https://math.stackexchange.com/questions/4082284/solving-for-vector-contained-in-a-diagonal-matrix","data_source_category": "HTML", "path":"test_clean_tags.html", "file_bytes": 1000, "page_layout_type":"forum", "meta_info": {"input_datetime": "2020-01-01 00:00:00"}} {"track_id": "list_nest_three", "dataset_name": "list_nest_three", "url": "http://test.com","data_source_category": "HTML", "path":"list_nest_three.html", "file_bytes": 1000, "page_layout_type":"forum", "meta_info": {"input_datetime": "2020-01-01 00:00:00"}} -{"track_id": "table_include_entity", "dataset_name": "table_include_entity", "url": "http://math.stackexchange.com/questions/658871/perfectly-centered-break-of-a-perfectly-aligned-pool-ball-rack?answertab=active","data_source_category": "HTML", "path":"table_include_entity.html", "file_bytes": 1000, "page_layout_type":"forum", "meta_info": {"input_datetime": "2020-01-01 00:00:00"}} +{"track_id": "table_include_entity", "dataset_name": "table_include_entity", "url": "http://math.stackexchange.com/questions/658871/perfectly-centered-break-of-a-perfectly-aligned-pool-ball-rack?answertab=active","data_source_category": "HTML", "path":"table_include_entity.html", "file_bytes": 1000, "page_layout_type":"forum", "meta_info": {"input_datetime": "2020-01-01 00:00:00"}} \ No newline at end of file From c942af02a8f69b108b2767ce639727e2008f62e5 Mon Sep 17 00:00:00 2001 From: yujing Date: Wed, 19 Mar 2025 18:06:22 +0800 Subject: [PATCH 31/32] clean unused code --- llm_web_kit/model/unsafe_words_detector.py | 55 ++----------- .../model/test_unsafe_words_detector.py | 80 +------------------ 2 files changed, 7 insertions(+), 128 deletions(-) diff --git a/llm_web_kit/model/unsafe_words_detector.py b/llm_web_kit/model/unsafe_words_detector.py index 3acd57d9..3833dd4c 100644 --- a/llm_web_kit/model/unsafe_words_detector.py +++ b/llm_web_kit/model/unsafe_words_detector.py @@ -212,55 +212,9 @@ def decide_content_unsafe_word_by_data_checker( return unsafe_word_min_level -def unsafe_words_filter( - data_dict: Dict[str, Any], language: str, content_style: str -) -> str: - if language in xyz_language_lst: - language = 'xyz' - elif language in [ - 'zh', - 'en', - 'yue', - 'zho', - 'eng', - 'zho_Hans', - 'zho_Hant', - 'yue_Hant', - 'eng_Latn', - ]: - language = 'zh-en' - else: - raise SafeModelException(f'Unsupported language: {language}') - - unsafeWordChecker = get_unsafe_words_checker(language) - unsafe_word_min_level = decide_data_unsafe_word_by_data_checker( - data_dict, unsafeWordChecker - ) - - return unsafe_word_min_level - - -def unsafe_words_filter_overall( - data_dict: Dict[str, Any], - language: str, - content_style: str, - from_safe_source, - from_domestic_source, -): - if from_safe_source: - return {'hit_unsafe_words': False} - if from_domestic_source: - unsafe_range = ('L1',) - else: - unsafe_range = ('L1', 'L2') - unsafe_word_min_level = unsafe_words_filter(data_dict, language, content_style) - hit = unsafe_word_min_level in unsafe_range - return {'hit_unsafe_words': hit} - - class UnsafeWordsFilter: - def __init__(self): - pass + def __init__(self,raise_not_support_language_exception: bool = False): + self.raise_not_support_language_exception = raise_not_support_language_exception def filter( self, @@ -286,7 +240,10 @@ def filter( ]: language = 'zh-en' else: - raise SafeModelException(f'Unsupported language: {language}') + if self.raise_not_support_language_exception: + raise SafeModelException(f'Unsupported language: {language}') + else: + return True, {'hit_unsafe_words': False} if from_safe_source: return True, {'hit_unsafe_words': False} diff --git a/tests/llm_web_kit/model/test_unsafe_words_detector.py b/tests/llm_web_kit/model/test_unsafe_words_detector.py index 16729fcc..2d97a76e 100644 --- a/tests/llm_web_kit/model/test_unsafe_words_detector.py +++ b/tests/llm_web_kit/model/test_unsafe_words_detector.py @@ -4,8 +4,7 @@ from llm_web_kit.exception.exception import SafeModelException from llm_web_kit.model.unsafe_words_detector import ( UnsafeWordChecker, auto_download, decide_data_unsafe_word_by_data_checker, - get_ac, get_unsafe_words, get_unsafe_words_checker, unsafe_words_filter, - unsafe_words_filter_overall) + get_ac, get_unsafe_words, get_unsafe_words_checker) class TestUnsafeWordChecker(unittest.TestCase): @@ -82,83 +81,6 @@ def test_get_unsafe_words_checker(self, mock_get_ac): checker2 = get_unsafe_words_checker('zh-en') self.assertIs(checker1, checker2) # Should return the same instance - @patch('llm_web_kit.model.unsafe_words_detector.get_unsafe_words_checker') - def test_unsafe_words_filter(self, mock_get_checker): - mock_checker = MagicMock() - mock_checker.check_unsafe_words.return_value = [ - {'word': '', 'level': 'L3', 'count': 1} - ] - mock_get_checker.return_value = mock_checker - - data_dict = {'content': 'Test content'} - result = unsafe_words_filter(data_dict, 'en', 'text') - self.assertEqual(result, 'L3') - result = unsafe_words_filter(data_dict, 'ko', 'text') - self.assertEqual(result, 'L3') - with self.assertRaises(SafeModelException): - unsafe_words_filter(data_dict, 'unk', 'text') - - def test_unsafe_words_filter_with_unsupported_language(self): - data_dict = {'content': 'Test content'} - with self.assertRaises(SafeModelException): - unsafe_words_filter(data_dict, 'unsupported_language', 'text') - - @patch('llm_web_kit.model.unsafe_words_detector.unsafe_words_filter') - def test_unsafe_words_filter_overall(self, mock_filter): - mock_filter.return_value = 'L1' - - data_dict = {'content': 'Content with unsafe words.'} - - result = unsafe_words_filter_overall( - data_dict, - language='en', - content_style='text', - from_safe_source=False, - from_domestic_source=False, - ) - self.assertIsInstance(result, dict) - self.assertTrue(result['hit_unsafe_words']) - - result = unsafe_words_filter_overall( - data_dict, - language='en', - content_style='text', - from_safe_source=False, - from_domestic_source=True, - ) - self.assertIsInstance(result, dict) - self.assertTrue(result['hit_unsafe_words']) - - result = unsafe_words_filter_overall( - data_dict, - language='en', - content_style='text', - from_safe_source=True, - from_domestic_source=True, - ) - self.assertIsInstance(result, dict) - self.assertFalse(result['hit_unsafe_words']) - - result = unsafe_words_filter_overall( - data_dict, - language='en', - content_style='text', - from_safe_source=True, - from_domestic_source=False, - ) - self.assertIsInstance(result, dict) - self.assertFalse(result['hit_unsafe_words']) - - result = unsafe_words_filter_overall( - data_dict, - language='ru', - content_style='text', - from_safe_source=False, - from_domestic_source=False, - ) - self.assertIsInstance(result, dict) - self.assertTrue(result['hit_unsafe_words']) - @patch('llm_web_kit.model.unsafe_words_detector.load_config') @patch('llm_web_kit.model.unsafe_words_detector.download_auto_file') def test_auto_download(self, mock_download_auto_file, mock_load_config): From 8dbb66b504e0670a6a8e62899d7915953f16735b Mon Sep 17 00:00:00 2001 From: yujing Date: Thu, 20 Mar 2025 11:52:31 +0800 Subject: [PATCH 32/32] clean unused code --- llm_web_kit/model/unsafe_words_detector.py | 19 ------------------- .../model/test_unsafe_words_detector.py | 19 ++++--------------- 2 files changed, 4 insertions(+), 34 deletions(-) diff --git a/llm_web_kit/model/unsafe_words_detector.py b/llm_web_kit/model/unsafe_words_detector.py index 3833dd4c..28556fc5 100644 --- a/llm_web_kit/model/unsafe_words_detector.py +++ b/llm_web_kit/model/unsafe_words_detector.py @@ -6,7 +6,6 @@ from llm_web_kit.config.cfg_reader import load_config from llm_web_kit.exception.exception import SafeModelException -from llm_web_kit.input.datajson import DataJson from llm_web_kit.libs.standard_utils import json_loads from llm_web_kit.model.basic_functions.format_check import (is_en_letter, is_pure_en_word) @@ -178,24 +177,6 @@ def get_unsafe_words_checker(language='zh-en') -> UnsafeWordChecker: return singleton_resource_manager.get_resource(language) -def decide_data_unsafe_word_by_data_checker( - data_dict: dict, unsafeWordChecker: UnsafeWordChecker -) -> str: - data_obj = DataJson(data_dict) - content_str = data_obj.get_content_list().to_txt() - unsafe_words_list = unsafeWordChecker.check_unsafe_words(content_str=content_str) - unsafe_word_levels = [] - for w in unsafe_words_list: - _, level, _ = w['word'], w['level'], w['count'] - # "涉政|观测|L4|带头人" - unsafe_word_levels.append(level) - - unsafe_word_levels = list(set(unsafe_word_levels)) - unsafe_word_min_level = min(unsafe_word_levels + ['NF']) - - return unsafe_word_min_level - - def decide_content_unsafe_word_by_data_checker( content_str: str, unsafeWordChecker: UnsafeWordChecker ) -> str: diff --git a/tests/llm_web_kit/model/test_unsafe_words_detector.py b/tests/llm_web_kit/model/test_unsafe_words_detector.py index 2d97a76e..f7c63f7f 100644 --- a/tests/llm_web_kit/model/test_unsafe_words_detector.py +++ b/tests/llm_web_kit/model/test_unsafe_words_detector.py @@ -2,9 +2,10 @@ from unittest.mock import MagicMock, Mock, mock_open, patch from llm_web_kit.exception.exception import SafeModelException -from llm_web_kit.model.unsafe_words_detector import ( - UnsafeWordChecker, auto_download, decide_data_unsafe_word_by_data_checker, - get_ac, get_unsafe_words, get_unsafe_words_checker) +from llm_web_kit.model.unsafe_words_detector import (UnsafeWordChecker, + auto_download, get_ac, + get_unsafe_words, + get_unsafe_words_checker) class TestUnsafeWordChecker(unittest.TestCase): @@ -36,18 +37,6 @@ def test_check_unsafe_words(self, mock_get_ac): self.assertIsInstance(result, list) self.assertEqual(len(result), 0) - @patch('llm_web_kit.model.unsafe_words_detector.get_unsafe_words_checker') - def test_decide_unsafe_word_by_data_checker(self, mock_get_checker): - mock_checker = MagicMock() - mock_checker.check_unsafe_words.return_value = [ - {'word': 'unsafe', 'level': 'L2', 'count': 1} - ] - mock_get_checker.return_value = mock_checker - - data_dict = {'content': 'Some content with unsafe elements.'} - result = decide_data_unsafe_word_by_data_checker(data_dict, mock_checker) - self.assertEqual(result, 'L2') - def test_standalone_word_detection(self): """测试独立存在的子词能被正确识别[2,6](@ref)""" ac = Mock()