diff --git a/llm_web_kit/config/cfg_reader.py b/llm_web_kit/config/cfg_reader.py index 652eaad0..5b648f76 100644 --- a/llm_web_kit/config/cfg_reader.py +++ b/llm_web_kit/config/cfg_reader.py @@ -2,15 +2,22 @@ import commentjson as json +from llm_web_kit.exception.exception import ModelResourceException + def load_config() -> dict: - """_summary_ + """Load the configuration file for the web kit. First try to read the + configuration file from the environment variable LLM_WEB_KIT_CFG_PATH. If + the environment variable is not set, use the default configuration file + path ~/.llm-web-kit.jsonc. If the configuration file does not exist, raise + an exception. - Args: - config_file (_type_): _description_ + Raises: + ModelResourceException: LLM_WEB_KIT_CFG_PATH points to a non-exist file + ModelResourceException: cfg_path does not exist Returns: - _type_: _description_ + config(dict): The configuration dictionary """ # 首先从环境变量LLM_WEB_KIT_CFG_PATH 读取配置文件的位置 # 如果没有配置,就使用默认的配置文件位置 @@ -19,12 +26,15 @@ def load_config() -> dict: if env_cfg_path: cfg_path = env_cfg_path if not os.path.exists(cfg_path): - raise FileNotFoundError(f'environment variable LLM_WEB_KIT_CFG_PATH points to a non-exist file: {cfg_path}') + raise ModelResourceException( + f'environment variable LLM_WEB_KIT_CFG_PATH points to a non-exist file: {cfg_path}' + ) else: cfg_path = os.path.expanduser('~/.llm-web-kit.jsonc') if not os.path.exists(cfg_path): - raise FileNotFoundError( - f'{cfg_path} does not exist, please create one or set environment variable LLM_WEB_KIT_CFG_PATH to a valid file path') + raise ModelResourceException( + f'{cfg_path} does not exist, please create one or set environment variable LLM_WEB_KIT_CFG_PATH to a valid file path' + ) # 读取配置文件 with open(cfg_path, 'r', encoding='utf-8') as f: diff --git a/llm_web_kit/model/resource_utils/boto3_ext.py b/llm_web_kit/model/resource_utils/boto3_ext.py index 11341568..ac75fb8b 100644 --- a/llm_web_kit/model/resource_utils/boto3_ext.py +++ b/llm_web_kit/model/resource_utils/boto3_ext.py @@ -1,20 +1,37 @@ import re -from typing import Dict, List, Union +from typing import Dict, List, Tuple, Union import boto3 from botocore.config import Config from botocore.exceptions import ClientError from llm_web_kit.config.cfg_reader import load_config +from llm_web_kit.exception.exception import ModelResourceException __re_s3_path = re.compile('^s3://([^/]+)(?:/(.*))?$') def is_s3_path(path: str) -> bool: + """check a path is s3 path or not. + + Args: + path (str): path + + Returns: + bool: is s3 path or not + """ return path.startswith('s3://') -def is_s3_404_error(e: Exception): +def is_s3_404_error(e: Exception) -> bool: + """check if an exception is 404 error. + + Args: + e (Exception): exception + + Returns: + bool: is 404 error or not + """ if not isinstance(e, ClientError): return False flag_1 = e.response.get('Error', {}).get('Code') in ['404', 'NoSuchKey'] @@ -23,22 +40,35 @@ def is_s3_404_error(e: Exception): return any([flag_1, flag_2, flag_3]) -def split_s3_path(path: str): - """split bucket and key from path.""" +def split_s3_path(path: str) -> Tuple[str, str]: + """split bucket and key from path. + + Args: + path (str): s3 path + + Returns: + Tuple[str, str]: bucket and key + + Raises: + ModelResourceException: if path is not s3 path + """ + if not is_s3_path(path): + raise ModelResourceException(f'{path} is not a s3 path') m = __re_s3_path.match(path) if m is None: return '', '' return m.group(1), (m.group(2) or '') -def get_s3_config(path: str): +def get_s3_config(path: str) -> Dict: """Get s3 config for a given path by its bucket name from the config file. Args: path (str): s3 path Raises: - ValueError: if bucket not found in config + ModelResourceException: if bucket not in config + ModelResourceException: if path is not s3 path Returns: dict: s3 config @@ -48,10 +78,23 @@ def get_s3_config(path: str): if bucket in config_dict['s3']: return config_dict['s3'][bucket] else: - raise ValueError(f'bucket {bucket} not in config') + raise ModelResourceException(f'bucket {bucket} not in config') + + +def get_s3_client(path: Union[str, List[str]]) -> boto3.client: + """Get s3 client for a given path. + + Args: + path (Union[str, List[str]]): s3 path + + Returns: + boto3.client: s3 client + Raises: + ModelResourceException: if bucket not in config + ModelResourceException: if path is not s3 path + """ -def get_s3_client(path: Union[str, List[str]]): s3_config = get_s3_config(path) try: return boto3.client( @@ -61,10 +104,7 @@ def get_s3_client(path: Union[str, List[str]]): endpoint_url=s3_config['endpoint'], config=Config( s3={'addressing_style': s3_config.get('addressing_style', 'path')}, - retries={ - 'max_attempts': 8, - 'mode': 'standard' - }, + retries={'max_attempts': 8, 'mode': 'standard'}, connect_timeout=600, read_timeout=600, ), @@ -84,6 +124,21 @@ def get_s3_client(path: Union[str, List[str]]): def head_s3_object(client, path: str, raise_404=False) -> Union[Dict, None]: + """Get s3 object metadata. + + Args: + client (boto3.client): the s3 client + path (str): the s3 path + raise_404 (bool, optional): raise 404 error or not. Defaults to False. + + Returns: + Union[Dict, None]: s3 object metadata or None if not found + + Raises: + ClientError: if raise_404 is True and object not + ModelResourceException: if path is not s3 path + ModelResourceException: if bucket not in config + """ bucket, key = split_s3_path(path) try: resp = client.head_object(Bucket=bucket, Key=key) diff --git a/llm_web_kit/model/resource_utils/download_assets.py b/llm_web_kit/model/resource_utils/download_assets.py index efa07619..ab7cbead 100644 --- a/llm_web_kit/model/resource_utils/download_assets.py +++ b/llm_web_kit/model/resource_utils/download_assets.py @@ -1,28 +1,19 @@ -import errno import hashlib import os import shutil import tempfile -import time from typing import Iterable, Optional import requests from tqdm import tqdm from llm_web_kit.config.cfg_reader import load_config -from llm_web_kit.exception.exception import ModelInputException +from llm_web_kit.exception.exception import ModelResourceException from llm_web_kit.libs.logger import mylogger as logger from llm_web_kit.model.resource_utils.boto3_ext import (get_s3_client, is_s3_path, split_s3_path) - - -def try_remove(path: str): - """Attempt to remove a file, but ignore any exceptions that occur.""" - try: - os.remove(path) - except Exception: - pass +from llm_web_kit.model.resource_utils.utils import FileLockContext, try_remove def decide_cache_dir(): @@ -114,59 +105,12 @@ def __del__(self): self.response.close() -class FileLock: - """基于文件锁的上下文管理器(跨平台兼容版)""" - - def __init__(self, lock_path: str, timeout: float = 300): - self.lock_path = lock_path - self.timeout = timeout - self._fd = None - - def __enter__(self): - start_time = time.time() - while True: - try: - # 原子性创建锁文件(O_EXCL标志是关键) - self._fd = os.open( - self.lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o644 - ) - # 写入进程信息和时间戳 - with os.fdopen(self._fd, 'w') as f: - f.write(f'{os.getpid()}\n{time.time()}') - return self - except OSError as e: - if e.errno != errno.EEXIST: - raise - - # 检查锁是否过期 - try: - with open(self.lock_path, 'r') as f: - pid, timestamp = f.read().split('\n')[:2] - if time.time() - float(timestamp) > self.timeout: - os.remove(self.lock_path) - except (FileNotFoundError, ValueError): - pass - - if time.time() - start_time > self.timeout: - raise TimeoutError(f'Could not acquire lock after {self.timeout}s') - time.sleep(0.1) - - def __exit__(self, exc_type, exc_val, exc_tb): - try: - if self._fd: - os.close(self._fd) - except OSError: - pass - finally: - try_remove(self.lock_path) - - def verify_file_checksum( file_path: str, md5_sum: Optional[str] = None, sha256_sum: Optional[str] = None ) -> bool: """校验文件哈希值.""" if not sum([bool(md5_sum), bool(sha256_sum)]) == 1: - raise ModelInputException( + raise ModelResourceException( 'Exactly one of md5_sum or sha256_sum must be provided' ) @@ -210,7 +154,7 @@ def download_to_temp(conn, progress_bar) -> str: def move_to_target(tmp_path: str, target_path: str, expected_size: int): """移动文件并验证.""" if os.path.getsize(tmp_path) != expected_size: - raise ValueError( + raise ModelResourceException( f'File size mismatch: {os.path.getsize(tmp_path)} vs {expected_size}' ) @@ -218,7 +162,7 @@ def move_to_target(tmp_path: str, target_path: str, expected_size: int): shutil.move(tmp_path, target_path) # 原子操作替换 if not os.path.exists(target_path): - raise RuntimeError(f'Move failed: {tmp_path} -> {target_path}') + raise ModelResourceException(f'Move failed: {tmp_path} -> {target_path}') def download_auto_file( @@ -254,20 +198,29 @@ def download_auto_file( """线程安全的文件下载函数""" lock_path = f'{target_path}.lock' - with FileLock(lock_path, timeout=lock_timeout): - # 二次检查(其他进程可能已经完成下载) - if os.path.exists(target_path): - if verify_file_checksum(target_path, md5_sum, sha256_sum): - logger.info(f'File already exists with valid checksum: {target_path}') - return target_path - - if not exist_ok: - raise FileExistsError( - f'File exists with invalid checksum: {target_path}' - ) + def check_callback(): + return verify_file_checksum(target_path, md5_sum, sha256_sum) + + if os.path.exists(target_path): + if not exist_ok: + raise ModelResourceException( + f'File exists with invalid checksum: {target_path}' + ) + + if verify_file_checksum(target_path, md5_sum, sha256_sum): + logger.info(f'File already exists with valid checksum: {target_path}') + return target_path + else: logger.warning(f'Removing invalid file: {target_path}') try_remove(target_path) + with FileLockContext(lock_path, check_callback, timeout=lock_timeout) as lock: + if lock is True: + logger.info( + f'File already exists with valid checksum: {target_path} while waiting' + ) + return target_path + # 创建连接 conn_cls = S3Connection if is_s3_path(resource_path) else HttpConnection conn = conn_cls(resource_path) diff --git a/llm_web_kit/model/resource_utils/singleton_resource_manager.py b/llm_web_kit/model/resource_utils/singleton_resource_manager.py index 24849131..84ddcc53 100644 --- a/llm_web_kit/model/resource_utils/singleton_resource_manager.py +++ b/llm_web_kit/model/resource_utils/singleton_resource_manager.py @@ -1,3 +1,6 @@ +from llm_web_kit.exception.exception import ModelResourceException + + class SingletonResourceManager: def __init__(self): @@ -8,9 +11,11 @@ def has_name(self, name): def set_resource(self, name: str, resource): if not isinstance(name, str): - raise TypeError('name should be a string') + raise ModelResourceException( + f'Name should be a string, but got {type(name)}' + ) if name in self.resources: - raise AssertionError(f'Resource {name} already exists') + raise ModelResourceException(f'Resource {name} already exists') self.resources[name] = resource @@ -18,7 +23,7 @@ def get_resource(self, name): if name in self.resources: return self.resources[name] else: - raise Exception(f'Resource {name} does not exist') + raise ModelResourceException(f'Resource {name} does not exist') def release_resource(self, name): if name in self.resources: diff --git a/llm_web_kit/model/resource_utils/unzip_ext.py b/llm_web_kit/model/resource_utils/unzip_ext.py index ca1a70b2..6579fd62 100644 --- a/llm_web_kit/model/resource_utils/unzip_ext.py +++ b/llm_web_kit/model/resource_utils/unzip_ext.py @@ -4,7 +4,9 @@ import zipfile from typing import Optional -from llm_web_kit.model.resource_utils.download_assets import FileLock +from llm_web_kit.exception.exception import ModelResourceException +from llm_web_kit.libs.logger import mylogger as logger +from llm_web_kit.model.resource_utils.utils import FileLockContext, try_remove def get_unzip_dir(zip_path: str) -> str: @@ -23,11 +25,32 @@ def get_unzip_dir(zip_path: str) -> str: return os.path.join(zip_dir, base_name + '_unzip') +def check_zip_file(zip_ref: zipfile.ZipFile, target_dir: str) -> bool: + """Check if the zip file is correctly unzipped to the target directory. + + Args: + zip_ref (zipfile.ZipFile): The zip file object. + target_dir (str): The target directory. + + Returns: + bool: True if the zip file is correctly unzipped to the target directory, False otherwise. + """ + + zip_info_list = [info for info in zip_ref.infolist() if not info.is_dir()] + for info in zip_info_list: + file_path = os.path.join(target_dir, info.filename) + if not os.path.exists(file_path): + return False + if os.path.getsize(file_path) != info.file_size: + return False + return True + + def unzip_local_file( zip_path: str, target_dir: str, password: Optional[str] = None, - exist_ok: bool = False, + exist_ok: bool = True, lock_timeout: float = 300, ) -> str: """Unzip a zip file to a target directory. @@ -40,20 +63,47 @@ def unzip_local_file( If False, raise an exception if the target directory already exists. Defaults to False. Raises: - Exception: If the target directory already exists and exist_ok is False. + ModelResourceException: If the zip file does not exist. + ModelResourceException: If the target directory already exists and exist_ok is False Returns: str: The path to the target directory. """ lock_path = f'{zip_path}.lock' - with FileLock(lock_path, timeout=lock_timeout): + + if not os.path.exists(zip_path): + logger.error(f'zip file {zip_path} does not exist') + raise ModelResourceException(f'zip file {zip_path} does not exist') + + def check_zip(): + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + if password: + zip_ref.setpassword(password.encode()) + return check_zip_file(zip_ref, target_dir) + + if os.path.exists(target_dir): + if not exist_ok: + raise ModelResourceException( + f'Target directory {target_dir} already exists' + ) + + if check_zip(): + logger.info(f'zip file {zip_path} is already unzipped to {target_dir}') + return target_dir + else: + logger.warning( + f'zip file {zip_path} is not correctly unzipped to {target_dir}, retry to unzip' + ) + try_remove(target_dir) + + with FileLockContext(lock_path, check_zip, timeout=lock_timeout) as lock: + if lock is True: + logger.info( + f'zip file {zip_path} is already unzipped to {target_dir} while waiting' + ) + return target_dir # ensure target directory not exists - if os.path.exists(target_dir): - if exist_ok: - shutil.rmtree(target_dir) - else: - raise Exception(f'Target directory {target_dir} already exists') # 创建临时解压目录 with tempfile.TemporaryDirectory() as temp_dir: diff --git a/llm_web_kit/model/resource_utils/utils.py b/llm_web_kit/model/resource_utils/utils.py new file mode 100644 index 00000000..b80a35d0 --- /dev/null +++ b/llm_web_kit/model/resource_utils/utils.py @@ -0,0 +1,62 @@ +import errno +import os +import time + + +def try_remove(path: str): + """Attempt to remove a file, but ignore any exceptions that occur.""" + try: + os.remove(path) + except Exception: + pass + + +class FileLockContext: + """基于文件锁的上下文管理器(跨平台兼容版)""" + + def __init__(self, lock_path: str, check_callback=None, timeout: float = 300): + self.lock_path = lock_path + self.check_callback = check_callback + self.timeout = timeout + self._fd = None + + def __enter__(self): + start_time = time.time() + while True: + if self.check_callback: + if self.check_callback(): + return True + try: + # 原子性创建锁文件(O_EXCL标志是关键) + self._fd = os.open( + self.lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o644 + ) + # 写入进程信息和时间戳 + with os.fdopen(self._fd, 'w') as f: + f.write(f'{os.getpid()}\n{time.time()}') + return self + except OSError as e: + if e.errno != errno.EEXIST: + raise + + # 检查锁是否过期 + try: + with open(self.lock_path, 'r') as f: + pid, timestamp = f.read().split('\n')[:2] + if time.time() - float(timestamp) > self.timeout: + os.remove(self.lock_path) + except (FileNotFoundError, ValueError): + pass + + if time.time() - start_time > self.timeout: + raise TimeoutError(f'Could not acquire lock after {self.timeout}s') + time.sleep(0.1) + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + if self._fd: + os.close(self._fd) + except OSError: + pass + finally: + try_remove(self.lock_path) diff --git a/tests/llm_web_kit/model/assets/zip_demo.zip b/tests/llm_web_kit/model/assets/zip_demo.zip new file mode 100644 index 00000000..62be3049 Binary files /dev/null and b/tests/llm_web_kit/model/assets/zip_demo.zip differ diff --git a/tests/llm_web_kit/model/resource_utils/test_boto3_ext.py b/tests/llm_web_kit/model/resource_utils/test_boto3_ext.py index 5be52f4a..5510dec0 100644 --- a/tests/llm_web_kit/model/resource_utils/test_boto3_ext.py +++ b/tests/llm_web_kit/model/resource_utils/test_boto3_ext.py @@ -3,6 +3,7 @@ import pytest from botocore.exceptions import ClientError +from llm_web_kit.exception.exception import ModelResourceException from llm_web_kit.model.resource_utils.boto3_ext import (get_s3_client, get_s3_config, head_s3_object, @@ -19,13 +20,8 @@ def test_is_s3_path(): def test_is_s3_404_error(): not_found_error = ClientError( error_response={ - 'Error': { - 'Code': '404', - 'Message': 'Not Found' - }, - 'ResponseMetadata': { - 'HTTPStatusCode': 404 - }, + 'Error': {'Code': '404', 'Message': 'Not Found'}, + 'ResponseMetadata': {'HTTPStatusCode': 404}, }, operation_name='test', ) @@ -33,13 +29,8 @@ def test_is_s3_404_error(): not_404_error = ClientError( error_response={ - 'Error': { - 'Code': '403', - 'Message': 'Forbidden' - }, - 'ResponseMetadata': { - 'HTTPStatusCode': 403 - }, + 'Error': {'Code': '403', 'Message': 'Forbidden'}, + 'ResponseMetadata': {'HTTPStatusCode': 403}, }, operation_name='test', ) @@ -54,20 +45,28 @@ def test_split_s3_path(): @patch('llm_web_kit.model.resource_utils.boto3_ext.load_config') def test_get_s3_config(get_config_mock): - get_config_mock.return_value = {'s3': {'bucket': {'ak': 'test_ak', 'sk': 'test_sk', 'endpoint': 'test_endpoint'}}} + get_config_mock.return_value = { + 's3': { + 'bucket': {'ak': 'test_ak', 'sk': 'test_sk', 'endpoint': 'test_endpoint'} + } + } assert get_s3_config('s3://bucket/key') == { 'ak': 'test_ak', 'sk': 'test_sk', 'endpoint': 'test_endpoint', } - with pytest.raises(ValueError): + with pytest.raises(ModelResourceException): get_s3_config('s3://nonexistent_bucket/key') @patch('llm_web_kit.model.resource_utils.boto3_ext.load_config') @patch('llm_web_kit.model.resource_utils.boto3_ext.boto3.client') def test_get_s3_client(boto3_client_mock, get_config_mock): - get_config_mock.return_value = {'s3': {'bucket': {'ak': 'test_ak', 'sk': 'test_sk', 'endpoint': 'test_endpoint'}}} + get_config_mock.return_value = { + 's3': { + 'bucket': {'ak': 'test_ak', 'sk': 'test_sk', 'endpoint': 'test_endpoint'} + } + } mock_client = MagicMock() boto3_client_mock.return_value = mock_client assert get_s3_client('s3://bucket/key') == mock_client @@ -78,19 +77,18 @@ def test_get_s3_client(boto3_client_mock, get_config_mock): def test_head_s3_object(boto3_client_mock, is_s3_404_error_mock): s3_client_mock = MagicMock() boto3_client_mock.return_value = s3_client_mock - s3_client_mock.head_object.return_value = {'ResponseMetadata': {'HTTPStatusCode': 200}} + s3_client_mock.head_object.return_value = { + 'ResponseMetadata': {'HTTPStatusCode': 200} + } - assert head_s3_object(s3_client_mock, 's3://bucket/key') == {'ResponseMetadata': {'HTTPStatusCode': 200}} + assert head_s3_object(s3_client_mock, 's3://bucket/key') == { + 'ResponseMetadata': {'HTTPStatusCode': 200} + } s3_client_mock.head_object.side_effect = ClientError( error_response={ - 'Error': { - 'Code': '404', - 'Message': 'Not Found' - }, - 'ResponseMetadata': { - 'HTTPStatusCode': 404 - }, + 'Error': {'Code': '404', 'Message': 'Not Found'}, + 'ResponseMetadata': {'HTTPStatusCode': 404}, }, operation_name='test', ) diff --git a/tests/llm_web_kit/model/resource_utils/test_download_assets.py b/tests/llm_web_kit/model/resource_utils/test_download_assets.py index a3c40d5c..b3e55e74 100644 --- a/tests/llm_web_kit/model/resource_utils/test_download_assets.py +++ b/tests/llm_web_kit/model/resource_utils/test_download_assets.py @@ -1,4 +1,3 @@ -import errno import io import os import tempfile @@ -6,25 +5,11 @@ from typing import Tuple from unittest.mock import MagicMock, call, mock_open, patch -from llm_web_kit.exception.exception import ModelInputException +from llm_web_kit.exception.exception import ModelResourceException from llm_web_kit.model.resource_utils.download_assets import ( - FileLock, HttpConnection, S3Connection, calc_file_md5, calc_file_sha256, + HttpConnection, S3Connection, calc_file_md5, calc_file_sha256, decide_cache_dir, download_auto_file, download_to_temp, move_to_target, - try_remove, verify_file_checksum) - - -class Test_try_remove: - - @patch('os.remove') - def test_remove(self, removeMock): - try_remove('path') - removeMock.assert_called_once_with('path') - - @patch('os.remove') - def test_remove_exception(self, removeMock): - removeMock.side_effect = Exception - try_remove('path') - removeMock.assert_called_once_with('path') + verify_file_checksum) class Test_decide_cache_dir: @@ -141,128 +126,6 @@ def test_HttpConnection(requests_get_mock): assert b''.join(conn.read_stream()) == test_data -class TestFileLock(unittest.TestCase): - - def setUp(self): - self.lock_path = 'test.lock' - - @patch('os.fdopen') - @patch('os.open') - @patch('os.close') - @patch('os.remove') - def test_acquire_and_release_lock( - self, mock_remove, mock_close, mock_open, mock_os_fdopen - ): - # 模拟成功获取锁 - mock_open.return_value = 123 # 假设文件描述符为123 - # 模拟文件描述符 - mock_fd = MagicMock() - mock_fd.__enter__.return_value = mock_fd - mock_fd.write.return_value = None - mock_os_fdopen.return_value = mock_fd - - with FileLock(self.lock_path): - mock_open.assert_called_once_with( - self.lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o644 - ) - mock_close.assert_called_once_with(123) - mock_remove.assert_called_once_with(self.lock_path) - - @patch('os.fdopen') - @patch('os.open') - @patch('builtins.open', new_callable=mock_open, read_data='1234\n100') - @patch('time.time') - @patch('os.remove') - def test_remove_stale_lock( - self, mock_remove, mock_time, mock_file_open, mock_os_open, mock_os_fdopen - ): - # 第一次尝试创建锁文件失败(锁已存在) - mock_os_open.side_effect = [ - OSError(errno.EEXIST, 'File exists'), - 123, # 第二次成功 - ] - - # 模拟文件描述符 - mock_fd = MagicMock() - mock_fd.__enter__.return_value = mock_fd - mock_fd.write.return_value = None - mock_os_fdopen.return_value = mock_fd - - # 当前时间设置为超过超时时间(timeout=300) - mock_time.return_value = 401 # 100 + 300 + 1 - - with FileLock(self.lock_path, timeout=300): - mock_remove.assert_called_once_with(self.lock_path) - mock_os_open.assert_any_call( - self.lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o644 - ) - - @patch('os.open') - @patch('time.time') - def test_timeout_acquiring_lock(self, mock_time, mock_os_open): - # 总是返回EEXIST错误 - mock_os_open.side_effect = OSError(errno.EEXIST, 'File exists') - # 时间累计超过超时时间 - start_time = 1000 - mock_time.side_effect = [ - start_time, - start_time + 301, - start_time + 302, - start_time + 303, - ] - - with self.assertRaises(TimeoutError): - with FileLock(self.lock_path, timeout=300): - pass - - @patch('os.open') - def test_other_os_error(self, mock_os_open): - # 模拟其他OS错误(如权限不足) - mock_os_open.side_effect = OSError(errno.EACCES, 'Permission denied') - with self.assertRaises(OSError): - with FileLock(self.lock_path): - pass - - @patch('os.close') - @patch('os.remove') - def test_cleanup_on_exit(self, mock_remove, mock_close): - - mock_close.side_effect = None - # 确保退出上下文时执行清理 - lock_path = 'test.lock' - lock = FileLock(lock_path) - lock._fd = 123 # 模拟已打开的文件描述符 - lock.__exit__('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!', None, None) - mock_remove.assert_called_once_with(self.lock_path) - - @patch('os.remove') - def test_cleanup_failure_handled(self, mock_remove): - # 模拟删除锁文件时失败 - mock_remove.side_effect = OSError - lock = FileLock(self.lock_path) - lock._fd = 123 - # 不应抛出异常 - lock.__exit__(None, None, None) - - @patch('os.getpid') - @patch('time.time') - def test_lock_file_content(self, mock_time, mock_pid): - # 验证锁文件内容格式 - mock_pid.return_value = 9999 - mock_time.return_value = 123456.789 - - with patch('os.open') as mock_os_open: - mock_os_open.return_value = 123 - with patch('os.fdopen') as mock_fdopen: - # 模拟写入文件描述符 - mock_file = MagicMock() - mock_fdopen.return_value.__enter__.return_value = mock_file - - with FileLock(self.lock_path): - mock_fdopen.assert_called_once_with(123, 'w') - mock_file.write.assert_called_once_with('9999\n123456.789') - - class TestDownloadAutoFile(unittest.TestCase): @patch('llm_web_kit.model.resource_utils.download_assets.os.path.exists') @@ -448,7 +311,7 @@ def test_file_not_exists_download_http( # ) -> bool: # """校验文件哈希值.""" # if not sum([bool(md5_sum), bool(sha256_sum)]) == 1: -# raise ModelInputException('Exactly one of md5_sum or sha256_sum must be provided') +# raise ModelResourceException('Exactly one of md5_sum or sha256_sum must be provided') # if md5_sum: # actual = calc_file_md5(file_path) @@ -482,8 +345,8 @@ def test_pass_two_value(self, mock_calc_file_sha256, mock_calc_file_md5): sha256_sum = 'sha256_sum' mock_calc_file_md5.return_value = md5_sum mock_calc_file_sha256.return_value = sha256_sum - # will raise ModelInputException - with self.assertRaises(ModelInputException): + # will raise ModelResourceException + with self.assertRaises(ModelResourceException): verify_file_checksum(file_path, md5_sum, sha256_sum) @patch('llm_web_kit.model.resource_utils.download_assets.calc_file_md5') @@ -492,8 +355,8 @@ def test_pass_two_None(self, mock_calc_file_sha256, mock_calc_file_md5): file_path = 'file_path' md5_sum = None sha256_sum = None - # will raise ModelInputException - with self.assertRaises(ModelInputException): + # will raise ModelResourceException + with self.assertRaises(ModelResourceException): verify_file_checksum(file_path, md5_sum, sha256_sum) @patch('llm_web_kit.model.resource_utils.download_assets.calc_file_md5') @@ -601,7 +464,7 @@ def test_size_mismatch(self): with open(tmp_path, 'wb') as f: f.write(b'short') - with self.assertRaisesRegex(ValueError, 'size mismatch'): + with self.assertRaisesRegex(ModelResourceException, 'size mismatch'): move_to_target(tmp_path, self.target_path, 100) def test_directory_creation(self): diff --git a/tests/llm_web_kit/model/resource_utils/test_resource_utils.py b/tests/llm_web_kit/model/resource_utils/test_resource_utils.py new file mode 100644 index 00000000..15448bcf --- /dev/null +++ b/tests/llm_web_kit/model/resource_utils/test_resource_utils.py @@ -0,0 +1,142 @@ +import errno +import os +import unittest +from unittest.mock import MagicMock, mock_open, patch + +from llm_web_kit.model.resource_utils.utils import FileLockContext, try_remove + + +class Test_try_remove: + + @patch('os.remove') + def test_remove(self, removeMock): + try_remove('path') + removeMock.assert_called_once_with('path') + + @patch('os.remove') + def test_remove_exception(self, removeMock): + removeMock.side_effect = Exception + try_remove('path') + removeMock.assert_called_once_with('path') + + +class TestFileLock(unittest.TestCase): + + def setUp(self): + self.lock_path = 'test.lock' + + @patch('os.fdopen') + @patch('os.open') + @patch('os.close') + @patch('os.remove') + def test_acquire_and_release_lock( + self, mock_remove, mock_close, mock_open, mock_os_fdopen + ): + # 模拟成功获取锁 + mock_open.return_value = 123 # 假设文件描述符为123 + # 模拟文件描述符 + mock_fd = MagicMock() + mock_fd.__enter__.return_value = mock_fd + mock_fd.write.return_value = None + mock_os_fdopen.return_value = mock_fd + + with FileLockContext(self.lock_path): + mock_open.assert_called_once_with( + self.lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o644 + ) + mock_close.assert_called_once_with(123) + mock_remove.assert_called_once_with(self.lock_path) + + @patch('os.fdopen') + @patch('os.open') + @patch('builtins.open', new_callable=mock_open, read_data='1234\n100') + @patch('time.time') + @patch('os.remove') + def test_remove_stale_lock( + self, mock_remove, mock_time, mock_file_open, mock_os_open, mock_os_fdopen + ): + # 第一次尝试创建锁文件失败(锁已存在) + mock_os_open.side_effect = [ + OSError(errno.EEXIST, 'File exists'), + 123, # 第二次成功 + ] + + # 模拟文件描述符 + mock_fd = MagicMock() + mock_fd.__enter__.return_value = mock_fd + mock_fd.write.return_value = None + mock_os_fdopen.return_value = mock_fd + + # 当前时间设置为超过超时时间(timeout=300) + mock_time.return_value = 401 # 100 + 300 + 1 + + with FileLockContext(self.lock_path, timeout=300): + mock_remove.assert_called_once_with(self.lock_path) + mock_os_open.assert_any_call( + self.lock_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o644 + ) + + @patch('os.open') + @patch('time.time') + def test_timeout_acquiring_lock(self, mock_time, mock_os_open): + # 总是返回EEXIST错误 + mock_os_open.side_effect = OSError(errno.EEXIST, 'File exists') + # 时间累计超过超时时间 + start_time = 1000 + mock_time.side_effect = [ + start_time, + start_time + 301, + start_time + 302, + start_time + 303, + ] + + with self.assertRaises(TimeoutError): + with FileLockContext(self.lock_path, timeout=300): + pass + + @patch('os.open') + def test_other_os_error(self, mock_os_open): + # 模拟其他OS错误(如权限不足) + mock_os_open.side_effect = OSError(errno.EACCES, 'Permission denied') + with self.assertRaises(OSError): + with FileLockContext(self.lock_path): + pass + + @patch('os.close') + @patch('os.remove') + def test_cleanup_on_exit(self, mock_remove, mock_close): + + mock_close.side_effect = None + # 确保退出上下文时执行清理 + lock_path = 'test.lock' + lock = FileLockContext(lock_path) + lock._fd = 123 # 模拟已打开的文件描述符 + lock.__exit__('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!', None, None) + mock_remove.assert_called_once_with(self.lock_path) + + @patch('os.remove') + def test_cleanup_failure_handled(self, mock_remove): + # 模拟删除锁文件时失败 + mock_remove.side_effect = OSError + lock = FileLockContext(self.lock_path) + lock._fd = 123 + # 不应抛出异常 + lock.__exit__(None, None, None) + + @patch('os.getpid') + @patch('time.time') + def test_lock_file_content(self, mock_time, mock_pid): + # 验证锁文件内容格式 + mock_pid.return_value = 9999 + mock_time.return_value = 123456.789 + + with patch('os.open') as mock_os_open: + mock_os_open.return_value = 123 + with patch('os.fdopen') as mock_fdopen: + # 模拟写入文件描述符 + mock_file = MagicMock() + mock_fdopen.return_value.__enter__.return_value = mock_file + + with FileLockContext(self.lock_path): + mock_fdopen.assert_called_once_with(123, 'w') + mock_file.write.assert_called_once_with('9999\n123456.789') diff --git a/tests/llm_web_kit/model/resource_utils/test_singleton_resource_manager.py b/tests/llm_web_kit/model/resource_utils/test_singleton_resource_manager.py index 8e345495..22a9282f 100644 --- a/tests/llm_web_kit/model/resource_utils/test_singleton_resource_manager.py +++ b/tests/llm_web_kit/model/resource_utils/test_singleton_resource_manager.py @@ -1,5 +1,6 @@ import pytest +from llm_web_kit.exception.exception import ModelResourceException from llm_web_kit.model.resource_utils.singleton_resource_manager import \ SingletonResourceManager @@ -27,15 +28,15 @@ def test_set_resource(self): assert self.manager.get_resource('test') == 'resource' # "test" should not be set again - with pytest.raises(AssertionError): + with pytest.raises(ModelResourceException): self.manager.set_resource('test', 'resource') # name should be a string - with pytest.raises(TypeError): + with pytest.raises(ModelResourceException): self.manager.set_resource(1, 'resource') # resource should not be None - with pytest.raises(TypeError): + with pytest.raises(ModelResourceException): self.manager.set_resource(None, 'resource') def test_get_resource(self): @@ -43,7 +44,7 @@ def test_get_resource(self): # "test" should exist after setting and the resource should be "resource" assert self.manager.get_resource('test') == 'resource' # Exception should be raised if the resource does not exist - with pytest.raises(Exception): + with pytest.raises(ModelResourceException): self.manager.get_resource('test1') def test_release_resource(self): @@ -52,7 +53,7 @@ def test_release_resource(self): # "test" should not exist after releasing assert not self.manager.has_name('test') # Exception should be raised if the resource does not exist - with pytest.raises(Exception): + with pytest.raises(ModelResourceException): self.manager.get_resource('test') # Should not raise exception if the resource does not exist self.manager.release_resource('test') diff --git a/tests/llm_web_kit/model/resource_utils/test_unzip_ext.py b/tests/llm_web_kit/model/resource_utils/test_unzip_ext.py index 39267ab1..bf514d14 100644 --- a/tests/llm_web_kit/model/resource_utils/test_unzip_ext.py +++ b/tests/llm_web_kit/model/resource_utils/test_unzip_ext.py @@ -1,14 +1,98 @@ import os import tempfile import zipfile +from unittest import TestCase -from llm_web_kit.model.resource_utils.unzip_ext import (get_unzip_dir, +from llm_web_kit.exception.exception import ModelResourceException +from llm_web_kit.model.resource_utils.unzip_ext import (check_zip_file, + get_unzip_dir, unzip_local_file) -def test_get_unzip_dir(): - assert get_unzip_dir('/path/to/test.zip') == '/path/to/test_unzip' - assert get_unzip_dir('/path/to/test') == '/path/to/test_unzip' +def get_assert_dir(): + file_path = os.path.abspath(__file__) + assert_dir = os.path.join(os.path.dirname(os.path.dirname(file_path)), 'assets') + return assert_dir + + +class TestGetUnzipDir(TestCase): + + def test_get_unzip_dir_case1(self): + assert get_unzip_dir('/path/to/test.zip') == '/path/to/test_unzip' + + def test_get_unzip_dir_case2(self): + assert get_unzip_dir('/path/to/test') == '/path/to/test_unzip' + + +class TestCheckZipFile(TestCase): + # # test_zip/ + # # ├── test1.txt "test1\n" + # # ├── folder1 + # # │ └── test2.txt "test2\n" + # # └── folder2 + + def get_zipfile(self): + # 创建一个临时文件夹 + zip_path = os.path.join(get_assert_dir(), 'zip_demo.zip') + zip_file = zipfile.ZipFile(zip_path, 'r') + return zip_file + + def test_check_zip_file_cese1(self): + zip_file = self.get_zipfile() + # # test_zip/ + # # ├── test1.txt + # # ├── folder1 + # # │ └── test2.txt + # # └── folder2 + + with tempfile.TemporaryDirectory() as temp_dir: + root_dir = os.path.join(temp_dir, 'test_zip') + os.makedirs(os.path.join(root_dir, 'test_zip')) + os.makedirs(os.path.join(root_dir, 'folder1')) + os.makedirs(os.path.join(root_dir, 'folder2')) + with open(os.path.join(root_dir, 'test1.txt'), 'w') as f: + f.write('test1\n') + with open(os.path.join(root_dir, 'folder1', 'test2.txt'), 'w') as f: + f.write('test2\n') + + assert check_zip_file(zip_file, temp_dir) is True + + def test_check_zip_file_cese2(self): + zip_file = self.get_zipfile() + with tempfile.TemporaryDirectory() as temp_dir: + root_dir = os.path.join(temp_dir, 'test_zip') + os.makedirs(os.path.join(root_dir, 'test_zip')) + os.makedirs(os.path.join(root_dir, 'folder1')) + with open(os.path.join(root_dir, 'test1.txt'), 'w') as f: + f.write('test1\n') + with open(os.path.join(root_dir, 'folder1', 'test2.txt'), 'w') as f: + f.write('test2\n') + + assert check_zip_file(zip_file, temp_dir) is True + + def test_check_zip_file_cese3(self): + zip_file = self.get_zipfile() + with tempfile.TemporaryDirectory() as temp_dir: + root_dir = os.path.join(temp_dir, 'test_zip') + os.makedirs(os.path.join(root_dir, 'test_zip')) + os.makedirs(os.path.join(root_dir, 'folder1')) + with open(os.path.join(root_dir, 'folder1', 'test2.txt'), 'w') as f: + f.write('test2\n') + + assert check_zip_file(zip_file, temp_dir) is False + + def test_check_zip_file_cese4(self): + zip_file = self.get_zipfile() + with tempfile.TemporaryDirectory() as temp_dir: + root_dir = os.path.join(temp_dir, 'test_zip') + os.makedirs(os.path.join(root_dir, 'test_zip')) + os.makedirs(os.path.join(root_dir, 'folder1')) + with open(os.path.join(root_dir, 'test1.txt'), 'w') as f: + f.write('test1\n') + with open(os.path.join(root_dir, 'folder1', 'test2.txt'), 'w') as f: + f.write('test123\n') + + assert check_zip_file(zip_file, temp_dir) is False def test_unzip_local_file(): @@ -35,5 +119,5 @@ def test_unzip_local_file(): assert f.read() == 'This is another test file' try: unzip_local_file(zip_path, target_dir, exist_ok=False) - except Exception as e: - assert str(e) == f'Target directory {target_dir} already exists' + except ModelResourceException as e: + assert e.custom_message == f'Target directory {target_dir} already exists'