diff --git a/docs/llm_web_kit/model/model_based_safety.md b/docs/llm_web_kit/model/model_based_safety.md new file mode 100644 index 00000000..e69de29b diff --git a/docs/llm_web_kit/model/model_interface.md b/docs/llm_web_kit/model/model_interface.md new file mode 100644 index 00000000..e69de29b diff --git a/docs/llm_web_kit/model/readme.md b/docs/llm_web_kit/model/readme.md new file mode 100644 index 00000000..4f2ad8b3 --- /dev/null +++ b/docs/llm_web_kit/model/readme.md @@ -0,0 +1,21 @@ +# 面向用户的接口 + +## html分类 + +html_simplify_classify.md + +## 语言检测 + +lang_id.md + +## 清洗模型 + +clean_module.md + +## 安全规则 + +rule_based_safety_module.md + +## 安全模型 + +model_interface.md diff --git a/docs/llm_web_kit/model/rule_based_safety_module.md b/docs/llm_web_kit/model/rule_based_safety_module.md new file mode 100644 index 00000000..e69de29b diff --git a/llm_web_kit/exception/exception.jsonc b/llm_web_kit/exception/exception.jsonc index c72a62f2..4de7ea29 100644 --- a/llm_web_kit/exception/exception.jsonc +++ b/llm_web_kit/exception/exception.jsonc @@ -146,6 +146,10 @@ "CleanModelUnsupportedLanguageException": { "code": 46100000, "message": "Clean model unsupported language exception" + }, + "ModelRuntimeException": { + "code": 47000000, + "message": "Model runtime exception" } } } diff --git a/llm_web_kit/exception/exception.py b/llm_web_kit/exception/exception.py index f6b92cf7..dc7939de 100644 --- a/llm_web_kit/exception/exception.py +++ b/llm_web_kit/exception/exception.py @@ -336,6 +336,14 @@ def __init__(self, custom_message: str | None = None, error_code: int | None = N super().__init__(custom_message, error_code) +class ModelRuntimeException(ModelBaseException): + """Exception raised for model input data format.""" + def __init__(self, custom_message: str | None = None, error_code: int | None = None): + if error_code is None: + error_code = ErrorMsg.get_error_code('Model', 'ModelRuntimeException') + super().__init__(custom_message, error_code) + + class ModelOutputException(ModelBaseException): """Exception raised for model output data format.""" def __init__(self, custom_message: str | None = None, error_code: int | None = None): diff --git a/llm_web_kit/model/domain_safety_detector.py b/llm_web_kit/model/domain_safety_detector.py new file mode 100644 index 00000000..15ce2799 --- /dev/null +++ b/llm_web_kit/model/domain_safety_detector.py @@ -0,0 +1,13 @@ +class DomainFilter: + def __init__(self): + pass + + def filter( + self, + content_str: str, + language: str, + url: str, + language_details: str, + content_style: str, + ) -> dict: + return True, {} diff --git a/llm_web_kit/model/model_impl.py b/llm_web_kit/model/model_impl.py new file mode 100644 index 00000000..d2d44112 --- /dev/null +++ b/llm_web_kit/model/model_impl.py @@ -0,0 +1,289 @@ +from abc import abstractmethod +from enum import Enum +from typing import Dict, List, Type + +from llm_web_kit.exception.exception import (ModelInitException, + ModelInputException, + ModelRuntimeException) +from llm_web_kit.model.model_interface import (BatchProcessConfig, + ModelPredictor, ModelResource, + ModelResponse, PoliticalRequest, + PoliticalResponse, PornRequest, + PornResponse, + ResourceRequirement) +from llm_web_kit.model.policical import (get_singleton_political_detect, + update_political_by_str) + + +class ModelType(Enum): + """模型类型枚举.""" + + POLITICAL = 'political' # 涉政模型 + PORN = 'porn' # 色情模型 + + +class DeviceType(Enum): + """设备类型枚举.""" + + CPU = 'cpu' + GPU = 'gpu' + + +class BaseModelResource(ModelResource): + """基础模型资源类.""" + + def __init__(self): + self.model = None + + def initialize(self) -> None: + self.model = self._load_model() + + @abstractmethod + def _load_model(self): + pass + + @abstractmethod + def convert_result_to_response(self, result: dict) -> ModelResponse: + pass + + def cleanup(self) -> None: + if self.model: + self._cleanup_model() + self.model = None + + def _cleanup_model(self): + pass + + +class BasePredictor(ModelPredictor): + """基础预测器类.""" + + def __init__(self, language: str): + self.language = language + self.model = self._create_model(language) + + # 初始化模型 + self.model.initialize() + + @abstractmethod + def _create_model(self, language) -> ModelResource: + pass + + def get_resource_requirement(self): + return self.model.get_resource_requirement() + + +# 涉政模型实现 +class PoliticalCPUModel(BaseModelResource): + """涉政检测CPU模型.""" + + def _load_model(self): + try: + model = get_singleton_political_detect() + if model is None: + raise RuntimeError('Failed to load political model') + return model + except Exception as e: + raise RuntimeError(f'Failed to load political CPU model: {e}') + + def get_resource_requirement(self): + return ResourceRequirement(num_cpus=1, memory_GB=4, num_gpus=0) + + def get_batch_config(self) -> BatchProcessConfig: + return BatchProcessConfig( + max_batch_size=1000, optimal_batch_size=512, min_batch_size=8 + ) + + def predict_batch(self, contents: List[str]) -> List[dict]: + if not self.model: + raise RuntimeError('Model not initialized') + try: + # 批量处理 + results = [] + for content in contents: + result = update_political_by_str(content) + results.append(result) + + return results + except Exception as e: + raise RuntimeError(f'Prediction failed: {e}') + + def convert_result_to_response(self, result: dict) -> ModelResponse: + # raise NotImplementedError + # TODO convert result to response ensure the threshold + return PoliticalResponse( + is_remained=result['political_prob'] > 0.99, details=result + ) + + +class PoliticalPredictorImpl(BasePredictor): + """涉政检测预测器实现.""" + + def _create_model(self, language: str) -> ModelResource: + + if language in ['zh', 'en']: + return PoliticalCPUModel() + raise ModelInitException( + f'Poltical model does not support language: {language}' + ) + + def predict_batch( + self, requests: List[PoliticalRequest] + ) -> List[PoliticalResponse]: + """批量预测接口.""" + + try: + # 收集所有请求内容 + batch_contents = [] + + for req in requests: + # 验证语言支持 + if req.language != self.language: + raise ModelInputException( + f'Language mismatch: {req.language} vs {self.language}' + ) + batch_contents.append(req.content) + + if batch_contents: + # 批量处理 + probs = self.model.predict_batch(batch_contents) + responses = [self.model.convert_result_to_response(prob) for prob in probs] + except Exception as e: + raise ModelRuntimeException(f'Political prediction failed: {e}') + + return responses + + +# 色情模型实现 +class PornEnGPUModel(BaseModelResource): + """英文色情检测GPU模型.""" + + def _load_model(self): + try: + from llm_web_kit.model.porn_detector import \ + BertModel as PornEnModel + + return PornEnModel() + except Exception as e: + raise ModelInitException(f'Failed to init the en porn model: {e}') + + def get_resource_requirement(self): + # S2 cluster has 96 CPUs, 1TB memory, 8 GPUs + # so we can use 12 CPUs, 64GB memory, 1 GPU for this model + return ResourceRequirement(num_cpus=12, memory_GB=64, num_gpus=1) + + def get_batch_config(self) -> BatchProcessConfig: + return BatchProcessConfig( + max_batch_size=1000, optimal_batch_size=512, min_batch_size=8 + ) + + def predict_batch(self, contents: List[str]) -> List[dict]: + if not self.model: + raise RuntimeError('Model not initialized') + try: + # 色情模型本身支持批处理 + results = self.model.predict(contents) + return [ + {'porn_prob': result[self.model.get_output_key('prob')]} + for result in results + ] + except Exception as e: + raise RuntimeError(f'Prediction failed: {e}') + + def convert_result_to_response(self, result: dict) -> ModelResponse: + # raise NotImplementedError + # TODO convert result to response ensure the threshold + return PornResponse(is_remained=result['porn_prob'] < 0.2, details=result) + + +class PornZhGPUModel(BaseModelResource): + """中文色情检测GPU模型.""" + + def _load_model(self): + try: + from llm_web_kit.model.porn_detector import \ + XlmrModel as PornZhModel + + return PornZhModel() + except Exception as e: + raise ModelInitException(f'Failed to init the zh porn model: {e}') + + def get_resource_requirement(self): + # S2 cluster has at least 96 CPUs, 1TB memory, 8 GPUs + # so we can use 12 CPUs, 64GB memory, 1 GPU for this model + return ResourceRequirement(num_cpus=12, memory_GB=64, num_gpus=1) + + def get_batch_config(self) -> BatchProcessConfig: + return BatchProcessConfig( + max_batch_size=300, optimal_batch_size=256, min_batch_size=8 + ) + + def predict_batch(self, contents: List[str]) -> List[dict]: + if not self.model: + raise RuntimeError('Model not initialized') + try: + # 色情模型本身支持批处理 + results = self.model.predict(contents) + return [ + {'porn_prob': result[self.model.get_output_key('prob')]} + for result in results + ] + except Exception as e: + raise RuntimeError(f'Prediction failed: {e}') + + def convert_result_to_response(self, result: dict) -> ModelResponse: + # raise NotImplementedError + # TODO convert result to response ensure the threshold + return PornResponse(is_remained=result['porn_prob'] > 0.95, details=result) + + +class PornPredictorImpl(BasePredictor): + """色情检测预测器实现.""" + + def _create_model(self, language: str) -> ModelResource: + if language == 'en': + return PornEnGPUModel() + elif language == 'zh': + return PornZhGPUModel() + raise ModelInitException(f'Porn model does not support language: {language}') + + def predict_batch(self, requests: List[PornRequest]) -> List[PornResponse]: + """批量预测接口.""" + try: + # 收集所有请求内容 + batch_contents = [] + + for req in requests: + # 验证语言支持 + if req.language != self.language: + raise ModelInputException( + f'Language mismatch: {req.language} vs {self.language}' + ) + batch_contents.append(req.content) + + if batch_contents: + # 批量处理 + probs = self.model.predict_batch(batch_contents) + responses = [self.model.convert_result_to_response(prob) for prob in probs] + except Exception as e: + raise ModelRuntimeException(f'Porn prediction failed: {e}') + return responses + + +# 模型工厂 +class ModelFactory: + """模型工厂类.""" + + _predictor_registry: Dict[ModelType, Type[BasePredictor]] = { + ModelType.POLITICAL: PoliticalPredictorImpl, + ModelType.PORN: PornPredictorImpl, + } + + @classmethod + def create_predictor(cls, model_type: ModelType, language: str) -> BasePredictor: + """创建预测器实例.""" + predictor_class = cls._predictor_registry.get(model_type) + print(predictor_class) + if not predictor_class: + raise ValueError(f'No predictor registered for type: {model_type}') + return predictor_class(language=language) diff --git a/llm_web_kit/model/model_interface.py b/llm_web_kit/model/model_interface.py new file mode 100644 index 00000000..3136cb99 --- /dev/null +++ b/llm_web_kit/model/model_interface.py @@ -0,0 +1,151 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List + + +@dataclass +class ModelRequest: + """通用模型请求基类.""" + + content: str + language: str + extra_params: Dict[str, Any] = None + + +@dataclass +class ModelResponse: + """通用模型响应基类.""" + + is_remained: bool + details: Dict[str, Any] = None + + +@dataclass +class PoliticalRequest(ModelRequest): + """涉政检测请求.""" + + pass + + +@dataclass +class PoliticalResponse(ModelResponse): + """涉政检测响应.""" + + pass + + +@dataclass +class PornRequest(ModelRequest): + """色情检测请求.""" + + pass + + +@dataclass +class PornResponse(ModelResponse): + """色情检测响应.""" + + pass + + +@dataclass +class BatchProcessConfig: + """批处理配置.""" + + max_batch_size: int + optimal_batch_size: int + min_batch_size: int + + +class ResourceType(Enum): + """资源类型枚举.""" + + CPU = 'cpu_only' + GPU = 'num_gpus' + DEFAULT = 'default' + + +class ResourceRequirement: + def __init__(self, num_cpus: float, memory_GB: float, num_gpus: float = 0.0): + self.num_cpus = num_cpus + self.memory_GB = memory_GB + self.num_gpus = num_gpus + + def to_ray_resources(self) -> Dict: + if self.num_gpus > 0: + resources = { + 'num_cpus': self.num_cpus, + 'memory': self.memory_GB * 2**30, + 'num_gpus': self.num_gpus, + } + else: + # prefer to use CPU on CPU only node + # we set dummy resource "cpu_only" on CPU only node + # so set resources.cpu_only = 1 to ensure the task can be scheduled on CPU only node + resources = { + 'num_cpus': self.num_cpus, + 'memory': self.memory_GB * 2**30, + 'resources': {'cpu_only': 1}, + } + + return resources + + +class ModelResource(ABC): + """模型资源接口.""" + + @abstractmethod + def initialize(self) -> None: + """初始化模型资源.""" + pass + + @abstractmethod + def get_batch_config(self) -> BatchProcessConfig: + """获取模型的批处理配置.""" + pass + + @abstractmethod + def predict_batch(self, contents: List[str]) -> List[dict]: + """批量预测.""" + pass + + @abstractmethod + def cleanup(self) -> None: + """清理资源.""" + pass + + @abstractmethod + def get_resource_requirement(self) -> ResourceRequirement: + """获取资源需求.""" + pass + + +class ModelPredictor(ABC): + """通用预测器接口.""" + + @abstractmethod + def get_resource_requirement(self, language: str) -> ResourceRequirement: + """获取资源需求.""" + pass + + @abstractmethod + def predict_batch(self, requests: List[ModelRequest]) -> List[ModelResponse]: + """批量预测接口 - 同步版本.""" + pass + + +class PoliticalPredictor(ModelPredictor): + """涉政预测器接口.""" + + def predict_batch( + self, requests: List[PoliticalRequest] + ) -> List[PoliticalResponse]: + pass + + +class PornPredictor(ModelPredictor): + """色情预测器接口.""" + + def predict_batch(self, requests: List[PornRequest]) -> List[PornResponse]: + pass diff --git a/llm_web_kit/model/policical.py b/llm_web_kit/model/policical.py index 7894c147..66e7f89b 100644 --- a/llm_web_kit/model/policical.py +++ b/llm_web_kit/model/policical.py @@ -90,7 +90,7 @@ def decide_political_by_prob( ) -> float: idx = predictions.index('__label__normal') normal_score = probabilities[idx] - return normal_score + return float(normal_score) def decide_political_func( diff --git a/llm_web_kit/model/porn_detector.py b/llm_web_kit/model/porn_detector.py index c2b6d443..8475afb3 100644 --- a/llm_web_kit/model/porn_detector.py +++ b/llm_web_kit/model/porn_detector.py @@ -154,3 +154,88 @@ def predict(self, texts: Union[List[str], str]): outputs.append(output) return outputs + + +class XlmrModel(BertModel): + def __init__(self, model_path: str = None) -> None: + if not model_path: + model_path = self.auto_download() + + transformers_module = import_transformer() + + self.model = transformers_module.AutoModelForSequenceClassification.from_pretrained( + os.path.join(model_path, 'porn_classifier/classifier_hf') + ) + with open( + os.path.join(model_path, 'porn_classifier/extra_parameters.json') + ) as reader: + model_config = json.load(reader) + + self.clip = bool(model_config.get('clip', False)) + self.max_tokens = int(model_config.get('max_tokens', 300)) + self.remain_tail = min( + self.max_tokens - 1, int(model_config.get('remain_tail', -1)) + ) + self.device = model_config.get('device', 'cpu') + + self.model.eval() + self.model.to(self.device, dtype=torch.float16) + + self.tokenizer = transformers_module.AutoTokenizer.from_pretrained( + os.path.join(model_path, 'porn_classifier/classifier_hf') + ) + self.tokenizer_config = { + 'padding': True, + 'truncation': self.remain_tail <= 0, + 'max_length': self.max_tokens if self.remain_tail <= 0 else None, + 'return_tensors': 'pt' if self.remain_tail <= 0 else None, + } + + self.output_prefix = str(model_config.get('output_prefix', '')).rstrip('_') + self.output_postfix = str(model_config.get('output_postfix', '')).lstrip('_') + + self.model_name = str(model_config.get('model_name', 'porn-24m5')) + + def auto_download(self) -> str: + """Default download the 23w44.zip model.""" + resource_name = 'porn-24m5' + resource_config = load_config()['resources'] + porn_24m5_config: Dict = resource_config[resource_name] + porn_24m5_s3 = porn_24m5_config['download_path'] + porn_24m5_md5 = porn_24m5_config.get('md5', '') + # get the zip path calculated by the s3 path + zip_path = os.path.join(CACHE_DIR, f'{resource_name}.zip') + # the unzip path is calculated by the zip path + unzip_path = get_unzip_dir(zip_path) + logger.info(f'try to make unzip_path: {unzip_path}') + # if the unzip path does not exist, download the zip file and unzip it + if not os.path.exists(unzip_path): + logger.info(f'unzip_path: {unzip_path} does not exist') + logger.info(f'try to unzip from zip_path: {zip_path}') + if not os.path.exists(zip_path): + logger.info(f'zip_path: {zip_path} does not exist') + logger.info(f'downloading {porn_24m5_s3}') + zip_path = download_auto_file(porn_24m5_s3, zip_path, porn_24m5_md5) + logger.info(f'unzipping {zip_path}') + unzip_path = unzip_local_file(zip_path, unzip_path) + else: + logger.info(f'unzip_path: {unzip_path} exist') + return unzip_path + + def predict(self, texts: Union[List[str], str]): + inputs_dict = self.pre_process(texts) + with torch.no_grad(): + logits = self.model(**inputs_dict['inputs']).logits + + if self.clip: + probs = logits.detach().cpu().numpy().clip(min=0, max=1) + else: + probs = logits.detach().cpu().numpy() + + outputs = [] + for prob in probs: + prob = round(float(prob[0]), 6) + output = {self.get_output_key('prob'): prob} + outputs.append(output) + + return outputs diff --git a/llm_web_kit/model/rule_based_safety_module.py b/llm_web_kit/model/rule_based_safety_module.py new file mode 100644 index 00000000..cb1c695d --- /dev/null +++ b/llm_web_kit/model/rule_based_safety_module.py @@ -0,0 +1,145 @@ +from typing import Any, Type + +from llm_web_kit.model.domain_safety_detector import DomainFilter +from llm_web_kit.model.source_safety_detector import SourceFilter +from llm_web_kit.model.unsafe_words_detector import UnsafeWordsFilter + + +def check_type(arg_name: str, arg_value: Any, arg_type: Type): + """check the type of the argument and raise TypeError if the type is not + matched.""" + if not isinstance(arg_value, arg_type): + # TODO change TypeError to custom exception + raise TypeError( + 'The type of {} should be {}, but got {}'.format( + arg_name, arg_type, type(arg_value) + ) + ) + + +class RuleBasedSafetyModuleDataPack: + """The data pack for the rule-based-safety module.""" + + def __init__( + self, + content_str: str, + language: str, + language_details: str, + content_style: str, + url: str, + dataset_name: str, + ): + + # the content of the dataset + check_type('content_str', content_str, str) + self.content_str = content_str + + # the language of the content + check_type('language', language, str) + self.language = language + + # the details of the language + check_type('language_details', language_details, str) + self.language_details = language_details + + # the content style of the content + check_type('content_style', content_style, str) + self.content_style = content_style + + # the url of the content + check_type('url', url, str) + self.url = url + + # the data source of the content + check_type('dataset_name', dataset_name, str) + self.dataset_name = dataset_name + + # the flag of the processed data should be remained or not + self.safety_remained = True + # the details of the clean process + self.safety_infos = {} + + def set_process_result(self, safety_remained: bool, safety_infos: dict) -> None: + """set the process result of the rule_based_safety module.""" + check_type('safety_remained', safety_remained, bool) + check_type('safety_infos', safety_infos, dict) + if safety_remained is False: + self.safety_remained = False + self.safety_infos.update(safety_infos) + + def get_output(self) -> dict: + """get the output of the data pack.""" + return { + 'safety_remained': self.safety_remained, + 'safety_infos': self.safety_infos, + } + + +class RuleBasedSafetyModule: + def __init__(self, prod: bool): + # when in production mode + # the process will return immediately when the data is not safe + self.prod = prod + self.domain_filter = DomainFilter() + self.source_filter = SourceFilter() + self.unsafe_words_filter = UnsafeWordsFilter() + + def process( + self, + content_str: str, + language: str, + language_details: str, + content_style: str, + url: str, + dataset_name: str, + ) -> dict: + """The process of the rule based safety.""" + data_pack = RuleBasedSafetyModuleDataPack( + content_str=content_str, + language=language, + language_details=language_details, + content_style=content_style, + url=url, + dataset_name=dataset_name, + ) + data_pack = self.process_core(data_pack) + return data_pack.get_output() + + def process_core( + self, data_pack: RuleBasedSafetyModuleDataPack + ) -> RuleBasedSafetyModuleDataPack: + """The core process of the rule based safety.""" + content_str = data_pack.content_str + language = data_pack.language + language_details = data_pack.language_details + content_style = data_pack.content_style + url = data_pack.url + data_source = data_pack.dataset_name + + domain_safe_remained, domain_safe_info = self.domain_filter.filter( + content_str, language, url, language_details, content_style + ) + data_pack.set_process_result(domain_safe_remained, domain_safe_info) + if not domain_safe_remained and self.prod: + return data_pack + + source_type_dict = self.source_filter.filter( + content_str, language, data_source, content_style + ) + + from_safe_source = source_type_dict['from_safe_source'] + from_domestic_source = source_type_dict['from_domestic_source'] + unsafe_words_remained, process_info = self.unsafe_words_filter.filter( + content_str, + language, + language_details, + content_style, + from_safe_source, + from_domestic_source, + ) + data_pack.set_process_result(unsafe_words_remained, process_info) + return data_pack + + def get_version(self): + version_str = '1.0.0' + return version_str diff --git a/llm_web_kit/model/source_safety_detector.py b/llm_web_kit/model/source_safety_detector.py new file mode 100644 index 00000000..7be51f09 --- /dev/null +++ b/llm_web_kit/model/source_safety_detector.py @@ -0,0 +1,13 @@ +class SourceFilter: + def __init__(self): + pass + + def filter( + self, + content_str: str, + language: str, + data_source: str, + language_details: str, + content_style: str, + ) -> dict: + return {'from_safe_source': False, 'from_domestic_source': False} diff --git a/llm_web_kit/model/unsafe_words_detector.py b/llm_web_kit/model/unsafe_words_detector.py index 8c05059c..28556fc5 100644 --- a/llm_web_kit/model/unsafe_words_detector.py +++ b/llm_web_kit/model/unsafe_words_detector.py @@ -1,12 +1,11 @@ import os import time -from typing import Any, Dict +from typing import Any, Dict, Tuple import ahocorasick from llm_web_kit.config.cfg_reader import load_config from llm_web_kit.exception.exception import SafeModelException -from llm_web_kit.input.datajson import DataJson from llm_web_kit.libs.standard_utils import json_loads from llm_web_kit.model.basic_functions.format_check import (is_en_letter, is_pure_en_word) @@ -178,11 +177,9 @@ def get_unsafe_words_checker(language='zh-en') -> UnsafeWordChecker: return singleton_resource_manager.get_resource(language) -def decide_unsafe_word_by_data_checker( - data_dict: dict, unsafeWordChecker: UnsafeWordChecker +def decide_content_unsafe_word_by_data_checker( + content_str: str, unsafeWordChecker: UnsafeWordChecker ) -> str: - data_obj = DataJson(data_dict) - content_str = data_obj.get_content_list().to_txt() unsafe_words_list = unsafeWordChecker.check_unsafe_words(content_str=content_str) unsafe_word_levels = [] for w in unsafe_words_list: @@ -196,64 +193,47 @@ def decide_unsafe_word_by_data_checker( return unsafe_word_min_level -def unsafe_words_filter( - data_dict: Dict[str, Any], language: str, content_style: str -) -> str: - if language in xyz_language_lst: - language = 'xyz' - elif language in [ - 'zh', - 'en', - 'yue', - 'zho', - 'eng', - 'zho_Hans', - 'zho_Hant', - 'yue_Hant', - 'eng_Latn', - ]: - language = 'zh-en' - else: - raise SafeModelException(f'Unsupported language: {language}') - - unsafeWordChecker = get_unsafe_words_checker(language) - unsafe_word_min_level = decide_unsafe_word_by_data_checker( - data_dict, unsafeWordChecker - ) - - return unsafe_word_min_level - - -def unsafe_words_filter_overall( - data_dict: Dict[str, Any], - language: str, - content_style: str, - from_safe_source, - from_domestic_source, -): - unsafe_word_min_level = unsafe_words_filter(data_dict, language, content_style) - - if language in xyz_language_lst: - language = 'xyz' - elif language in [ - 'zh', - 'en', - 'yue', - 'zho', - 'eng', - 'zho_Hans', - 'zho_Hant', - 'yue_Hant', - 'eng_Latn', - ]: - language = 'zh-en' - else: - raise SafeModelException(f'Unsupported language: {language}') - if from_safe_source: - return {'hit_unsafe_words': False} - if from_domestic_source: - unsafe_range = ('L1',) - else: - unsafe_range = ('L1', 'L2') - hit = unsafe_word_min_level in unsafe_range - return {'hit_unsafe_words': hit} +class UnsafeWordsFilter: + def __init__(self,raise_not_support_language_exception: bool = False): + self.raise_not_support_language_exception = raise_not_support_language_exception + + def filter( + self, + content_str: str, + language: str, + language_details: str, + content_style: str, + from_safe_source: bool, + from_domestic_source: bool, + ) -> Tuple[bool, Dict[str, Any]]: + if language in xyz_language_lst: + language = 'xyz' + elif language in [ + 'zh', + 'en', + 'yue', + 'zho', + 'eng', + 'zho_Hans', + 'zho_Hant', + 'yue_Hant', + 'eng_Latn', + ]: + language = 'zh-en' + else: + if self.raise_not_support_language_exception: + raise SafeModelException(f'Unsupported language: {language}') + else: + return True, {'hit_unsafe_words': False} + + if from_safe_source: + return True, {'hit_unsafe_words': False} + if from_domestic_source: + unsafe_range = ('L1',) + else: + unsafe_range = ('L1', 'L2') + unsafe_word_min_level = decide_content_unsafe_word_by_data_checker( + content_str, get_unsafe_words_checker(language) + ) + hit = unsafe_word_min_level in unsafe_range + return not hit, {'hit_unsafe_words': hit} diff --git a/tests/llm_web_kit/model/test_domain_safety_detector.py b/tests/llm_web_kit/model/test_domain_safety_detector.py new file mode 100644 index 00000000..d716dddb --- /dev/null +++ b/tests/llm_web_kit/model/test_domain_safety_detector.py @@ -0,0 +1,24 @@ +import unittest + +from llm_web_kit.model.domain_safety_detector import DomainFilter + + +class TestDomainFilter(unittest.TestCase): + def setUp(self): + self.filter = DomainFilter() + + # 测试基础过滤逻辑 + def test_filter_basic_case(self): + result = self.filter.filter( + content_str='Valid content', + language='en', + url='https://example.com', + language_details='formal', + content_style='professional' + ) + self.assertTrue(result[0]) # 预期允许通过 + self.assertEqual(result[1], {}) # 预期无附加信息 + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/llm_web_kit/model/test_model_impl.py b/tests/llm_web_kit/model/test_model_impl.py new file mode 100644 index 00000000..3e656068 --- /dev/null +++ b/tests/llm_web_kit/model/test_model_impl.py @@ -0,0 +1,312 @@ +"""Test cases for model_impl.py.""" + +import unittest +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from llm_web_kit.exception.exception import ModelRuntimeException +from llm_web_kit.model.model_impl import (ModelFactory, ModelType, + PoliticalCPUModel, + PoliticalPredictorImpl, + PornEnGPUModel, PornPredictorImpl, + PornZhGPUModel) +from llm_web_kit.model.model_interface import PornRequest, PornResponse + + +class TestPoliticalCPUModel(TestCase): + """Test cases for PoliticalCPUModel.""" + + @patch.object(PoliticalCPUModel, '_load_model') + def test_load_model(self, mock_load_model): + """Test model loading.""" + mock_load_model.return_value = MagicMock() + model = PoliticalCPUModel() + model._load_model() + assert mock_load_model.call_count == 1 + + @patch.object(PoliticalCPUModel, '_load_model') + def test_get_resource_requirement(self, mock_load_model): + """Test resource requirements.""" + mock_load_model.return_value = MagicMock() + model = PoliticalCPUModel() + resource_requirement = model.get_resource_requirement() + assert resource_requirement.num_cpus == 1 + assert resource_requirement.memory_GB == 4 + assert resource_requirement.num_gpus == 0 + + @patch.object(PoliticalCPUModel, '_load_model') + def test_get_batch_config(self, mock_load_model): + """Test batch configuration.""" + mock_load_model.return_value = MagicMock() + model = PoliticalCPUModel() + batch_config = model.get_batch_config() + assert batch_config.max_batch_size == 1000 + assert batch_config.optimal_batch_size == 512 + assert batch_config.min_batch_size == 8 + + @patch.object(PoliticalCPUModel, '_load_model') + @patch('llm_web_kit.model.model_impl.update_political_by_str') + def test_predict_batch(self, mock_update_political_by_str, mock_load_model): + """Test batch prediction.""" + mock_model = MagicMock() + mock_load_model.return_value = mock_model + mock_update_political_by_str.return_value = {'political_prob': 0.96} + + model = PoliticalCPUModel() + model.model = mock_model + + results = model.predict_batch(['test1', 'test2']) + assert len(results) == 2 + assert results[0]['political_prob'] == 0.96 + assert results[1]['political_prob'] == 0.96 + assert mock_update_political_by_str.call_count == 2 + + @patch.object(PoliticalCPUModel, '_load_model') + def test_convert_result_to_response(self, mock_load_model): + """Test result conversion to response.""" + mock_load_model.return_value = MagicMock() + model = PoliticalCPUModel() + + # Test case where political_prob > 0.99 (should be flagged) + result = {'political_prob': 0.995} + response = model.convert_result_to_response(result) + assert response.is_remained + assert response.details == result + + # Test case where political_prob <= 0.99 (should not be flagged) + result = {'political_prob': 0.985} + response = model.convert_result_to_response(result) + assert not response.is_remained + assert response.details == result + + +class TestPornEnGPUModel(TestCase): + """Test cases for PornEnGPUModel.""" + + from llm_web_kit.model.porn_detector import BertModel as PornEnModel + + @patch.object(PornEnModel, '__init__') + def test_load_model(self, mock_init): + """Test model loading.""" + mock_init.return_value = None + model = PornEnGPUModel() + model._load_model() + assert mock_init.call_count == 1 + + @patch.object(PornEnGPUModel, '_load_model') + def test_get_resource_requirement(self, mock_load_model): + """Test resource requirements.""" + mock_load_model.return_value = MagicMock() + model = PornEnGPUModel() + resource_requirement = model.get_resource_requirement() + assert resource_requirement.num_cpus == 12 + assert resource_requirement.memory_GB == 64 + assert resource_requirement.num_gpus == 1 + + @patch.object(PornEnGPUModel, '_load_model') + def test_get_batch_config(self, mock_load_model): + """Test batch configuration.""" + mock_load_model.return_value = MagicMock() + model = PornEnGPUModel() + batch_config = model.get_batch_config() + assert batch_config.max_batch_size == 1000 + assert batch_config.optimal_batch_size == 512 + assert batch_config.min_batch_size == 8 + + @patch.object(PornEnGPUModel, '_load_model') + def test_predict_batch(self, mock_load_model): + """Test batch prediction.""" + mock_model = MagicMock() + mock_model.get_output_key.return_value = 'prob' + mock_model.predict.return_value = [{'prob': 0.96}, {'prob': 0.94}] + mock_load_model.return_value = mock_model + + model = PornEnGPUModel() + model.model = mock_model + + results = model.predict_batch(['test1', 'test2']) + assert len(results) == 2 + assert results[0]['porn_prob'] == 0.96 + assert results[1]['porn_prob'] == 0.94 + + # Test model not initialized + model.model = None + with self.assertRaises(RuntimeError): + model.predict_batch(['test']) + + @patch.object(PornEnGPUModel, '_load_model') + def test_convert_result_to_response(self, mock_load_model): + """Test result conversion to response.""" + mock_load_model.return_value = MagicMock() + model = PornEnGPUModel() + + # Test with high probability (should be remained) + response = model.convert_result_to_response({'porn_prob': 0.21}) + assert isinstance(response, PornResponse) + assert not response.is_remained + assert response.details == {'porn_prob': 0.21} + + # Test with low probability (should not be remained) + response = model.convert_result_to_response({'porn_prob': 0.19}) + assert isinstance(response, PornResponse) + assert response.is_remained + assert response.details == {'porn_prob': 0.19} + + +class TestPornZhGPUModel(TestCase): + """Test cases for PornZhGPUModel.""" + + from llm_web_kit.model.porn_detector import XlmrModel as PornZhModel + + @patch.object(PornZhModel, '__init__') + def test_load_model(self, mock_init): + """Test model loading.""" + mock_init.return_value = None + model = PornZhGPUModel() + model._load_model() + assert mock_init.call_count == 1 + + @patch.object(PornZhGPUModel, '_load_model') + def test_get_resource_requirement(self, mock_load_model): + """Test resource requirements.""" + mock_load_model.return_value = MagicMock() + model = PornZhGPUModel() + resource_requirement = model.get_resource_requirement() + assert resource_requirement.num_cpus == 12 + assert resource_requirement.memory_GB == 64 + assert resource_requirement.num_gpus == 1 + + @patch.object(PornZhGPUModel, '_load_model') + def test_get_batch_config(self, mock_load_model): + """Test batch configuration.""" + mock_load_model.return_value = MagicMock() + model = PornZhGPUModel() + batch_config = model.get_batch_config() + assert batch_config.max_batch_size == 300 + assert batch_config.optimal_batch_size == 256 + assert batch_config.min_batch_size == 8 + + @patch.object(PornZhGPUModel, '_load_model') + def test_predict_batch(self, mock_load_model): + """Test batch prediction.""" + mock_model = MagicMock() + mock_model.get_output_key.return_value = 'prob' + mock_model.predict.return_value = [{'prob': 0.96}, {'prob': 0.94}] + mock_load_model.return_value = mock_model + + model = PornZhGPUModel() + model.model = mock_model + + results = model.predict_batch(['test1', 'test2']) + assert len(results) == 2 + assert results[0]['porn_prob'] == 0.96 + assert results[1]['porn_prob'] == 0.94 + + # Test model not initialized + model.model = None + with self.assertRaises(RuntimeError): + model.predict_batch(['test']) + + @patch.object(PornZhGPUModel, '_load_model') + def test_convert_result_to_response(self, mock_load_model): + """Test result conversion to response.""" + mock_load_model.return_value = MagicMock() + model = PornZhGPUModel() + + # Test with high probability (should be remained) + response = model.convert_result_to_response({'porn_prob': 0.96}) + assert isinstance(response, PornResponse) + assert response.is_remained + assert response.details == {'porn_prob': 0.96} + + # Test with low probability (should not be remained) + response = model.convert_result_to_response({'porn_prob': 0.94}) + assert isinstance(response, PornResponse) + assert not response.is_remained + assert response.details == {'porn_prob': 0.94} + + +class TestPornPredictorImpl(TestCase): + """Test cases for PornPredictorImpl.""" + + @patch.object(PornEnGPUModel, '_load_model') + @patch.object(PornZhGPUModel, '_load_model') + @patch.object(PornZhGPUModel, 'predict_batch') + @patch.object(PornEnGPUModel, 'predict_batch') + def test_predict_batch(self, mock_predict_batch_en, mock_predict_batch_zh, + mock_load_model_en, mock_load_model_zh): + """Test batch prediction.""" + mock_load_model_en.return_value = MagicMock() + mock_load_model_zh.return_value = MagicMock() + mock_predict_batch_en.return_value = [{'porn_prob': 0.19}, {'porn_prob': 0.3}] + mock_predict_batch_zh.return_value = [{'porn_prob': 0.19}, {'porn_prob': 0.3}] + + predictor = PornPredictorImpl('en') + assert predictor.language == 'en' + with self.assertRaises(ModelRuntimeException): + predictor.predict_batch([ + PornRequest(content='Hello, world!', language='en'), + PornRequest(content='你好', language='zh') + ]) + assert mock_predict_batch_en.call_count == 1 + + results = predictor.predict_batch([ + PornRequest(content='Hello, world!', language='en'), + PornRequest(content='nihao', language='en') + ]) + assert results[0].is_remained + assert not results[1].is_remained + + @patch.object(PornEnGPUModel, '_load_model') + @patch.object(PornZhGPUModel, '_load_model') + def test_create_model(self, mock_load_model_en, mock_load_model_zh): + """Test model creation.""" + mock_load_model_en.return_value = MagicMock() + mock_load_model_zh.return_value = MagicMock() + predictor = PornPredictorImpl('en') + assert predictor.language == 'en' + assert predictor.model is not None + + +def test_model_factory(): + """Test ModelFactory creation.""" + factory = ModelFactory() + assert factory is not None + + +class TestModelFactory(TestCase): + """Test cases for ModelFactory.""" + + @patch.object(PoliticalPredictorImpl, '_create_model') + @patch.object(PoliticalCPUModel, '_load_model') + def test_create_predictor(self, mock_load_model, mock_create_model): + """Test ModelFactory.create_predictor method.""" + mock_load_model.return_value = MagicMock() + mock_create_model.return_value = MagicMock() + predictor = ModelFactory.create_predictor(ModelType.POLITICAL, 'en') + assert isinstance(predictor, PoliticalPredictorImpl) + assert mock_create_model.call_count == 1 + + @patch.object(PornPredictorImpl, '_create_model') + @patch.object(PornEnGPUModel, '_load_model') + def test_create_predictor_porn(self, mock_load_model, mock_create_model): + """Test ModelFactory.create_predictor method for porn model.""" + mock_load_model.return_value = MagicMock() + mock_create_model.return_value = MagicMock() + predictor = ModelFactory.create_predictor(ModelType.PORN, 'en') + assert isinstance(predictor, PornPredictorImpl) + assert mock_create_model.call_count == 1 + + @patch.object(PornPredictorImpl, '_create_model') + @patch.object(PornZhGPUModel, '_load_model') + def test_create_predictor_porn_zh(self, mock_load_model, mock_create_model): + """Test ModelFactory.create_predictor method for porn model.""" + mock_load_model.return_value = MagicMock() + mock_create_model.return_value = MagicMock() + predictor = ModelFactory.create_predictor(ModelType.PORN, 'zh') + assert isinstance(predictor, PornPredictorImpl) + assert mock_create_model.call_count == 1 + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/llm_web_kit/model/test_model_interface.py b/tests/llm_web_kit/model/test_model_interface.py new file mode 100644 index 00000000..d16c2042 --- /dev/null +++ b/tests/llm_web_kit/model/test_model_interface.py @@ -0,0 +1,122 @@ +"""Test cases for model_interface.py.""" + +import pytest + +from llm_web_kit.model.model_interface import (BatchProcessConfig, + ModelPredictor, ModelRequest, + ModelResource, ModelResponse, + PoliticalPredictor, + PoliticalRequest, + PoliticalResponse, + PornPredictor, PornRequest, + PornResponse, + ResourceRequirement, + ResourceType) + + +def test_model_request() -> None: + """Test ModelRequest initialization and attributes.""" + request = ModelRequest(content='Hello, world!', language='en') + assert request.content == 'Hello, world!' + assert request.language == 'en' + + +def test_model_response() -> None: + """Test ModelResponse initialization and attributes.""" + response = ModelResponse(is_remained=True, details={'score': 0.95}) + assert response.is_remained + assert response.details['score'] == 0.95 + + +def test_political_request() -> None: + """Test PoliticalRequest initialization and attributes.""" + request = PoliticalRequest(content='Hello, world!', language='en') + assert request.content == 'Hello, world!' + assert request.language == 'en' + + +def test_political_response() -> None: + """Test PoliticalResponse initialization and attributes.""" + response = PoliticalResponse(is_remained=True, details={'score': 0.95}) + assert response.is_remained + assert response.details['score'] == 0.95 + + +def test_porn_request() -> None: + """Test PornRequest initialization and attributes.""" + request = PornRequest(content='Hello, world!', language='en') + assert request.content == 'Hello, world!' + assert request.language == 'en' + + +def test_porn_response() -> None: + """Test PornResponse initialization and attributes.""" + response = PornResponse(is_remained=True, details={'score': 0.95}) + assert response.is_remained + assert response.details['score'] == 0.95 + + +def test_batch_process_config() -> None: + """Test BatchProcessConfig initialization and attributes.""" + config = BatchProcessConfig( + max_batch_size=100, + optimal_batch_size=50, + min_batch_size=10 + ) + assert config.max_batch_size == 100 + assert config.optimal_batch_size == 50 + assert config.min_batch_size == 10 + + +def test_resource_type() -> None: + """Test ResourceType enum values.""" + assert ResourceType.CPU + assert ResourceType.GPU + assert ResourceType.DEFAULT + + +def test_resource_requirement() -> None: + """Test ResourceRequirement initialization and conversion to ray + resources.""" + requirement = ResourceRequirement(num_cpus=1, memory_GB=4, num_gpus=0.25) + assert requirement.num_cpus == 1 + assert requirement.memory_GB == 4 + assert requirement.num_gpus == 0.25 + assert requirement.to_ray_resources() == { + 'num_cpus': 1, + 'memory': 4 * 2**30, + 'num_gpus': 0.25 + } + + cpu_only_requirement = ResourceRequirement(num_cpus=1, memory_GB=4) + assert cpu_only_requirement.to_ray_resources() == { + 'num_cpus': 1, + 'memory': 4 * 2**30, + 'resources': {'cpu_only': 1} + } + + +def test_model_resource() -> None: + """Test that ModelResource cannot be instantiated directly.""" + with pytest.raises(TypeError): + ModelResource() + + +def test_model_predictor() -> None: + """Test that ModelPredictor cannot be instantiated directly.""" + with pytest.raises(TypeError): + ModelPredictor() + + +def test_political_predictor() -> None: + """Test PoliticalPredictor implementation.""" + with pytest.raises(TypeError): + predictor = PoliticalPredictor() + assert isinstance(predictor, ModelPredictor) + + +def test_porn_predictor() -> None: + """Test PornPredictor implementation.""" + with pytest.raises(TypeError): + predictor = PornPredictor() + assert isinstance(predictor, ModelPredictor) diff --git a/tests/llm_web_kit/model/test_rule_based_safety_module.py b/tests/llm_web_kit/model/test_rule_based_safety_module.py new file mode 100644 index 00000000..983018cf --- /dev/null +++ b/tests/llm_web_kit/model/test_rule_based_safety_module.py @@ -0,0 +1,141 @@ +import unittest +from unittest import TestCase +from unittest.mock import patch + +from llm_web_kit.model.domain_safety_detector import DomainFilter +# 需要根据实际模块路径调整 +from llm_web_kit.model.rule_based_safety_module import ( + RuleBasedSafetyModule, RuleBasedSafetyModuleDataPack, check_type) +from llm_web_kit.model.source_safety_detector import SourceFilter +from llm_web_kit.model.unsafe_words_detector import UnsafeWordsFilter + + +class TestCheckType(TestCase): + def test_type_checking(self): + """测试类型检查工具函数.""" + # 测试类型匹配的情况 + try: + check_type('test', 'string', str) + except TypeError: + self.fail('Valid type check failed') + + # 测试类型不匹配的情况 + with self.assertRaises(TypeError) as cm: + check_type('test', 123, str) + + expected_error = ( + "The type of test should be , but got " + ) + self.assertEqual(str(cm.exception), expected_error) + + +class TestRuleBasedSafetyModuleDataPack(TestCase): + def test_init_type_checks(self): + """测试初始化时的类型检查.""" + valid_args = { + 'content_str': 'test', + 'language': 'en', + 'language_details': 'details', + 'content_style': 'article', + 'url': 'http://test.com', + 'dataset_name': 'test_dataset', + } + + # 测试所有参数的正确类型 + try: + RuleBasedSafetyModuleDataPack(**valid_args) + except TypeError: + self.fail('Type check failed for valid types') + + # 逐个测试每个参数的类型错误 + for param in valid_args: + invalid_args = valid_args.copy() + invalid_args[param] = 123 # 故意设置错误类型 + with self.assertRaises(TypeError): + RuleBasedSafetyModuleDataPack(**invalid_args) + + def test_set_process_result(self): + """测试设置处理结果的功能.""" + data_pack = RuleBasedSafetyModuleDataPack('test', 'en', 'details', 'article','http://test.com', 'test_dataset') + + # 测试正确类型 + data_pack.set_process_result(False, {'key': 'value'}) + self.assertFalse(data_pack.safety_remained) + self.assertEqual(data_pack.safety_infos, {'key': 'value'}) + + # 测试错误类型 + with self.assertRaises(TypeError): + data_pack.set_process_result('not_bool', {'key': 'value'}) + + with self.assertRaises(TypeError): + data_pack.set_process_result(False, 'not_dict') + + def test_get_output(self): + """测试输出字典的生成.""" + data_pack = RuleBasedSafetyModuleDataPack('test', 'en', 'details', 'article','http://test.com', 'test_dataset') + data_pack.set_process_result(False, {'info': 'test'}) + + expected_output = {'safety_remained': False, 'safety_infos': {'info': 'test'}} + self.assertDictEqual(data_pack.get_output(), expected_output) + + +class TestRuleBasedSafetyModule(TestCase): + @patch.object(DomainFilter,'filter') + @patch.object(SourceFilter,'filter') + @patch.object(UnsafeWordsFilter,'filter') + def test_process_core(self, mock_unsafe_words_filter, mock_source_filter, mock_domain_filter): + """测试核心处理流程.""" + # 设置模拟返回值 + mock_source_filter.return_value = {'from_safe_source': False, 'from_domestic_source': False} + mock_domain_filter.return_value = (True, {}) + mock_unsafe_words_filter.return_value = (False, {'reason': 'test'}) + + # 初始化测试对象 + safety_module = RuleBasedSafetyModule(prod=True) + data_pack = RuleBasedSafetyModuleDataPack('test', 'en', 'details', 'article','http://test.com', 'test_dataset') + + # 执行核心处理 + result = safety_module.process_core(data_pack) + + # 验证过滤方法被正确调用 + mock_unsafe_words_filter.assert_called_once_with('test', 'en', 'details', 'article', False, False) + # 验证处理结果设置正确 + self.assertFalse(result.safety_remained) + self.assertEqual(result.safety_infos, {'reason': 'test'}) + + @patch.object(DomainFilter,'filter') + @patch.object(SourceFilter,'filter') + @patch.object(UnsafeWordsFilter,'filter') + def test_process_flow(self, mock_unsafe_words_filter, mock_source_filter, mock_domain_filter): + """测试完整处理流程.""" + mock_source_filter.return_value = {'from_safe_source': False, 'from_domestic_source': False} + mock_domain_filter.return_value = (True, {}) + mock_unsafe_words_filter.return_value = (True, {}) + + safety_module = RuleBasedSafetyModule(prod=False) + result = safety_module.process( + content_str='content', + language='en', + language_details='details', + content_style='article', + url='http://test.com', + dataset_name='test_dataset', + ) + + expected_result = {'safety_remained': True, 'safety_infos': {}} + self.assertDictEqual(result, expected_result) + + def test_production_mode_effect(self): + """测试生产模式的影响.""" + # 根据实际业务逻辑补充测试 + # 当前代码中prod参数未实际使用,需要根据具体实现调整 + pass + + def test_get_version(self): + """测试版本获取.""" + safety_module = RuleBasedSafetyModule(prod=True) + self.assertEqual(safety_module.get_version(), '1.0.0') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/llm_web_kit/model/test_source_safety_detector.py b/tests/llm_web_kit/model/test_source_safety_detector.py new file mode 100644 index 00000000..355cd277 --- /dev/null +++ b/tests/llm_web_kit/model/test_source_safety_detector.py @@ -0,0 +1,24 @@ +import unittest + +from llm_web_kit.model.source_safety_detector import SourceFilter + + +class TestSourceFilter(unittest.TestCase): + def setUp(self): + self.filter = SourceFilter() + + # 测试非法来源 + def test_unsafe_source(self): + result = self.filter.filter( + content_str='Unverified data', + language='en', + data_source='http://unknown-site.com', + language_details='informal', + content_style='user-generated' + ) + self.assertFalse(result['from_safe_source']) + self.assertFalse(result['from_domestic_source']) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/llm_web_kit/model/test_unsafe_words_detector.py b/tests/llm_web_kit/model/test_unsafe_words_detector.py index a26c608a..f7c63f7f 100644 --- a/tests/llm_web_kit/model/test_unsafe_words_detector.py +++ b/tests/llm_web_kit/model/test_unsafe_words_detector.py @@ -2,10 +2,10 @@ from unittest.mock import MagicMock, Mock, mock_open, patch from llm_web_kit.exception.exception import SafeModelException -from llm_web_kit.model.unsafe_words_detector import ( - UnsafeWordChecker, auto_download, decide_unsafe_word_by_data_checker, - get_ac, get_unsafe_words, get_unsafe_words_checker, unsafe_words_filter, - unsafe_words_filter_overall) +from llm_web_kit.model.unsafe_words_detector import (UnsafeWordChecker, + auto_download, get_ac, + get_unsafe_words, + get_unsafe_words_checker) class TestUnsafeWordChecker(unittest.TestCase): @@ -37,18 +37,6 @@ def test_check_unsafe_words(self, mock_get_ac): self.assertIsInstance(result, list) self.assertEqual(len(result), 0) - @patch('llm_web_kit.model.unsafe_words_detector.get_unsafe_words_checker') - def test_decide_unsafe_word_by_data_checker(self, mock_get_checker): - mock_checker = MagicMock() - mock_checker.check_unsafe_words.return_value = [ - {'word': 'unsafe', 'level': 'L2', 'count': 1} - ] - mock_get_checker.return_value = mock_checker - - data_dict = {'content': 'Some content with unsafe elements.'} - result = decide_unsafe_word_by_data_checker(data_dict, mock_checker) - self.assertEqual(result, 'L2') - def test_standalone_word_detection(self): """测试独立存在的子词能被正确识别[2,6](@ref)""" ac = Mock() @@ -82,92 +70,6 @@ def test_get_unsafe_words_checker(self, mock_get_ac): checker2 = get_unsafe_words_checker('zh-en') self.assertIs(checker1, checker2) # Should return the same instance - @patch('llm_web_kit.model.unsafe_words_detector.get_unsafe_words_checker') - def test_unsafe_words_filter(self, mock_get_checker): - mock_checker = MagicMock() - mock_checker.check_unsafe_words.return_value = [ - {'word': '', 'level': 'L3', 'count': 1} - ] - mock_get_checker.return_value = mock_checker - - data_dict = {'content': 'Test content'} - result = unsafe_words_filter(data_dict, 'en', 'text') - self.assertEqual(result, 'L3') - result = unsafe_words_filter(data_dict, 'ko', 'text') - self.assertEqual(result, 'L3') - with self.assertRaises(SafeModelException): - unsafe_words_filter(data_dict, 'unk', 'text') - - def test_unsafe_words_filter_with_unsupported_language(self): - data_dict = {'content': 'Test content'} - with self.assertRaises(SafeModelException): - unsafe_words_filter(data_dict, 'unsupported_language', 'text') - - @patch('llm_web_kit.model.unsafe_words_detector.unsafe_words_filter') - def test_unsafe_words_filter_overall(self, mock_filter): - mock_filter.return_value = 'L1' - - data_dict = {'content': 'Content with unsafe words.'} - - result = unsafe_words_filter_overall( - data_dict, - language='en', - content_style='text', - from_safe_source=False, - from_domestic_source=False, - ) - self.assertIsInstance(result, dict) - self.assertTrue(result['hit_unsafe_words']) - - result = unsafe_words_filter_overall( - data_dict, - language='en', - content_style='text', - from_safe_source=False, - from_domestic_source=True, - ) - self.assertIsInstance(result, dict) - self.assertTrue(result['hit_unsafe_words']) - - result = unsafe_words_filter_overall( - data_dict, - language='en', - content_style='text', - from_safe_source=True, - from_domestic_source=True, - ) - self.assertIsInstance(result, dict) - self.assertFalse(result['hit_unsafe_words']) - - result = unsafe_words_filter_overall( - data_dict, - language='en', - content_style='text', - from_safe_source=True, - from_domestic_source=False, - ) - self.assertIsInstance(result, dict) - self.assertFalse(result['hit_unsafe_words']) - - result = unsafe_words_filter_overall( - data_dict, - language='ru', - content_style='text', - from_safe_source=False, - from_domestic_source=False, - ) - self.assertIsInstance(result, dict) - self.assertTrue(result['hit_unsafe_words']) - - with self.assertRaises(SafeModelException): - result = unsafe_words_filter_overall( - data_dict, - language='unknown', - content_style='text', - from_safe_source=False, - from_domestic_source=False, - ) - @patch('llm_web_kit.model.unsafe_words_detector.load_config') @patch('llm_web_kit.model.unsafe_words_detector.download_auto_file') def test_auto_download(self, mock_download_auto_file, mock_load_config):