diff --git a/llm_web_kit/model/code_detector.py b/llm_web_kit/model/code_detector.py index dac7dd4c..c419d1c4 100644 --- a/llm_web_kit/model/code_detector.py +++ b/llm_web_kit/model/code_detector.py @@ -7,12 +7,10 @@ from llm_web_kit.config.cfg_reader import load_config from llm_web_kit.libs.logger import mylogger as logger -from llm_web_kit.model.resource_utils.download_assets import ( - CACHE_DIR, download_auto_file) -from llm_web_kit.model.resource_utils.singleton_resource_manager import \ - singleton_resource_manager -from llm_web_kit.model.resource_utils.unzip_ext import (get_unzip_dir, - unzip_local_file) +from llm_web_kit.model.resource_utils import (CACHE_DIR, download_auto_file, + get_unzip_dir, + singleton_resource_manager, + unzip_local_file) class CodeClassification: @@ -139,13 +137,17 @@ def decide_code_func(content_str: str, code_detect: CodeClassification) -> float if str_len > 10000: logger.warning('Content string is too long, truncate to 10000 characters') start_idx = (str_len - 10000) // 2 - content_str = content_str[start_idx:start_idx + 10000] + content_str = content_str[start_idx : start_idx + 10000] # check if the content string contains latex environment if detect_latex_env(content_str): - logger.warning('Content string contains latex environment, may be misclassified') + logger.warning( + 'Content string contains latex environment, may be misclassified' + ) - def decide_code_by_prob_v3(predictions: Tuple[str], probabilities: Tuple[float]) -> float: + def decide_code_by_prob_v3( + predictions: Tuple[str], probabilities: Tuple[float] + ) -> float: idx = predictions.index('__label__1') true_prob = probabilities[idx] return true_prob @@ -154,7 +156,9 @@ def decide_code_by_prob_v3(predictions: Tuple[str], probabilities: Tuple[float]) predictions, probabilities = code_detect.predict(content_str) result = decide_code_by_prob_v3(predictions, probabilities) else: - raise ValueError(f'Unsupported version: {code_detect.version}. Supported versions: {[CODE_CL_SUPPORTED_VERSIONS]}') + raise ValueError( + f'Unsupported version: {code_detect.version}. Supported versions: {[CODE_CL_SUPPORTED_VERSIONS]}' + ) return result diff --git a/llm_web_kit/model/html_layout_cls.py b/llm_web_kit/model/html_layout_cls.py index 8f566709..e4c86694 100644 --- a/llm_web_kit/model/html_layout_cls.py +++ b/llm_web_kit/model/html_layout_cls.py @@ -4,10 +4,8 @@ from llm_web_kit.config.cfg_reader import load_config from llm_web_kit.libs.logger import mylogger as logger from llm_web_kit.model.html_classify.model import Markuplm -from llm_web_kit.model.resource_utils.download_assets import ( - CACHE_DIR, download_auto_file) -from llm_web_kit.model.resource_utils.unzip_ext import (get_unzip_dir, - unzip_local_file) +from llm_web_kit.model.resource_utils import (CACHE_DIR, download_auto_file, + get_unzip_dir, unzip_local_file) class HTMLLayoutClassifier: diff --git a/llm_web_kit/model/lang_id.py b/llm_web_kit/model/lang_id.py index a2e898ce..7b08c04f 100644 --- a/llm_web_kit/model/lang_id.py +++ b/llm_web_kit/model/lang_id.py @@ -6,10 +6,8 @@ from llm_web_kit.config.cfg_reader import load_config from llm_web_kit.libs.logger import mylogger as logger -from llm_web_kit.model.resource_utils.download_assets import ( - CACHE_DIR, download_auto_file) -from llm_web_kit.model.resource_utils.singleton_resource_manager import \ - singleton_resource_manager +from llm_web_kit.model.resource_utils import (CACHE_DIR, download_auto_file, + singleton_resource_manager) language_dict = { 'srp': 'sr', 'swe': 'sv', 'dan': 'da', 'ita': 'it', 'spa': 'es', 'pes': 'fa', 'slk': 'sk', 'hun': 'hu', 'bul': 'bg', 'cat': 'ca', diff --git a/llm_web_kit/model/libgomp.so.1 b/llm_web_kit/model/libgomp.so.1 new file mode 100644 index 00000000..91f39643 Binary files /dev/null and b/llm_web_kit/model/libgomp.so.1 differ diff --git a/llm_web_kit/model/policical.py b/llm_web_kit/model/policical.py index a46e933d..dff0e50a 100644 --- a/llm_web_kit/model/policical.py +++ b/llm_web_kit/model/policical.py @@ -2,30 +2,34 @@ from typing import Any, Dict, Tuple import fasttext -from transformers import AutoTokenizer from llm_web_kit.config.cfg_reader import load_config from llm_web_kit.exception.exception import ModelInputException from llm_web_kit.input.datajson import DataJson from llm_web_kit.libs.logger import mylogger as logger -from llm_web_kit.model.resource_utils.download_assets import ( - CACHE_DIR, download_auto_file) -from llm_web_kit.model.resource_utils.singleton_resource_manager import \ - singleton_resource_manager -from llm_web_kit.model.resource_utils.unzip_ext import (get_unzip_dir, - unzip_local_file) +from llm_web_kit.model.resource_utils import (CACHE_DIR, download_auto_file, + get_unzip_dir, + singleton_resource_manager, + unzip_local_file) class PoliticalDetector: def __init__(self, model_path: str = None): + # import AutoTokenizer here to avoid isort error + # must set the HF_HOME to the CACHE_DIR at this point + os.environ['HF_HOME'] = CACHE_DIR + from transformers import AutoTokenizer + if not model_path: model_path = self.auto_download() model_bin_path = os.path.join(model_path, 'model.bin') tokenizer_path = os.path.join(model_path, 'internlm2-chat-20b') self.model = fasttext.load_model(model_bin_path) - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=False, trust_remote_code=True) + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, use_fast=False, trust_remote_code=True + ) def auto_download(self): """Default download the 24m7.zip model.""" @@ -46,7 +50,9 @@ def auto_download(self): if not os.path.exists(zip_path): logger.info(f'zip_path: {zip_path} does not exist') logger.info(f'downloading {political_24m7_s3}') - zip_path = download_auto_file(political_24m7_s3, zip_path, political_24m7_md5) + zip_path = download_auto_file( + political_24m7_s3, zip_path, political_24m7_md5 + ) logger.info(f'unzipping {zip_path}') unzip_path = unzip_local_file(zip_path, unzip_path) return unzip_path @@ -54,7 +60,9 @@ def auto_download(self): def predict(self, text: str) -> Tuple[str, float]: text = text.replace('\n', ' ') input_ids = self.tokenizer(text)['input_ids'] - predictions, probabilities = self.model.predict(' '.join([str(i) for i in input_ids]), k=-1) + predictions, probabilities = self.model.predict( + ' '.join([str(i) for i in input_ids]), k=-1 + ) return predictions, probabilities @@ -77,13 +85,17 @@ def get_singleton_political_detect() -> PoliticalDetector: return singleton_resource_manager.get_resource('political_detect') -def decide_political_by_prob(predictions: Tuple[str], probabilities: Tuple[float]) -> float: +def decide_political_by_prob( + predictions: Tuple[str], probabilities: Tuple[float] +) -> float: idx = predictions.index('__label__normal') normal_score = probabilities[idx] return normal_score -def decide_political_func(content_str: str, political_detect: PoliticalDetector) -> float: +def decide_political_func( + content_str: str, political_detect: PoliticalDetector +) -> float: # Limit the length of the content to 2560000 content_str = content_str[:2560000] predictions, probabilities = political_detect.predict(content_str) @@ -111,7 +123,9 @@ def political_filter_cpu(data_dict: Dict[str, Any], language: str): test_cases.append('hello, nice to meet you!') test_cases.append('你好,唔該幫我一個忙?') test_cases.append('Bawo ni? Mo nife Yoruba. ') - test_cases.append('你好,我很高兴见到你,请多多指教!你今天吃饭了吗?hello, nice to meet you!') + test_cases.append( + '你好,我很高兴见到你,请多多指教!你今天吃饭了吗?hello, nice to meet you!' + ) test_cases.append('איך בין אַ גרויסער פֿאַן פֿון די וויסנשאַפֿט. מיר האָבן פֿיל צו לערנען.') test_cases.append('გამარჯობა, როგორ ხარ? მე ვარ კარგად, მადლობა.') test_cases.append('გამარჯობა, როგორ ხართ? ეს ჩემი ქვეყანაა, საქართველო.') diff --git a/llm_web_kit/model/porn_detector.py b/llm_web_kit/model/porn_detector.py index ae3df185..d60c6b7a 100644 --- a/llm_web_kit/model/porn_detector.py +++ b/llm_web_kit/model/porn_detector.py @@ -7,24 +7,28 @@ from llm_web_kit.config.cfg_reader import load_config from llm_web_kit.libs.logger import mylogger as logger -from llm_web_kit.model.resource_utils.download_assets import ( - CACHE_DIR, download_auto_file) -from llm_web_kit.model.resource_utils.unzip_ext import (get_unzip_dir, - unzip_local_file) +from llm_web_kit.model.resource_utils import (CACHE_DIR, download_auto_file, + get_unzip_dir, unzip_local_file) -class BertModel(): +class BertModel: def __init__(self, model_path: str = None) -> None: if not model_path: model_path = self.auto_download() - self.model = AutoModelForSequenceClassification.from_pretrained(os.path.join(model_path, 'porn_classifier/classifier_hf')) - with open(os.path.join(model_path, 'porn_classifier/extra_parameters.json')) as reader: + self.model = AutoModelForSequenceClassification.from_pretrained( + os.path.join(model_path, 'porn_classifier/classifier_hf') + ) + with open( + os.path.join(model_path, 'porn_classifier/extra_parameters.json') + ) as reader: model_config = json.load(reader) self.cls_index = int(model_config.get('cls_index', 1)) self.use_sigmoid = bool(model_config.get('use_sigmoid', False)) self.max_tokens = int(model_config.get('max_tokens', 512)) - self.remain_tail = min(self.max_tokens - 1, int(model_config.get('remain_tail', -1))) + self.remain_tail = min( + self.max_tokens - 1, int(model_config.get('remain_tail', -1)) + ) self.device = model_config.get('device', 'cpu') self.model.eval() @@ -33,7 +37,9 @@ def __init__(self, model_path: str = None) -> None: if hasattr(self.model, 'to_bettertransformer'): self.model = self.model.to_bettertransformer() - self.tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, 'porn_classifier/classifier_hf')) + self.tokenizer = AutoTokenizer.from_pretrained( + os.path.join(model_path, 'porn_classifier/classifier_hf') + ) self.tokenizer_config = { 'padding': True, 'truncation': self.remain_tail <= 0, @@ -86,22 +92,36 @@ def pre_process(self, samples: Union[List[str], str]) -> Dict: length = tokens_id.index(self.tokenizer.sep_token_id) + 1 # 如果tokens的长度小于等于max_tokens,则直接在尾部补0,不需要截断 if length <= self.max_tokens: - tokens = tokens_id[:length] + [self.tokenizer.pad_token_id] * (self.max_tokens - length) + tokens = tokens_id[:length] + [self.tokenizer.pad_token_id] * ( + self.max_tokens - length + ) attn = [1] * length + [0] * (self.max_tokens - length) # 如果tokens的长度大于max_tokens,则需要取头部max_tokens-remain_tail个tokens和尾部remain_tail个tokens else: head_length = self.max_tokens - self.remain_tail tail_length = self.remain_tail - tokens = tokens_id[:head_length] + tokens_id[length - tail_length : length] + tokens = ( + tokens_id[:head_length] + + tokens_id[length - tail_length : length] + ) attn = [1] * self.max_tokens # 将处理后的tokens添加到新的inputs列表中 - processed_inputs.append({'input_ids': torch.tensor(tokens), 'attention_mask': torch.tensor(attn)}) + processed_inputs.append( + { + 'input_ids': torch.tensor(tokens), + 'attention_mask': torch.tensor(attn), + } + ) # 将所有inputs整合成一个batch inputs = { - 'input_ids': torch.cat([inp['input_ids'].unsqueeze(0) for inp in processed_inputs]), - 'attention_mask': torch.cat([inp['attention_mask'].unsqueeze(0) for inp in processed_inputs]), + 'input_ids': torch.cat( + [inp['input_ids'].unsqueeze(0) for inp in processed_inputs] + ), + 'attention_mask': torch.cat( + [inp['attention_mask'].unsqueeze(0) for inp in processed_inputs] + ), } inputs = {name: tensor.to(self.device) for name, tensor in inputs.items()} return {'inputs': inputs} diff --git a/llm_web_kit/model/quality_model.py b/llm_web_kit/model/quality_model.py index 837c91ea..f6d95bd1 100644 --- a/llm_web_kit/model/quality_model.py +++ b/llm_web_kit/model/quality_model.py @@ -19,10 +19,8 @@ stats_html_entity, stats_ngram_mini, stats_punctuation_end_sentence, stats_stop_words, stats_unicode) from llm_web_kit.model.basic_functions.utils import div_zero -from llm_web_kit.model.resource_utils.download_assets import ( - CACHE_DIR, download_auto_file) -from llm_web_kit.model.resource_utils.unzip_ext import (get_unzip_dir, - unzip_local_file) +from llm_web_kit.model.resource_utils import (CACHE_DIR, download_auto_file, + get_unzip_dir, unzip_local_file) _global_quality_model = {} _model_resource_map = { diff --git a/llm_web_kit/model/resource_utils/__init__.py b/llm_web_kit/model/resource_utils/__init__.py index e69de29b..79ea734a 100644 --- a/llm_web_kit/model/resource_utils/__init__.py +++ b/llm_web_kit/model/resource_utils/__init__.py @@ -0,0 +1,6 @@ +from .download_assets import download_auto_file +from .singleton_resource_manager import singleton_resource_manager +from .unzip_ext import get_unzip_dir, unzip_local_file +from .utils import CACHE_DIR, CACHE_TMP_DIR + +__all__ = ['download_auto_file', 'unzip_local_file', 'get_unzip_dir', 'CACHE_DIR', 'CACHE_TMP_DIR', 'singleton_resource_manager'] diff --git a/llm_web_kit/model/resource_utils/download_assets.py b/llm_web_kit/model/resource_utils/download_assets.py index ab7cbead..9ea85f95 100644 --- a/llm_web_kit/model/resource_utils/download_assets.py +++ b/llm_web_kit/model/resource_utils/download_assets.py @@ -1,75 +1,126 @@ +"""本模块提供从 S3 或 HTTP 下载文件的功能,支持校验和验证和并发下载锁机制。 + +主要功能: +1. 计算文件的 MD5 和 SHA256 校验和 +2. 通过 S3 或 HTTP 连接下载文件 +3. 使用文件锁防止并发下载冲突 +4. 自动校验文件完整性 + +类说明: +- Connection: 抽象基类,定义下载连接接口 +- S3Connection: 实现 S3 文件下载连接 +- HttpConnection: 实现 HTTP 文件下载连接 + +函数说明: +- calc_file_md5/sha256: 计算文件哈希值 +- verify_file_checksum: 校验文件哈希 +- download_auto_file_core: 核心下载逻辑 +- download_auto_file: 自动下载入口函数(含锁机制) +""" + import hashlib import os -import shutil import tempfile +from functools import partial from typing import Iterable, Optional import requests from tqdm import tqdm -from llm_web_kit.config.cfg_reader import load_config from llm_web_kit.exception.exception import ModelResourceException from llm_web_kit.libs.logger import mylogger as logger from llm_web_kit.model.resource_utils.boto3_ext import (get_s3_client, is_s3_path, split_s3_path) -from llm_web_kit.model.resource_utils.utils import FileLockContext, try_remove +from llm_web_kit.model.resource_utils.process_with_lock import \ + process_and_verify_file_with_lock +from llm_web_kit.model.resource_utils.utils import CACHE_TMP_DIR + +def calc_file_md5(file_path: str) -> str: + """计算文件的 MD5 校验和. -def decide_cache_dir(): - """Get the cache directory for the web kit. The. + Args: + file_path: 文件路径 Returns: - _type_: _description_ + MD5 哈希字符串(32位十六进制) """ - cache_dir = '~/.llm_web_kit_cache' + with open(file_path, 'rb') as f: + return hashlib.md5(f.read()).hexdigest() - if 'WEB_KIT_CACHE_DIR' in os.environ: - cache_dir = os.environ['WEB_KIT_CACHE_DIR'] - try: - config = load_config() - cache_dir = config['resources']['common']['cache_path'] - except Exception: - pass +def calc_file_sha256(file_path: str) -> str: + """计算文件的 SHA256 校验和. - if cache_dir.startswith('~/'): - cache_dir = os.path.expanduser(cache_dir) + Args: + file_path: 文件路径 - return cache_dir + Returns: + SHA256 哈希字符串(64位十六进制) + """ + with open(file_path, 'rb') as f: + return hashlib.sha256(f.read()).hexdigest() -CACHE_DIR = decide_cache_dir() +def verify_file_checksum( + file_path: str, md5_sum: Optional[str] = None, sha256_sum: Optional[str] = None +) -> bool: + """验证文件的 MD5 或 SHA256 校验和. + Args: + file_path: 待验证文件路径 + md5_sum: 预期 MD5 值(与 sha256_sum 二选一) + sha256_sum: 预期 SHA256 值(与 md5_sum 二选一) -def calc_file_md5(file_path: str) -> str: - """Calculate the MD5 checksum of a file.""" - with open(file_path, 'rb') as f: - return hashlib.md5(f.read()).hexdigest() + Returns: + bool: 校验是否通过 + Raises: + ModelResourceException: 当未提供或同时提供两个校验和时 + """ + if not (bool(md5_sum) ^ bool(sha256_sum)): + raise ModelResourceException( + 'Exactly one of md5_sum or sha256_sum must be provided' + ) + if not os.path.exists(file_path): + return False + if md5_sum: + actual = calc_file_md5(file_path) + if actual != md5_sum: + logger.warning( + f'MD5 mismatch: expect {md5_sum[:8]}..., got {actual[:8]}...' + ) + return False -def calc_file_sha256(file_path: str) -> str: - """Calculate the sha256 checksum of a file.""" - with open(file_path, 'rb') as f: - return hashlib.sha256(f.read()).hexdigest() + if sha256_sum: + actual = calc_file_sha256(file_path) + if actual != sha256_sum: + logger.warning( + f'SHA256 mismatch: expect {sha256_sum[:8]}..., got {actual[:8]}...' + ) + return False + return True -class Connection: - def __init__(self, *args, **kwargs): - pass +class Connection: + """下载连接的抽象基类.""" def get_size(self) -> int: + """获取文件大小(字节)""" raise NotImplementedError def read_stream(self) -> Iterable[bytes]: + """返回数据流的迭代器.""" raise NotImplementedError class S3Connection(Connection): + """S3 文件下载连接.""" def __init__(self, resource_path: str): - super().__init__(resource_path) + super().__init__() self.client = get_s3_client(resource_path) self.bucket, self.key = split_s3_path(resource_path) self.obj = self.client.get_object(Bucket=self.bucket, Key=self.key) @@ -83,13 +134,15 @@ def read_stream(self) -> Iterable[bytes]: yield chunk def __del__(self): - self.obj['Body'].close() + if hasattr(self, 'obj') and 'Body' in self.obj: + self.obj['Body'].close() class HttpConnection(Connection): + """HTTP 文件下载连接.""" def __init__(self, resource_path: str): - super().__init__(resource_path) + super().__init__() self.response = requests.get(resource_path, stream=True) self.response.raise_for_status() @@ -102,67 +155,70 @@ def read_stream(self) -> Iterable[bytes]: yield chunk def __del__(self): - self.response.close() + if hasattr(self, 'response'): + self.response.close() -def verify_file_checksum( - file_path: str, md5_sum: Optional[str] = None, sha256_sum: Optional[str] = None -) -> bool: - """校验文件哈希值.""" - if not sum([bool(md5_sum), bool(sha256_sum)]) == 1: - raise ModelResourceException( - 'Exactly one of md5_sum or sha256_sum must be provided' - ) +def download_to_temp(conn: Connection, progress_bar: tqdm, download_path: str): + """下载文件到临时目录. - if md5_sum: - actual = calc_file_md5(file_path) - if actual != md5_sum: - logger.warning( - f'MD5 mismatch: expect {md5_sum[:8]}..., got {actual[:8]}...' - ) - return False - - if sha256_sum: - actual = calc_file_sha256(file_path) - if actual != sha256_sum: - logger.warning( - f'SHA256 mismatch: expect {sha256_sum[:8]}..., got {actual[:8]}...' - ) - return False + Args: + conn: 下载连接 + progress_bar: 进度条 + download_path: 临时文件路径 + """ - return True + with open(download_path, 'wb') as f: + for chunk in conn.read_stream(): + if chunk: # 防止空chunk导致进度条卡死 + f.write(chunk) + progress_bar.update(len(chunk)) -def download_to_temp(conn, progress_bar) -> str: - """下载到临时文件.""" - with tempfile.NamedTemporaryFile(delete=False) as tmp_file: - tmp_path = tmp_file.name - logger.info(f'Downloading to temporary file: {tmp_path}') +def download_auto_file_core( + resource_path: str, + target_path: str, +) -> str: + """下载文件的核心逻辑(无锁) - try: - with open(tmp_path, 'wb') as f: - for chunk in conn.read_stream(): - if chunk: # 防止空chunk导致进度条卡死 - f.write(chunk) - progress_bar.update(len(chunk)) - return tmp_path - except Exception: - try_remove(tmp_path) - raise - - -def move_to_target(tmp_path: str, target_path: str, expected_size: int): - """移动文件并验证.""" - if os.path.getsize(tmp_path) != expected_size: - raise ModelResourceException( - f'File size mismatch: {os.path.getsize(tmp_path)} vs {expected_size}' - ) + Args: + resource_path: 源文件路径(S3或HTTP URL) + target_path: 目标保存路径 - os.makedirs(os.path.dirname(target_path), exist_ok=True) - shutil.move(tmp_path, target_path) # 原子操作替换 + Returns: + 下载后的文件路径 - if not os.path.exists(target_path): - raise ModelResourceException(f'Move failed: {tmp_path} -> {target_path}') + Raises: + ModelResourceException: 下载失败或文件大小不匹配时 + """ + # 初始化连接 + conn_cls = S3Connection if is_s3_path(resource_path) else HttpConnection + conn = conn_cls(resource_path) + total_size = conn.get_size() + + # 配置进度条 + logger.info(f'Downloading {resource_path} => {target_path}') + progress = tqdm(total=total_size, unit='iB', unit_scale=True) + + # 使用临时目录确保原子性 + with tempfile.TemporaryDirectory(dir=CACHE_TMP_DIR) as temp_dir: + download_path = os.path.join(temp_dir, 'download_file') + try: + download_to_temp(conn, progress, download_path) + + # 验证文件大小 + actual_size = os.path.getsize(download_path) + if total_size != actual_size: + raise ModelResourceException( + f'Size mismatch: expected {total_size}, got {actual_size}' + ) + + # 移动到目标路径 + os.makedirs(os.path.dirname(target_path), exist_ok=True) + os.rename(download_path, target_path) # 替换 os.rename + return target_path + finally: + progress.close() def download_auto_file( @@ -170,71 +226,27 @@ def download_auto_file( target_path: str, md5_sum: str = '', sha256_sum: str = '', - exist_ok=True, - lock_timeout: int = 300, + lock_suffix: str = '.lock', + lock_timeout: float = 60, ) -> str: - """Download a file from a given resource path (either an S3 path or an HTTP - URL) to a target path on the local file system. - - This function will first download the file to a temporary file, then move the temporary file to the target path after - the download is complete. A progress bar will be displayed during the download. - - If the size of the downloaded file does not match the expected size, an exception will be raised. + """自动下载文件(含锁机制和校验) Args: - resource_path (str): The path of the resource to download. This can be either an S3 path (e.g., "s3://bucket/key") - or an HTTP URL (e.g., "http://example.com/file"). - target_path (str): The path on the local file system where the downloaded file should be saved.\ - exist_ok (bool, optional): If False, raise an exception if the target path already exists. Defaults to True. + resource_path: 源文件路径 + target_path: 目标保存路径 + md5_sum: 预期 MD5 值(与 sha256_sum 二选一) + sha256_sum: 预期 SHA256 值(与 md5_sum 二选一) + lock_suffix: 锁文件后缀 + lock_timeout: 锁超时时间(秒) Returns: - str: The path where the downloaded file was saved. + 下载后的文件路径 Raises: - Exception: If an error occurs during the download, or if the size of the downloaded file does not match the - expected size, or if the temporary file cannot be moved to the target path. + ModelResourceException: 校验失败或下载错误时 """ - - """线程安全的文件下载函数""" - lock_path = f'{target_path}.lock' - - def check_callback(): - return verify_file_checksum(target_path, md5_sum, sha256_sum) - - if os.path.exists(target_path): - if not exist_ok: - raise ModelResourceException( - f'File exists with invalid checksum: {target_path}' - ) - - if verify_file_checksum(target_path, md5_sum, sha256_sum): - logger.info(f'File already exists with valid checksum: {target_path}') - return target_path - else: - logger.warning(f'Removing invalid file: {target_path}') - try_remove(target_path) - - with FileLockContext(lock_path, check_callback, timeout=lock_timeout) as lock: - if lock is True: - logger.info( - f'File already exists with valid checksum: {target_path} while waiting' - ) - return target_path - - # 创建连接 - conn_cls = S3Connection if is_s3_path(resource_path) else HttpConnection - conn = conn_cls(resource_path) - total_size = conn.get_size() - - # 下载流程 - logger.info(f'Downloading {resource_path} => {target_path}') - progress = tqdm(total=total_size, unit='iB', unit_scale=True) - - try: - tmp_path = download_to_temp(conn, progress) - move_to_target(tmp_path, target_path, total_size) - - return target_path - finally: - progress.close() - try_remove(tmp_path) # 确保清理临时文件 + process_func = partial(download_auto_file_core, resource_path, target_path) + verify_func = partial(verify_file_checksum, target_path, md5_sum, sha256_sum) + return process_and_verify_file_with_lock( + process_func, verify_func, target_path, lock_suffix, lock_timeout + ) diff --git a/llm_web_kit/model/resource_utils/process_with_lock.py b/llm_web_kit/model/resource_utils/process_with_lock.py new file mode 100644 index 00000000..b9c7fef3 --- /dev/null +++ b/llm_web_kit/model/resource_utils/process_with_lock.py @@ -0,0 +1,109 @@ +import os +import time +from typing import Callable + +from filelock import SoftFileLock, Timeout + +from llm_web_kit.model.resource_utils.utils import try_remove + + +def get_path_mtime(target_path: str) -> float: + """获得文件或目录的最新修改时间. 如果是文件,则直接返回 mtime. 如果是目录,则遍历目录获取最新的 mtime. + + Args: + target_path: 文件或目录路径 + + Returns: + float: 最新修改时间 + """ + if os.path.isdir(target_path): + # walk through the directory and get the latest mtime + latest_mtime = None + for root, _, files in os.walk(target_path): + for file in files: + file_path = os.path.join(root, file) + mtime = os.path.getmtime(file_path) + if latest_mtime is None or mtime > latest_mtime: + latest_mtime = mtime + return latest_mtime + else: + return os.path.getmtime(target_path) + + +def process_and_verify_file_with_lock( + process_func: Callable[[], str], # 无参数,返回目标路径 + verify_func: Callable[[], bool], # 无参数,返回验证结果 + target_path: str, + lock_suffix: str = '.lock', + timeout: float = 60, +) -> str: + # """通用处理验证框架. + + # :param process_func: 无参数的处理函数,返回最终目标路径 + # :param verify_func: 无参数的验证函数,返回布尔值 + # :param target_path: 目标路径(文件或目录) + # :param lock_suffix: 锁文件后缀 + # :param timeout: 处理超时时间(秒) + # """ + """ + 通用使用文件锁进行资源处理与资源验证的框架. + 使用文件锁保证处理函数调用时是唯一的。 + 资源校验不在锁保护范围内从而提高效率。 + 当资源校验不通过时,会删除目标文件并重新处理。 + 简易逻辑为: + 1. 检查目标是否存在且有效,如果是则直接返回目标路径 + 2. 如果目标不存在或无效,则尝试获取锁 + 3. 如果锁存在且陈旧,则删除锁和目标文件重新处理 + 4. 如果锁存在且未陈旧,则等待锁释放 + 5. 如果锁不存在,则执行处理函数 + 6. 处理完成后返回目标路径 + + Args: + process_func: 无参数的处理函数,返回最终目标路径 + verify_func: 无参数的验证函数,返回布尔值 + target_path: 目标路径(文件或目录) + lock_suffix: 锁文件后缀 + timeout: 处理超时时间(秒) + Returns: + str: 最终目标路径 + """ + lock_path = target_path + lock_suffix + + while True: + # 检查目标是否存在且有效 + if os.path.exists(target_path): + if verify_func(): + return target_path + else: + # 目标存在但验证失败 + if os.path.exists(lock_path): + now = time.time() + try: + mtime = get_path_mtime(target_path) + if now - mtime < timeout: + time.sleep(1) + continue + else: + try_remove(lock_path) + try_remove(target_path) + except FileNotFoundError: + pass + else: + try_remove(target_path) + else: + + # 尝试获取锁 + file_lock = SoftFileLock(lock_path) + try: + file_lock.acquire(timeout=1) + # 二次验证(可能其他进程已处理完成) + if os.path.exists(target_path) and verify_func(): + return target_path + # 执行处理 + return process_func() + except Timeout: + time.sleep(1) + continue + finally: + if file_lock.is_locked: + file_lock.release() diff --git a/llm_web_kit/model/resource_utils/unzip_ext.py b/llm_web_kit/model/resource_utils/unzip_ext.py index 6579fd62..66622595 100644 --- a/llm_web_kit/model/resource_utils/unzip_ext.py +++ b/llm_web_kit/model/resource_utils/unzip_ext.py @@ -1,12 +1,14 @@ import os -import shutil import tempfile import zipfile +from functools import partial from typing import Optional from llm_web_kit.exception.exception import ModelResourceException from llm_web_kit.libs.logger import mylogger as logger -from llm_web_kit.model.resource_utils.utils import FileLockContext, try_remove +from llm_web_kit.model.resource_utils.download_assets import CACHE_TMP_DIR +from llm_web_kit.model.resource_utils.process_with_lock import \ + process_and_verify_file_with_lock def get_unzip_dir(zip_path: str) -> str: @@ -25,33 +27,40 @@ def get_unzip_dir(zip_path: str) -> str: return os.path.join(zip_dir, base_name + '_unzip') -def check_zip_file(zip_ref: zipfile.ZipFile, target_dir: str) -> bool: +def check_zip_path( + zip_path: str, target_dir: str, password: Optional[str] = None +) -> bool: """Check if the zip file is correctly unzipped to the target directory. Args: - zip_ref (zipfile.ZipFile): The zip file object. + zip_path (str): The path to the zip file. target_dir (str): The target directory. + password (Optional[str], optional): The password to the zip file. Defaults to None. Returns: bool: True if the zip file is correctly unzipped to the target directory, False otherwise. """ - - zip_info_list = [info for info in zip_ref.infolist() if not info.is_dir()] - for info in zip_info_list: - file_path = os.path.join(target_dir, info.filename) - if not os.path.exists(file_path): - return False - if os.path.getsize(file_path) != info.file_size: - return False - return True - - -def unzip_local_file( + if not os.path.exists(zip_path): + logger.error(f'zip file {zip_path} does not exist') + return False + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + if password: + zip_ref.setpassword(password.encode()) + + zip_info_list = [info for info in zip_ref.infolist() if not info.is_dir()] + for info in zip_info_list: + file_path = os.path.join(target_dir, info.filename) + if not os.path.exists(file_path): + return False + if os.path.getsize(file_path) != info.file_size: + return False + return True + + +def unzip_local_file_core( zip_path: str, target_dir: str, password: Optional[str] = None, - exist_ok: bool = True, - lock_timeout: float = 300, ) -> str: """Unzip a zip file to a target directory. @@ -59,64 +68,52 @@ def unzip_local_file( zip_path (str): The path to the zip file. target_dir (str): The directory to unzip the files to. password (Optional[str], optional): The password to the zip file. Defaults to None. - exist_ok (bool, optional): If True, overwrite the files in the target directory if it already exists. - If False, raise an exception if the target directory already exists. Defaults to False. Raises: ModelResourceException: If the zip file does not exist. - ModelResourceException: If the target directory already exists and exist_ok is False + ModelResourceException: If the target directory already exists. Returns: str: The path to the target directory. """ - lock_path = f'{zip_path}.lock' - if not os.path.exists(zip_path): logger.error(f'zip file {zip_path} does not exist') raise ModelResourceException(f'zip file {zip_path} does not exist') - def check_zip(): - with zipfile.ZipFile(zip_path, 'r') as zip_ref: - if password: - zip_ref.setpassword(password.encode()) - return check_zip_file(zip_ref, target_dir) - if os.path.exists(target_dir): - if not exist_ok: - raise ModelResourceException( - f'Target directory {target_dir} already exists' - ) - - if check_zip(): - logger.info(f'zip file {zip_path} is already unzipped to {target_dir}') - return target_dir - else: - logger.warning( - f'zip file {zip_path} is not correctly unzipped to {target_dir}, retry to unzip' - ) - try_remove(target_dir) - - with FileLockContext(lock_path, check_zip, timeout=lock_timeout) as lock: - if lock is True: - logger.info( - f'zip file {zip_path} is already unzipped to {target_dir} while waiting' - ) - return target_dir - - # ensure target directory not exists - - # 创建临时解压目录 - with tempfile.TemporaryDirectory() as temp_dir: + raise ModelResourceException(f'Target directory {target_dir} already exists') + + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + if password: + zip_ref.setpassword(password.encode()) + with tempfile.TemporaryDirectory(dir=CACHE_TMP_DIR) as temp_dir: extract_dir = os.path.join(temp_dir, 'temp') os.makedirs(extract_dir, exist_ok=True) + zip_ref.extractall(extract_dir) + os.rename(extract_dir, target_dir) + return target_dir + - # 解压到临时目录 - with zipfile.ZipFile(zip_path, 'r') as zip_ref: - if password: - zip_ref.setpassword(password.encode()) - zip_ref.extractall(extract_dir) +def unzip_local_file( + zip_path: str, + target_dir: str, + password: Optional[str] = None, + lock_suffix: str = '.unzip.lock', + timeout: float = 60, +) -> str: + """Unzip a zip file to a target directory with a lock. - # 原子性复制到目标目录 - shutil.copytree(extract_dir, target_dir) + Args: + zip_path (str): The path to the zip file. + target_dir (str): The directory to unzip the files to. + password (Optional[str], optional): The password to the zip file. Defaults to None. + timeout (float, optional): The timeout for the lock. Defaults to 60. - return target_dir + Returns: + str: The path to the target directory. + """ + process_func = partial(unzip_local_file_core, zip_path, target_dir, password) + verify_func = partial(check_zip_path, zip_path, target_dir, password) + return process_and_verify_file_with_lock( + process_func, verify_func, target_dir, lock_suffix, timeout + ) diff --git a/llm_web_kit/model/resource_utils/utils.py b/llm_web_kit/model/resource_utils/utils.py index b80a35d0..4ea78dda 100644 --- a/llm_web_kit/model/resource_utils/utils.py +++ b/llm_web_kit/model/resource_utils/utils.py @@ -1,62 +1,51 @@ -import errno import os -import time +import shutil +from llm_web_kit.config.cfg_reader import load_config + + +def decide_cache_dir(): + """Get the cache directory for the web kit. The. + + Returns: + _type_: _description_ + """ + cache_dir = '~/.llm_web_kit_cache' + + if 'WEB_KIT_CACHE_DIR' in os.environ: + cache_dir = os.environ['WEB_KIT_CACHE_DIR'] -def try_remove(path: str): - """Attempt to remove a file, but ignore any exceptions that occur.""" try: - os.remove(path) + config = load_config() + cache_dir = config['resources']['common']['cache_path'] except Exception: pass + if cache_dir.startswith('~/'): + cache_dir = os.path.expanduser(cache_dir) + + cache_tmp_dir = os.path.join(cache_dir, 'tmp') + + return cache_dir, cache_tmp_dir + + +CACHE_DIR, CACHE_TMP_DIR = decide_cache_dir() -class FileLockContext: - """基于文件锁的上下文管理器(跨平台兼容版)""" - - def __init__(self, lock_path: str, check_callback=None, timeout: float = 300): - self.lock_path = lock_path - self.check_callback = check_callback - self.timeout = timeout - self._fd = None - - def __enter__(self): - start_time = time.time() - while True: - if self.check_callback: - if self.check_callback(): - return True - try: - # 原子性创建锁文件(O_EXCL标志是关键) - self._fd = os.open( - self.lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o644 - ) - # 写入进程信息和时间戳 - with os.fdopen(self._fd, 'w') as f: - f.write(f'{os.getpid()}\n{time.time()}') - return self - except OSError as e: - if e.errno != errno.EEXIST: - raise - - # 检查锁是否过期 - try: - with open(self.lock_path, 'r') as f: - pid, timestamp = f.read().split('\n')[:2] - if time.time() - float(timestamp) > self.timeout: - os.remove(self.lock_path) - except (FileNotFoundError, ValueError): - pass - - if time.time() - start_time > self.timeout: - raise TimeoutError(f'Could not acquire lock after {self.timeout}s') - time.sleep(0.1) - - def __exit__(self, exc_type, exc_val, exc_tb): - try: - if self._fd: - os.close(self._fd) - except OSError: - pass - finally: - try_remove(self.lock_path) + +if not os.path.exists(CACHE_DIR): + os.makedirs(CACHE_DIR, exist_ok=True) + +if not os.path.exists(CACHE_TMP_DIR): + os.makedirs(CACHE_TMP_DIR, exist_ok=True) + + +def try_remove(path: str): + """Attempt to remove a file by os.remove or to remove a directory by + shutil.rmtree and ignore exceptions.""" + try: + if os.path.isdir(path): + shutil.rmtree(path) + else: + os.remove(path) + except Exception: + pass diff --git a/llm_web_kit/model/unsafe_words_detector.py b/llm_web_kit/model/unsafe_words_detector.py index 8e336d9f..8c05059c 100644 --- a/llm_web_kit/model/unsafe_words_detector.py +++ b/llm_web_kit/model/unsafe_words_detector.py @@ -10,10 +10,8 @@ from llm_web_kit.libs.standard_utils import json_loads from llm_web_kit.model.basic_functions.format_check import (is_en_letter, is_pure_en_word) -from llm_web_kit.model.resource_utils.download_assets import ( - CACHE_DIR, download_auto_file) -from llm_web_kit.model.resource_utils.singleton_resource_manager import \ - singleton_resource_manager +from llm_web_kit.model.resource_utils import (CACHE_DIR, download_auto_file, + singleton_resource_manager) xyz_language_lst = [ 'ar', @@ -257,5 +255,5 @@ def unsafe_words_filter_overall( unsafe_range = ('L1',) else: unsafe_range = ('L1', 'L2') - hit = (unsafe_word_min_level in unsafe_range) + hit = unsafe_word_min_level in unsafe_range return {'hit_unsafe_words': hit} diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 786070f7..cec5055c 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -3,6 +3,7 @@ cairosvg==2.7.1 click==8.1.8 commentjson==0.9.0 fasttext-wheel==0.9.2 +filelock==3.16.1 jieba-fast==0.53 lightgbm==4.5.0 loguru==0.7.2 diff --git a/tests/llm_web_kit/model/resource_utils/test_download_assets.py b/tests/llm_web_kit/model/resource_utils/test_download_assets.py index b3e55e74..f93dba3f 100644 --- a/tests/llm_web_kit/model/resource_utils/test_download_assets.py +++ b/tests/llm_web_kit/model/resource_utils/test_download_assets.py @@ -1,482 +1,192 @@ -import io +import hashlib import os import tempfile +import time import unittest -from typing import Tuple -from unittest.mock import MagicMock, call, mock_open, patch +from unittest.mock import MagicMock, call, patch from llm_web_kit.exception.exception import ModelResourceException from llm_web_kit.model.resource_utils.download_assets import ( HttpConnection, S3Connection, calc_file_md5, calc_file_sha256, - decide_cache_dir, download_auto_file, download_to_temp, move_to_target, + download_auto_file, download_auto_file_core, download_to_temp, verify_file_checksum) -class Test_decide_cache_dir: - - @patch('os.environ', {'WEB_KIT_CACHE_DIR': '/env/cache_dir'}) - @patch('llm_web_kit.model.resource_utils.download_assets.load_config') - def test_only_env(self, get_configMock): - get_configMock.side_effect = Exception - assert decide_cache_dir() == '/env/cache_dir' - - @patch('os.environ', {}) - @patch('llm_web_kit.model.resource_utils.download_assets.load_config') - def test_only_config(self, get_configMock): - get_configMock.return_value = { - 'resources': {'common': {'cache_path': '/config/cache_dir'}} - } - assert decide_cache_dir() == '/config/cache_dir' - - @patch('os.environ', {}) - @patch('llm_web_kit.model.resource_utils.download_assets.load_config') - def test_default(self, get_configMock): - get_configMock.side_effect = Exception - # if no env or config, use default - assert decide_cache_dir() == os.path.expanduser('~/.llm_web_kit_cache') - - @patch('os.environ', {'WEB_KIT_CACHE_DIR': '/env/cache_dir'}) - @patch('llm_web_kit.model.resource_utils.download_assets.load_config') - def test_both(self, get_configMock): - get_configMock.return_value = { - 'resources': {'common': {'cache_path': '/config/cache_dir'}} - } - # config is preferred - assert decide_cache_dir() == '/config/cache_dir' - - -class Test_calc_file_md5: +class TestChecksumCalculations: def test_calc_file_md5(self): - import hashlib - with tempfile.NamedTemporaryFile() as f: - test_bytes = b'hello world' * 10000 - f.write(test_bytes) + test_data = b'hello world' * 100 + f.write(test_data) f.flush() - assert calc_file_md5(f.name) == hashlib.md5(test_bytes).hexdigest() - - -class Test_calc_file_sha256: + expected = hashlib.md5(test_data).hexdigest() + assert calc_file_md5(f.name) == expected def test_calc_file_sha256(self): - import hashlib - with tempfile.NamedTemporaryFile() as f: - test_bytes = b'hello world' * 10000 - f.write(test_bytes) + test_data = b'hello world' * 100 + f.write(test_data) f.flush() - assert calc_file_sha256(f.name) == hashlib.sha256(test_bytes).hexdigest() - - -def read_mockio_size(mock_io: io.BytesIO, size: int): - while True: - data = mock_io.read(size) - if not data: - break - yield data - - -def get_mock_http_response(test_data: bytes) -> Tuple[MagicMock, int]: - mock_io = io.BytesIO(test_data) - content_length = len(test_data) - response_mock = MagicMock() - response_mock.headers = {'content-length': str(content_length)} - response_mock.iter_content.return_value = read_mockio_size(mock_io, 1024) - return response_mock, content_length - - -def get_mock_s3_response(test_data: bytes) -> Tuple[MagicMock, int]: - mock_io = io.BytesIO(test_data) - content_length = len(test_data) - clientMock = MagicMock() - body = MagicMock() - body.read.side_effect = read_mockio_size(mock_io, 1024) - clientMock.get_object.return_value = {'ContentLength': content_length, 'Body': body} - return clientMock, content_length - - -@patch('llm_web_kit.model.resource_utils.download_assets.get_s3_client') -@patch('llm_web_kit.model.resource_utils.download_assets.split_s3_path') -def test_S3Connection(split_s3_pathMock, get_s3_clientMock): - test_data = b'hello world' * 100 - - # Mock the split_s3_path function - split_s3_pathMock.return_value = ('bucket', 'key') - - # Mock the S3 client - clientMock, content_length = get_mock_s3_response(test_data) - get_s3_clientMock.return_value = clientMock - - # Test the S3Connection class - conn = S3Connection('s3://bucket/key') - assert conn.get_size() == content_length - assert b''.join(conn.read_stream()) == test_data - - -@patch('requests.get') -def test_HttpConnection(requests_get_mock): - test_data = b'hello world' * 100 - response_mock, content_length = get_mock_http_response(test_data) - requests_get_mock.return_value = response_mock + expected = hashlib.sha256(test_data).hexdigest() + assert calc_file_sha256(f.name) == expected + + +class TestConnections: + + @patch('requests.get') + def test_http_connection(self, mock_get): + test_data = b'test data' + mock_response = MagicMock() + mock_response.headers = {'content-length': str(len(test_data))} + mock_response.iter_content.return_value = [test_data] + mock_get.return_value = mock_response + + conn = HttpConnection('http://example.com') + assert conn.get_size() == len(test_data) + assert next(conn.read_stream()) == test_data + del conn + mock_response.close.assert_called() + + @patch('llm_web_kit.model.resource_utils.download_assets.get_s3_client') + def test_s3_connection(self, mock_client): + mock_body = MagicMock() + mock_body.read.side_effect = [b'chunk1', b'chunk2', b''] + mock_client.return_value.get_object.return_value = { + 'ContentLength': 100, + 'Body': mock_body, + } - # Test the HttpConnection class - conn = HttpConnection('http://example.com/file') - assert conn.get_size() == content_length - assert b''.join(conn.read_stream()) == test_data + conn = S3Connection('s3://bucket/key') + assert conn.get_size() == 100 + assert list(conn.read_stream()) == [b'chunk1', b'chunk2'] + del conn + mock_body.close.assert_called() -class TestDownloadAutoFile(unittest.TestCase): +class TestDownloadCoreFunctionality(unittest.TestCase): - @patch('llm_web_kit.model.resource_utils.download_assets.os.path.exists') - @patch('llm_web_kit.model.resource_utils.download_assets.calc_file_md5') - @patch('llm_web_kit.model.resource_utils.download_assets.is_s3_path') @patch('llm_web_kit.model.resource_utils.download_assets.S3Connection') - @patch('llm_web_kit.model.resource_utils.download_assets.HttpConnection') - def test_file_exists_correct_md5( - self, - mock_http_conn, - mock_s3_conn, - mock_is_s3_path, - mock_calc_file_md5, - mock_os_path_exists, - ): - # Arrange - mock_os_path_exists.return_value = True - mock_calc_file_md5.return_value = 'correct_md5' - mock_is_s3_path.return_value = False - mock_http_conn.return_value = MagicMock(get_size=MagicMock(return_value=100)) - - # Act - result = download_auto_file( - 'http://example.com', 'target_path', md5_sum='correct_md5' - ) - - # Assert - assert result == 'target_path' - - mock_os_path_exists.assert_called_once_with('target_path') - mock_calc_file_md5.assert_called_once_with('target_path') - mock_http_conn.assert_not_called() - mock_s3_conn.assert_not_called() - try: - os.remove('target_path.lock') - except FileNotFoundError: - pass - - @patch('llm_web_kit.model.resource_utils.download_assets.os.path.exists') - @patch('llm_web_kit.model.resource_utils.download_assets.calc_file_sha256') - @patch('llm_web_kit.model.resource_utils.download_assets.is_s3_path') - @patch('llm_web_kit.model.resource_utils.download_assets.S3Connection') - @patch('llm_web_kit.model.resource_utils.download_assets.HttpConnection') - def test_file_exists_correct_sha256( - self, - mock_http_conn, - mock_s3_conn, - mock_is_s3_path, - mock_calc_file_sha256, - mock_os_path_exists, - ): - # Arrange - mock_os_path_exists.return_value = True - mock_calc_file_sha256.return_value = 'correct_sha256' - mock_is_s3_path.return_value = False - mock_http_conn.return_value = MagicMock(get_size=MagicMock(return_value=100)) - - # Act - result = download_auto_file( - 'http://example.com', 'sha256_target_path', sha256_sum='correct_sha256' - ) - - # Assert - assert result == 'sha256_target_path' - - mock_os_path_exists.assert_called_once_with('sha256_target_path') - mock_calc_file_sha256.assert_called_once_with('sha256_target_path') - mock_http_conn.assert_not_called() - mock_s3_conn.assert_not_called() - try: - os.remove('sha256_target_path.lock') - except FileNotFoundError: - pass + def test_successful_download(self, mock_conn): + # Mock connection + download_data = b'data' + mock_instance = MagicMock() + mock_instance.read_stream.return_value = [download_data] + mock_instance.get_size.return_value = len(download_data) + mock_conn.return_value = mock_instance - @patch('llm_web_kit.model.resource_utils.download_assets.calc_file_md5') - @patch('llm_web_kit.model.resource_utils.download_assets.os.remove') - @patch('llm_web_kit.model.resource_utils.download_assets.is_s3_path') - @patch('llm_web_kit.model.resource_utils.download_assets.S3Connection') - @patch('llm_web_kit.model.resource_utils.download_assets.HttpConnection') - def test_file_exists_wrong_md5_download_http( - self, - mock_http_conn, - mock_s3_conn, - mock_is_s3_path, - mock_os_remove, - mock_calc_file_md5, - ): - # Arrange - mock_calc_file_md5.return_value = 'wrong_md5' - mock_is_s3_path.return_value = False - - with tempfile.TemporaryDirectory() as tmp_dir: - with open(os.path.join(tmp_dir, 'target_path'), 'wb') as f: - f.write(b'hello world') - response_mock, content_length = get_mock_http_response(b'hello world') - mock_http_conn.return_value = MagicMock( - get_size=MagicMock(return_value=content_length), - read_stream=MagicMock(return_value=response_mock.iter_content()), - ) - - target_path = os.path.join(tmp_dir, 'target_path') - # Act - result = download_auto_file( - 'http://example.com', target_path, md5_sum='correct_md5' - ) - - assert result == target_path - with open(target_path, 'rb') as f: - assert f.read() == b'hello world' + with tempfile.TemporaryDirectory() as tmpdir: + target = os.path.join(tmpdir, 'target.file') + result = download_auto_file_core('s3://bucket/key', target) - @patch('llm_web_kit.model.resource_utils.download_assets.calc_file_sha256') - @patch('llm_web_kit.model.resource_utils.download_assets.os.remove') - @patch('llm_web_kit.model.resource_utils.download_assets.is_s3_path') - @patch('llm_web_kit.model.resource_utils.download_assets.S3Connection') - @patch('llm_web_kit.model.resource_utils.download_assets.HttpConnection') - def test_file_exists_wrong_sha256_download_http( - self, - mock_http_conn, - mock_s3_conn, - mock_is_s3_path, - mock_os_remove, - mock_calc_file_sha256, - ): - # Arrange - mock_calc_file_sha256.return_value = 'wrong_sha256' - mock_is_s3_path.return_value = False - - with tempfile.TemporaryDirectory() as tmp_dir: - with open(os.path.join(tmp_dir, 'target_path'), 'wb') as f: - f.write(b'hello world') - response_mock, content_length = get_mock_http_response(b'hello world') - mock_http_conn.return_value = MagicMock( - get_size=MagicMock(return_value=content_length), - read_stream=MagicMock(return_value=response_mock.iter_content()), - ) - - target_path = os.path.join(tmp_dir, 'target_path') - # Act - result = download_auto_file( - 'http://example.com', target_path, sha256_sum='correct_sha256' - ) - - assert result == target_path - with open(target_path, 'rb') as f: - assert f.read() == b'hello world' + assert result == target + assert os.path.exists(target) - @patch('llm_web_kit.model.resource_utils.download_assets.calc_file_md5') - @patch('llm_web_kit.model.resource_utils.download_assets.os.remove') - @patch('llm_web_kit.model.resource_utils.download_assets.is_s3_path') - @patch('llm_web_kit.model.resource_utils.download_assets.S3Connection') @patch('llm_web_kit.model.resource_utils.download_assets.HttpConnection') - def test_file_not_exists_download_http( - self, - mock_http_conn, - mock_s3_conn, - mock_is_s3_path, - mock_os_remove, - mock_calc_file_md5, - ): - # Arrange - mock_is_s3_path.return_value = False - - with tempfile.TemporaryDirectory() as tmp_dir: - response_mock, content_length = get_mock_http_response(b'hello world') - mock_http_conn.return_value = MagicMock( - get_size=MagicMock(return_value=content_length), - read_stream=MagicMock(return_value=response_mock.iter_content()), - ) - - target_path = os.path.join(tmp_dir, 'target_path') - # Act - result = download_auto_file( - 'http://example.com', target_path, md5_sum='correct_md5' - ) - - assert result == target_path - with open(target_path, 'rb') as f: - assert f.read() == b'hello world' - - -# def verify_file_checksum( -# file_path: str, md5_sum: Optional[str] = None, sha256_sum: Optional[str] = None -# ) -> bool: -# """校验文件哈希值.""" -# if not sum([bool(md5_sum), bool(sha256_sum)]) == 1: -# raise ModelResourceException('Exactly one of md5_sum or sha256_sum must be provided') - -# if md5_sum: -# actual = calc_file_md5(file_path) -# if actual != md5_sum: -# logger.warning( -# f'MD5 mismatch: expect {md5_sum[:8]}..., got {actual[:8]}...' -# ) -# return False - -# if sha256_sum: -# actual = calc_file_sha256(file_path) -# if actual != sha256_sum: -# logger.warning( -# f'SHA256 mismatch: expect {sha256_sum[:8]}..., got {actual[:8]}...' -# ) -# return False - - -# return True -class Test_verify_file_checksum(unittest.TestCase): - # test pass two value - # test pass two None - # test pass one value correct - # test pass one value incorrect + def test_size_mismatch(self, mock_conn): + download_data = b'data' + mock_instance = MagicMock() + mock_instance.read_stream.return_value = [download_data] + mock_instance.get_size.return_value = len(download_data) + 1 - @patch('llm_web_kit.model.resource_utils.download_assets.calc_file_md5') - @patch('llm_web_kit.model.resource_utils.download_assets.calc_file_sha256') - def test_pass_two_value(self, mock_calc_file_sha256, mock_calc_file_md5): - file_path = 'file_path' - md5_sum = 'md5_sum' - sha256_sum = 'sha256_sum' - mock_calc_file_md5.return_value = md5_sum - mock_calc_file_sha256.return_value = sha256_sum - # will raise ModelResourceException - with self.assertRaises(ModelResourceException): - verify_file_checksum(file_path, md5_sum, sha256_sum) + mock_conn.return_value = mock_instance - @patch('llm_web_kit.model.resource_utils.download_assets.calc_file_md5') - @patch('llm_web_kit.model.resource_utils.download_assets.calc_file_sha256') - def test_pass_two_None(self, mock_calc_file_sha256, mock_calc_file_md5): - file_path = 'file_path' - md5_sum = None - sha256_sum = None - # will raise ModelResourceException - with self.assertRaises(ModelResourceException): - verify_file_checksum(file_path, md5_sum, sha256_sum) + with tempfile.TemporaryDirectory() as tmpdir: + target = os.path.join(tmpdir, 'target.file') + with self.assertRaises(ModelResourceException): + download_auto_file_core('http://example.com', target) - @patch('llm_web_kit.model.resource_utils.download_assets.calc_file_md5') - @patch('llm_web_kit.model.resource_utils.download_assets.calc_file_sha256') - def test_pass_one_value_correct(self, mock_calc_file_sha256, mock_calc_file_md5): - file_path = 'file_path' - md5_sum = 'md5_sum' - sha256_sum = None - mock_calc_file_md5.return_value = md5_sum - mock_calc_file_sha256.return_value = None - assert verify_file_checksum(file_path, md5_sum, sha256_sum) is True - @patch('llm_web_kit.model.resource_utils.download_assets.calc_file_md5') - @patch('llm_web_kit.model.resource_utils.download_assets.calc_file_sha256') - def test_pass_one_value_incorrect(self, mock_calc_file_sha256, mock_calc_file_md5): - file_path = 'file_path' - md5_sum = 'md5_sum' - sha256_sum = None - mock_calc_file_md5.return_value = 'wrong_md5' - mock_calc_file_sha256.return_value = None - assert verify_file_checksum(file_path, md5_sum, sha256_sum) is False +class TestDownloadToTemp: + def test_normal_download(self): + mock_conn = MagicMock() + mock_conn.read_stream.return_value = [b'chunk1', b'chunk2'] + mock_progress = MagicMock() -class TestDownloadToTemp(unittest.TestCase): + with tempfile.TemporaryDirectory() as tmpdir: + temp_path = os.path.join(tmpdir, 'temp.file') + download_to_temp(mock_conn, mock_progress, temp_path) - def setUp(self): - self.mock_conn = MagicMock() - self.mock_progress = MagicMock() - - # mock_open - @patch('builtins.open', new_callable=mock_open) - @patch('tempfile.NamedTemporaryFile') - def test_normal_download(self, mock_temp, mock_open_func): - # 模拟下载流数据 - test_data = [b'chunk1', b'chunk2', b'chunk3'] - self.mock_conn.read_stream.return_value = iter(test_data) - - # 配置临时文件mock - mock_temp.return_value.__enter__.return_value.name = '/tmp/fake.tmp' - - result = download_to_temp(self.mock_conn, self.mock_progress) - - mock_open_func.return_value.write.assert_has_calls( - [call(b'chunk1'), call(b'chunk2'), call(b'chunk3')] - ) - # 验证进度条更新 - self.mock_progress.update.assert_has_calls( - [call(6), call(6), call(6)] # 每个chunk的长度是6 - ) - self.assertEqual(result, '/tmp/fake.tmp') - - @patch('builtins.open', new_callable=mock_open) - @patch('tempfile.NamedTemporaryFile') - def test_exception_handling(self, mock_temp, mock_open_func): - # 模拟写入时发生异常 - self.mock_conn.read_stream.return_value = iter([b'data']) - mock_temp.return_value.__enter__.return_value.name = '/tmp/fail.tmp' - - # file_mock = mock_temp.return_value.__enter__.return_value.__enter__.return_value - # file_mock.write.side_effect = IOError("Disk failure") - - mock_open_func.return_value.write.side_effect = IOError('Disk failure') - with self.assertRaises(IOError): - download_to_temp(self.mock_conn, self.mock_progress) + with open(temp_path, 'rb') as f: + assert f.read() == b'chunk1chunk2' + mock_progress.update.assert_has_calls([call(6), call(6)]) def test_empty_chunk_handling(self): - # 测试包含空chunk的情况 - self.mock_conn.read_stream.return_value = iter([b'', b'valid', b'']) + mock_conn = MagicMock() + mock_conn.read_stream.return_value = [b'', b'data', b''] + mock_progress = MagicMock() - with tempfile.NamedTemporaryFile(delete=False) as real_temp: - with patch('tempfile.NamedTemporaryFile') as mock_temp: - mock_temp.return_value.__enter__.return_value.name = real_temp.name - download_to_temp(self.mock_conn, self.mock_progress) + with tempfile.TemporaryDirectory() as tmpdir: + temp_path = os.path.join(tmpdir, 'temp.file') + download_to_temp(mock_conn, mock_progress, temp_path) - # 验证只有有效chunk被写入 - with open(real_temp.name, 'rb') as f: - self.assertEqual(f.read(), b'valid') - os.unlink(real_temp.name) + with open(temp_path, 'rb') as f: + assert f.read() == b'data' -class TestMoveToTarget(unittest.TestCase): +class TestVerifyChecksum(unittest.TestCase): def setUp(self): - self.tmp_dir = tempfile.TemporaryDirectory() - self.target_path = os.path.join(self.tmp_dir.name, 'subdir/target.file') + self.temp_file = tempfile.NamedTemporaryFile() + self.temp_file.write(b'test data') + self.temp_file.flush() def tearDown(self): - self.tmp_dir.cleanup() + self.temp_file.close() - def test_normal_move(self): - # 创建测试文件 - tmp_path = os.path.join(self.tmp_dir.name, 'test.tmp') - with open(tmp_path, 'wb') as f: - f.write(b'test content') - - move_to_target(tmp_path, self.target_path, 12) - - # 验证结果 - self.assertTrue(os.path.exists(self.target_path)) - self.assertFalse(os.path.exists(tmp_path)) - self.assertEqual(os.path.getsize(self.target_path), 12) + @patch('llm_web_kit.model.resource_utils.download_assets.calc_file_md5') + def test_valid_md5(self, mock_md5): + mock_md5.return_value = 'correct_md5' + assert verify_file_checksum(self.temp_file.name, md5_sum='correct_md5') is True - def test_size_mismatch(self): - tmp_path = os.path.join(self.tmp_dir.name, 'bad.tmp') - with open(tmp_path, 'wb') as f: - f.write(b'short') + @patch('llm_web_kit.model.resource_utils.download_assets.calc_file_sha256') + def test_invalid_sha256(self, mock_sha): + mock_sha.return_value = 'wrong_sha' + assert verify_file_checksum(self.temp_file.name, sha256_sum='correct_sha') is False - with self.assertRaisesRegex(ModelResourceException, 'size mismatch'): - move_to_target(tmp_path, self.target_path, 100) + def test_no_such_file(self): + assert verify_file_checksum('dummy', md5_sum='a') is False - def test_directory_creation(self): - tmp_path = os.path.join(self.tmp_dir.name, 'test.tmp') - with open(tmp_path, 'wb') as f: - f.write(b'content') + def test_invalid_arguments(self): + with self.assertRaises(ModelResourceException): + verify_file_checksum('dummy', md5_sum='a', sha256_sum='b') - # 目标目录不存在 - deep_path = os.path.join(self.tmp_dir.name, 'a/b/c/target.file') - move_to_target(tmp_path, deep_path, 7) - self.assertTrue(os.path.exists(deep_path)) +class TestDownloadAutoFile(unittest.TestCase): + @patch( + 'llm_web_kit.model.resource_utils.download_assets.process_and_verify_file_with_lock' + ) + @patch('llm_web_kit.model.resource_utils.download_assets.verify_file_checksum') + @patch('llm_web_kit.model.resource_utils.download_assets.download_auto_file_core') + def test_download(self, mock_download, mock_verify, mock_process): + def download_func(resource_path, target_path): + dir = os.path.dirname(target_path) + os.makedirs(dir, exist_ok=True) + with open(target_path, 'w') as f: + time.sleep(1) + f.write(resource_path) + + mock_download.side_effect = download_func + + def verify_func(target_path, md5 ,sha): + with open(target_path, 'r') as f: + return f.read() == md5 + + mock_verify.side_effect = verify_func + + def process_and_verify( + process_func, verify_func, target_path, lock_suffix, timeout + ): + process_func() + if verify_func(): + return target_path + + mock_process.side_effect = process_and_verify + with tempfile.TemporaryDirectory() as tmpdir: + resource_url = 'http://example.com/resource' + target_dir = os.path.join(tmpdir, 'target') + result = download_auto_file(resource_url, target_dir, md5_sum=resource_url) + assert result == os.path.join(tmpdir, 'target') if __name__ == '__main__': diff --git a/tests/llm_web_kit/model/resource_utils/test_process_with_lock.py b/tests/llm_web_kit/model/resource_utils/test_process_with_lock.py new file mode 100644 index 00000000..69b24db8 --- /dev/null +++ b/tests/llm_web_kit/model/resource_utils/test_process_with_lock.py @@ -0,0 +1,310 @@ +import multiprocessing +import os +import shutil +import tempfile +import time +import unittest +from functools import partial +from unittest.mock import Mock, patch + +from filelock import Timeout + +from llm_web_kit.model.resource_utils.process_with_lock import ( + get_path_mtime, process_and_verify_file_with_lock) + + +class TestGetPathMtime(unittest.TestCase): + """测试 get_path_mtime 函数.""" + + def setUp(self): + self.test_dir = 'test_dir' + self.test_file = 'test_file.txt' + os.makedirs(self.test_dir, exist_ok=True) + with open(self.test_file, 'w') as f: + f.write('test') + + def tearDown(self): + if os.path.exists(self.test_file): + os.remove(self.test_file) + if os.path.exists(self.test_dir): + shutil.rmtree(self.test_dir) + + def test_file_mtime(self): + # 测试文件路径 + expected_mtime = os.path.getmtime(self.test_file) + result = get_path_mtime(self.test_file) + self.assertEqual(result, expected_mtime) + + def test_dir_with_files(self): + # 测试包含文件的目录 + file1 = os.path.join(self.test_dir, 'file1.txt') + file2 = os.path.join(self.test_dir, 'file2.txt') + + with open(file1, 'w') as f: + f.write('test1') + time.sleep(0.1) # 确保mtime不同 + with open(file2, 'w') as f: + f.write('test2') + + latest_mtime = max(os.path.getmtime(file1), os.path.getmtime(file2)) + result = get_path_mtime(self.test_dir) + self.assertEqual(result, latest_mtime) + + def test_empty_dir(self): + # 测试空目录(预期返回0) + empty_dir = 'empty_dir' + os.makedirs(empty_dir, exist_ok=True) + try: + result = get_path_mtime(empty_dir) + self.assertEqual(result, None) # 根据当前函数逻辑返回0 + finally: + shutil.rmtree(empty_dir) + + +class TestProcessAndVerifyFileWithLock(unittest.TestCase): + """测试 process_and_verify_file_with_lock 函数.""" + + def setUp(self): + self.target_path = 'target.txt' + self.lock_path = self.target_path + '.lock' + + def tearDown(self): + if os.path.exists(self.target_path): + os.remove(self.target_path) + if os.path.exists(self.lock_path): + os.remove(self.lock_path) + + @patch('os.path.exists') + @patch('llm_web_kit.model.resource_utils.process_with_lock.try_remove') + def test_target_exists_and_valid(self, mock_remove, mock_exists): + # 目标存在且验证成功 + mock_exists.side_effect = lambda path: path == self.target_path + process_func = Mock() + verify_func = Mock(return_value=True) + + result = process_and_verify_file_with_lock( + process_func, verify_func, self.target_path + ) + + self.assertEqual(result, self.target_path) + process_func.assert_not_called() + verify_func.assert_called_once() + + @patch('os.path.exists') + @patch('llm_web_kit.model.resource_utils.process_with_lock.try_remove') + @patch('time.sleep') + def test_target_not_exists_acquire_lock_success( + self, mock_sleep, mock_remove, mock_exists + ): + # 目标不存在,成功获取锁 + mock_exists.side_effect = lambda path: False + process_func = Mock(return_value=self.target_path) + verify_func = Mock() + + result = process_and_verify_file_with_lock( + process_func, verify_func, self.target_path + ) + + process_func.assert_called_once() + self.assertEqual(result, self.target_path) + + @patch('os.path.exists') + @patch('llm_web_kit.model.resource_utils.process_with_lock.try_remove') + @patch('time.sleep') + def test_second_validation_after_lock(self, mock_sleep, mock_remove, mock_exists): + # 获取锁后二次验证成功(其他进程已完成) + mock_exists.side_effect = lambda path: { + self.lock_path: False, + self.target_path: True, + } + verify_func = Mock(return_value=True) + process_func = Mock() + + result = process_and_verify_file_with_lock( + process_func, verify_func, self.target_path + ) + + process_func.assert_not_called() + self.assertEqual(result, self.target_path) + + @patch('os.path.exists') + @patch('llm_web_kit.model.resource_utils.process_with_lock.SoftFileLock') + @patch('time.sleep') + def test_lock_timeout_retry_success(self, mock_sleep, mock_lock, mock_exists): + # 第一次获取锁超时,重试后成功 + lock_str = self.target_path + '.lock' + mock_exists.return_value = False + lock_instance = Mock() + mock_lock.return_value = lock_instance + + # 第一次acquire抛出Timeout,第二次成功 + lock_instance.acquire.side_effect = [Timeout(lock_str), None] + process_func = Mock(return_value=self.target_path) + verify_func = Mock() + + process_and_verify_file_with_lock(process_func, verify_func, self.target_path) + + self.assertEqual(lock_instance.acquire.call_count, 2) + process_func.assert_called_once() + + +class TestProcessWithLockRealFiles(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.target_name = 'test_target.dat' + self.lock_suffix = '.lock' + + self.target_path = os.path.join(self.temp_dir.name, self.target_name) + self.lock_path = self.target_path + self.lock_suffix + + def tearDown(self): + self.temp_dir.cleanup() + + def test_zombie_process_recovery(self): + # 准备过期文件和僵尸锁文件 + with open(self.target_path, 'w') as f: + f.write('old content') + with open(self.lock_path, 'w') as f: + f.write('lock') + + # 设置文件修改时间为超时前(60秒超时,设置为2分钟前) + old_mtime = time.time() - 59 + os.utime(self.target_path, (old_mtime, old_mtime)) + + # Mock验证函数和处理函数 + def verify_func(): + # 验证文件内容 + with open(self.target_path) as f: + content = f.read() + return content == 'new content' + + process_called = [False] # 使用list实现nonlocal效果 + + def real_process(): + # 真实写入文件 + with open(self.target_path, 'w') as f: + f.write('new content') + process_called[0] = True + return self.target_path + + # 执行测试 + result = process_and_verify_file_with_lock( + process_func=real_process, + verify_func=verify_func, + target_path=self.target_path, + lock_suffix=self.lock_suffix, + timeout=60, + ) + + # 验证结果 + self.assertTrue(os.path.exists(self.target_path)) + self.assertFalse(os.path.exists(self.lock_path)) + self.assertTrue(process_called[0]) + self.assertEqual(result, self.target_path) + + # 验证文件内容 + with open(self.target_path) as f: + content = f.read() + self.assertEqual(content, 'new content') + + +def dummy_process_func(target_path, data_content): + # 写入文件 + with open(target_path, 'w') as f: + time.sleep(1) + f.write(data_content) + return target_path + + +def dummy_verify_func(target_path, data_content): + # 验证文件内容 + try: + with open(target_path) as f: + content = f.read() + except FileNotFoundError: + return False + return content == data_content + + +class TestMultiProcessWithLock(unittest.TestCase): + + def setUp(self): + # 临时文件夹 + self.temp_dir = tempfile.TemporaryDirectory() + self.target_name = 'test_target.dat' + self.lock_suffix = '.lock' + self.data_content = 'test content' + self.target_path = os.path.join(self.temp_dir.name, self.target_name) + self.lock_path = self.target_path + self.lock_suffix + + def tearDown(self): + self.temp_dir.cleanup() + + # 开始时什么都没有,多个进程尝试拿锁,一个进程拿到锁并用1s写入一个资源文件(指定内容,verify尝试检查)。然后所有的进程都发现这个文件被写入,成功返回。资源文件存在,并且锁文件被删掉 + def test_multi_process_with_lock(self): + process_func = partial(dummy_process_func, self.target_path, self.data_content) + verify_func = partial(dummy_verify_func, self.target_path, self.data_content) + + # 多进程同时执行 + # 构建多个进程 然后同时执行 + pool = multiprocessing.Pool(16) + process = partial( + process_and_verify_file_with_lock, + process_func, + verify_func, + self.target_path, + self.lock_suffix, + ) + results = pool.map(process, [60] * 32) + pool.close() + # 检查文件是否存在 + self.assertTrue(os.path.exists(self.target_path)) + self.assertFalse(os.path.exists(self.lock_path)) + # 检查结果 + for result in results: + self.assertEqual(result, self.target_path) + # 检查文件内容 + with open(self.target_path) as f: + content = f.read() + self.assertEqual(content, self.data_content) + + # 开始时有个文件,这个文件mtime比较早,且有个锁文件(模拟之前下载失败了)。然后同样多进程尝试执行,有某个进程尝试删掉这个文件和锁文件,然后还原为场景1,最终大家成功返回。 + def test_multi_process_with_zombie_files(self): + # 准备过期文件和僵尸锁文件 + with open(self.target_path, 'w') as f: + f.write('old content') + with open(self.lock_path, 'w') as f: + f.write('lock') + + # 设置文件修改时间为超时前(60秒超时,设置为2分钟前) + old_mtime = time.time() - 59 + os.utime(self.target_path, (old_mtime, old_mtime)) + + process_func = partial(dummy_process_func, self.target_path, self.data_content) + verify_func = partial(dummy_verify_func, self.target_path, self.data_content) + + # 多进程同时执行 + pool = multiprocessing.Pool(16) + process = partial( + process_and_verify_file_with_lock, + process_func, + verify_func, + self.target_path, + self.lock_suffix, + ) + results = pool.map(process, [60] * 32) + pool.close() + # 检查文件是否存在 + self.assertTrue(os.path.exists(self.target_path)) + self.assertFalse(os.path.exists(self.lock_path)) + # 检查结果 + for result in results: + self.assertEqual(result, self.target_path) + # 检查文件内容 + with open(self.target_path) as f: + content = f.read() + self.assertEqual(content, self.data_content) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/llm_web_kit/model/resource_utils/test_resource_utils.py b/tests/llm_web_kit/model/resource_utils/test_resource_utils.py index 15448bcf..0206e631 100644 --- a/tests/llm_web_kit/model/resource_utils/test_resource_utils.py +++ b/tests/llm_web_kit/model/resource_utils/test_resource_utils.py @@ -1,9 +1,7 @@ -import errno import os -import unittest -from unittest.mock import MagicMock, mock_open, patch +from unittest.mock import patch -from llm_web_kit.model.resource_utils.utils import FileLockContext, try_remove +from llm_web_kit.model.resource_utils.utils import decide_cache_dir, try_remove class Test_try_remove: @@ -20,123 +18,42 @@ def test_remove_exception(self, removeMock): removeMock.assert_called_once_with('path') -class TestFileLock(unittest.TestCase): - - def setUp(self): - self.lock_path = 'test.lock' - - @patch('os.fdopen') - @patch('os.open') - @patch('os.close') - @patch('os.remove') - def test_acquire_and_release_lock( - self, mock_remove, mock_close, mock_open, mock_os_fdopen - ): - # 模拟成功获取锁 - mock_open.return_value = 123 # 假设文件描述符为123 - # 模拟文件描述符 - mock_fd = MagicMock() - mock_fd.__enter__.return_value = mock_fd - mock_fd.write.return_value = None - mock_os_fdopen.return_value = mock_fd - - with FileLockContext(self.lock_path): - mock_open.assert_called_once_with( - self.lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o644 - ) - mock_close.assert_called_once_with(123) - mock_remove.assert_called_once_with(self.lock_path) - - @patch('os.fdopen') - @patch('os.open') - @patch('builtins.open', new_callable=mock_open, read_data='1234\n100') - @patch('time.time') - @patch('os.remove') - def test_remove_stale_lock( - self, mock_remove, mock_time, mock_file_open, mock_os_open, mock_os_fdopen - ): - # 第一次尝试创建锁文件失败(锁已存在) - mock_os_open.side_effect = [ - OSError(errno.EEXIST, 'File exists'), - 123, # 第二次成功 - ] - - # 模拟文件描述符 - mock_fd = MagicMock() - mock_fd.__enter__.return_value = mock_fd - mock_fd.write.return_value = None - mock_os_fdopen.return_value = mock_fd - - # 当前时间设置为超过超时时间(timeout=300) - mock_time.return_value = 401 # 100 + 300 + 1 - - with FileLockContext(self.lock_path, timeout=300): - mock_remove.assert_called_once_with(self.lock_path) - mock_os_open.assert_any_call( - self.lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o644 - ) - - @patch('os.open') - @patch('time.time') - def test_timeout_acquiring_lock(self, mock_time, mock_os_open): - # 总是返回EEXIST错误 - mock_os_open.side_effect = OSError(errno.EEXIST, 'File exists') - # 时间累计超过超时时间 - start_time = 1000 - mock_time.side_effect = [ - start_time, - start_time + 301, - start_time + 302, - start_time + 303, - ] - - with self.assertRaises(TimeoutError): - with FileLockContext(self.lock_path, timeout=300): - pass - - @patch('os.open') - def test_other_os_error(self, mock_os_open): - # 模拟其他OS错误(如权限不足) - mock_os_open.side_effect = OSError(errno.EACCES, 'Permission denied') - with self.assertRaises(OSError): - with FileLockContext(self.lock_path): - pass - - @patch('os.close') - @patch('os.remove') - def test_cleanup_on_exit(self, mock_remove, mock_close): - - mock_close.side_effect = None - # 确保退出上下文时执行清理 - lock_path = 'test.lock' - lock = FileLockContext(lock_path) - lock._fd = 123 # 模拟已打开的文件描述符 - lock.__exit__('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!', None, None) - mock_remove.assert_called_once_with(self.lock_path) - - @patch('os.remove') - def test_cleanup_failure_handled(self, mock_remove): - # 模拟删除锁文件时失败 - mock_remove.side_effect = OSError - lock = FileLockContext(self.lock_path) - lock._fd = 123 - # 不应抛出异常 - lock.__exit__(None, None, None) - - @patch('os.getpid') - @patch('time.time') - def test_lock_file_content(self, mock_time, mock_pid): - # 验证锁文件内容格式 - mock_pid.return_value = 9999 - mock_time.return_value = 123456.789 - - with patch('os.open') as mock_os_open: - mock_os_open.return_value = 123 - with patch('os.fdopen') as mock_fdopen: - # 模拟写入文件描述符 - mock_file = MagicMock() - mock_fdopen.return_value.__enter__.return_value = mock_file - - with FileLockContext(self.lock_path): - mock_fdopen.assert_called_once_with(123, 'w') - mock_file.write.assert_called_once_with('9999\n123456.789') +class TestDecideCacheDir: + + @patch('os.environ', {'WEB_KIT_CACHE_DIR': '/env/cache_dir'}) + @patch('llm_web_kit.model.resource_utils.utils.load_config') + def test_only_env(self, get_config_mock): + get_config_mock.side_effect = Exception + cache_dir, cache_tmp_dir = decide_cache_dir() + assert cache_dir == '/env/cache_dir' + assert cache_tmp_dir == '/env/cache_dir/tmp' + + @patch('os.environ', {}) + @patch('llm_web_kit.model.resource_utils.utils.load_config') + def test_only_config(self, get_config_mock): + get_config_mock.return_value = { + 'resources': {'common': {'cache_path': '/config/cache_dir'}} + } + + cache_dir, cache_tmp_dir = decide_cache_dir() + assert cache_dir == '/config/cache_dir' + assert cache_tmp_dir == '/config/cache_dir/tmp' + + @patch('os.environ', {}) + @patch('llm_web_kit.model.resource_utils.utils.load_config') + def test_default(self, get_config_mock): + get_config_mock.side_effect = Exception + env_result = os.path.expanduser('~/.llm_web_kit_cache') + cache_dir, cache_tmp_dir = decide_cache_dir() + assert cache_dir == env_result + assert cache_tmp_dir == f'{env_result}/tmp' + + @patch('os.environ', {'WEB_KIT_CACHE_DIR': '/env/cache_dir'}) + @patch('llm_web_kit.model.resource_utils.utils.load_config') + def test_priority(self, get_config_mock): + get_config_mock.return_value = { + 'resources': {'common': {'cache_path': '/config/cache_dir'}} + } + cache_dir, cache_tmp_dir = decide_cache_dir() + assert cache_dir == '/config/cache_dir' + assert cache_tmp_dir == '/config/cache_dir/tmp' diff --git a/tests/llm_web_kit/model/resource_utils/test_unzip_ext.py b/tests/llm_web_kit/model/resource_utils/test_unzip_ext.py index bf514d14..4daeccd1 100644 --- a/tests/llm_web_kit/model/resource_utils/test_unzip_ext.py +++ b/tests/llm_web_kit/model/resource_utils/test_unzip_ext.py @@ -2,122 +2,153 @@ import tempfile import zipfile from unittest import TestCase +from unittest.mock import patch from llm_web_kit.exception.exception import ModelResourceException -from llm_web_kit.model.resource_utils.unzip_ext import (check_zip_file, +from llm_web_kit.model.resource_utils.unzip_ext import (check_zip_path, get_unzip_dir, - unzip_local_file) - - -def get_assert_dir(): - file_path = os.path.abspath(__file__) - assert_dir = os.path.join(os.path.dirname(os.path.dirname(file_path)), 'assets') - return assert_dir + unzip_local_file, + unzip_local_file_core) class TestGetUnzipDir(TestCase): - - def test_get_unzip_dir_case1(self): - assert get_unzip_dir('/path/to/test.zip') == '/path/to/test_unzip' - - def test_get_unzip_dir_case2(self): - assert get_unzip_dir('/path/to/test') == '/path/to/test_unzip' - - -class TestCheckZipFile(TestCase): - # # test_zip/ - # # ├── test1.txt "test1\n" - # # ├── folder1 - # # │ └── test2.txt "test2\n" - # # └── folder2 - - def get_zipfile(self): - # 创建一个临时文件夹 - zip_path = os.path.join(get_assert_dir(), 'zip_demo.zip') - zip_file = zipfile.ZipFile(zip_path, 'r') - return zip_file - - def test_check_zip_file_cese1(self): - zip_file = self.get_zipfile() - # # test_zip/ - # # ├── test1.txt - # # ├── folder1 - # # │ └── test2.txt - # # └── folder2 - - with tempfile.TemporaryDirectory() as temp_dir: - root_dir = os.path.join(temp_dir, 'test_zip') - os.makedirs(os.path.join(root_dir, 'test_zip')) - os.makedirs(os.path.join(root_dir, 'folder1')) - os.makedirs(os.path.join(root_dir, 'folder2')) - with open(os.path.join(root_dir, 'test1.txt'), 'w') as f: - f.write('test1\n') - with open(os.path.join(root_dir, 'folder1', 'test2.txt'), 'w') as f: - f.write('test2\n') - - assert check_zip_file(zip_file, temp_dir) is True - - def test_check_zip_file_cese2(self): - zip_file = self.get_zipfile() - with tempfile.TemporaryDirectory() as temp_dir: - root_dir = os.path.join(temp_dir, 'test_zip') - os.makedirs(os.path.join(root_dir, 'test_zip')) - os.makedirs(os.path.join(root_dir, 'folder1')) - with open(os.path.join(root_dir, 'test1.txt'), 'w') as f: - f.write('test1\n') - with open(os.path.join(root_dir, 'folder1', 'test2.txt'), 'w') as f: - f.write('test2\n') - - assert check_zip_file(zip_file, temp_dir) is True - - def test_check_zip_file_cese3(self): - zip_file = self.get_zipfile() - with tempfile.TemporaryDirectory() as temp_dir: - root_dir = os.path.join(temp_dir, 'test_zip') - os.makedirs(os.path.join(root_dir, 'test_zip')) - os.makedirs(os.path.join(root_dir, 'folder1')) - with open(os.path.join(root_dir, 'folder1', 'test2.txt'), 'w') as f: - f.write('test2\n') - - assert check_zip_file(zip_file, temp_dir) is False - - def test_check_zip_file_cese4(self): - zip_file = self.get_zipfile() - with tempfile.TemporaryDirectory() as temp_dir: - root_dir = os.path.join(temp_dir, 'test_zip') - os.makedirs(os.path.join(root_dir, 'test_zip')) - os.makedirs(os.path.join(root_dir, 'folder1')) - with open(os.path.join(root_dir, 'test1.txt'), 'w') as f: - f.write('test1\n') - with open(os.path.join(root_dir, 'folder1', 'test2.txt'), 'w') as f: - f.write('test123\n') - - assert check_zip_file(zip_file, temp_dir) is False - - -def test_unzip_local_file(): - # creat a temp dir to test - with tempfile.TemporaryDirectory() as temp_dir1, tempfile.TemporaryDirectory() as temp_dir2: - # test unzip a zip file with 2 txt files - zip_path = os.path.join(temp_dir1, 'test.zip') - target_dir = os.path.join(temp_dir2, 'target') - # zip 2 txt files - with zipfile.ZipFile(zip_path, 'w') as zipf: - zipf.writestr('test1.txt', 'This is a test file') - zipf.writestr('test2.txt', 'This is another test file') - - unzip_local_file(zip_path, target_dir) - with open(os.path.join(target_dir, 'test1.txt')) as f: - assert f.read() == 'This is a test file' - with open(os.path.join(target_dir, 'test2.txt')) as f: - assert f.read() == 'This is another test file' - - unzip_local_file(zip_path, target_dir, exist_ok=True) - with open(os.path.join(target_dir, 'test1.txt')) as f: - assert f.read() == 'This is a test file' - with open(os.path.join(target_dir, 'test2.txt')) as f: - assert f.read() == 'This is another test file' - try: - unzip_local_file(zip_path, target_dir, exist_ok=False) - except ModelResourceException as e: - assert e.custom_message == f'Target directory {target_dir} already exists' + def test_get_unzip_dir_with_zip_extension(self): + self.assertEqual( + get_unzip_dir('/path/to/test.zip'), + '/path/to/test_unzip', + ) + + def test_get_unzip_dir_without_zip_extension(self): + self.assertEqual( + get_unzip_dir('/path/to/test'), + '/path/to/test_unzip', + ) + + +class TestCheckZipPath(TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.zip_path = os.path.join(self.temp_dir.name, 'test.zip') + self.target_dir = os.path.join(self.temp_dir.name, 'target') + os.makedirs(self.target_dir, exist_ok=True) + + def tearDown(self): + self.temp_dir.cleanup() + + def test_check_valid_zip(self): + # Create valid zip with test file + with zipfile.ZipFile(self.zip_path, 'w') as zipf: + zipf.writestr('file.txt', 'content') + # Properly extract files + with zipfile.ZipFile(self.zip_path, 'r') as zip_ref: + zip_ref.extractall(self.target_dir) + + self.assertTrue(check_zip_path(self.zip_path, self.target_dir)) + + def test_check_missing_file(self): + with zipfile.ZipFile(self.zip_path, 'w') as zipf: + zipf.writestr('file.txt', 'content') + # Extract and then delete file + with zipfile.ZipFile(self.zip_path, 'r') as zip_ref: + zip_ref.extractall(self.target_dir) + os.remove(os.path.join(self.target_dir, 'file.txt')) + + self.assertFalse(check_zip_path(self.zip_path, self.target_dir)) + + def test_check_corrupted_file_size(self): + with zipfile.ZipFile(self.zip_path, 'w') as zipf: + zipf.writestr('file.txt', 'original content') + # Modify extracted file + with zipfile.ZipFile(self.zip_path, 'r') as zip_ref: + zip_ref.extractall(self.target_dir) + with open(os.path.join(self.target_dir, 'file.txt'), 'w') as f: + f.write('modified') + + self.assertFalse(check_zip_path(self.zip_path, self.target_dir)) + + def test_password_protected_zip(self): + password = 'secret' + # Create encrypted zip + with zipfile.ZipFile(self.zip_path, 'w') as zipf: + zipf.writestr('file.txt', 'content') + zipf.setpassword(password.encode()) + # Extract with correct password + with zipfile.ZipFile(self.zip_path, 'r') as zip_ref: + zip_ref.setpassword(password.encode()) + zip_ref.extractall(self.target_dir) + + self.assertTrue( + check_zip_path(self.zip_path, self.target_dir, password=password) + ) + + +class TestUnzipLocalFileCore(TestCase): + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.zip_path = os.path.join(self.temp_dir.name, 'test.zip') + self.target_dir = os.path.join(self.temp_dir.name, 'target') + + def tearDown(self): + self.temp_dir.cleanup() + + def test_nonexistent_zip_file(self): + with self.assertRaises(ModelResourceException) as cm: + unzip_local_file_core('invalid.zip', self.target_dir) + self.assertIn('does not exist', str(cm.exception)) + + def test_target_directory_conflict(self): + # Create target directory first + os.makedirs(self.target_dir) + with zipfile.ZipFile(self.zip_path, 'w') as zipf: + zipf.writestr('file.txt', 'content') + + with self.assertRaises(ModelResourceException) as cm: + unzip_local_file_core(self.zip_path, self.target_dir) + self.assertIn('already exists', str(cm.exception)) + + def test_successful_extraction(self): + with zipfile.ZipFile(self.zip_path, 'w') as zipf: + zipf.writestr('file.txt', 'content') + + result = unzip_local_file_core(self.zip_path, self.target_dir) + self.assertEqual(result, self.target_dir) + self.assertTrue(os.path.exists(os.path.join(self.target_dir, 'file.txt'))) + + def test_password_protected_extraction(self): + password = 'secret' + with zipfile.ZipFile(self.zip_path, 'w') as zipf: + zipf.writestr('file.txt', 'content') + zipf.setpassword(password.encode()) + + unzip_local_file_core(self.zip_path, self.target_dir, password=password) + self.assertTrue(os.path.exists(os.path.join(self.target_dir, 'file.txt'))) + + +class TestUnzipLocalFile(TestCase): + + def setUp(self): + self.temp_dir = tempfile.TemporaryDirectory() + self.zip_path = os.path.join(self.temp_dir.name, 'test.zip') + self.target_dir = os.path.join(self.temp_dir.name, 'target') + with zipfile.ZipFile(self.zip_path, 'w') as zipf: + zipf.writestr('file.txt', 'content') + + def tearDown(self): + self.temp_dir.cleanup() + + @patch( + 'llm_web_kit.model.resource_utils.unzip_ext.process_and_verify_file_with_lock' + ) + def test_unzip(self, mock_process): + + def process_and_verify( + process_func, verify_func, target_path, lock_suffix, timeout + ): + process_func() + if verify_func(): + return target_path + + mock_process.side_effect = process_and_verify + result = unzip_local_file(self.zip_path, self.target_dir) + self.assertEqual(result, self.target_dir) + self.assertTrue(os.path.exists(os.path.join(self.target_dir, 'file.txt'))) diff --git a/tests/llm_web_kit/model/test_political.py b/tests/llm_web_kit/model/test_political.py index 3be90563..54cc3437 100644 --- a/tests/llm_web_kit/model/test_political.py +++ b/tests/llm_web_kit/model/test_political.py @@ -21,7 +21,7 @@ class TestPoliticalDetector: - @patch('llm_web_kit.model.policical.AutoTokenizer.from_pretrained') + @patch('transformers.AutoTokenizer.from_pretrained') @patch('llm_web_kit.model.policical.fasttext.load_model') @patch('llm_web_kit.model.policical.PoliticalDetector.auto_download') def test_init(self, mock_auto_download, mock_load_model, mock_auto_tokenizer): @@ -46,7 +46,7 @@ def test_init(self, mock_auto_download, mock_load_model, mock_auto_tokenizer): trust_remote_code=True, ) - @patch('llm_web_kit.model.policical.AutoTokenizer.from_pretrained') + @patch('transformers.AutoTokenizer.from_pretrained') @patch('llm_web_kit.model.policical.fasttext.load_model') @patch('llm_web_kit.model.policical.PoliticalDetector.auto_download') def test_predict(self, mock_auto_download, mock_load_model, mock_auto_tokenizer): diff --git a/tests/llm_web_kit/model/test_quality_model.py b/tests/llm_web_kit/model/test_quality_model.py index 4a879eec..4172be95 100644 --- a/tests/llm_web_kit/model/test_quality_model.py +++ b/tests/llm_web_kit/model/test_quality_model.py @@ -9,8 +9,7 @@ from llm_web_kit.model.quality_model import get_quality_model # noqa: E402 from llm_web_kit.model.quality_model import quality_prober # noqa: E402 from llm_web_kit.model.quality_model import QualityFilter -from llm_web_kit.model.resource_utils.download_assets import \ - CACHE_DIR # noqa: E402 +from llm_web_kit.model.resource_utils.utils import CACHE_DIR current_file_path = os.path.abspath(__file__) parent_dir_path = os.path.join(current_file_path, *[os.pardir] * 4)