diff --git a/llm_web_kit/exception/exception.py b/llm_web_kit/exception/exception.py index 4c28c700..c3f4f5d1 100644 --- a/llm_web_kit/exception/exception.py +++ b/llm_web_kit/exception/exception.py @@ -3,8 +3,6 @@ import commentjson as json -from llm_web_kit.input.datajson import DataJsonKey - class ErrorMsg: """Error message manager class.""" @@ -54,7 +52,7 @@ def __init__(self, custom_message: str | None = None, error_code: int | None = N self.error_code = error_code self.message = ErrorMsg.get_error_message(self.error_code) self.custom_message = custom_message - self.dataset_name = DataJsonKey.DATASET_NAME + self.dataset_name = '' super().__init__(self.message) frame = inspect.currentframe().f_back self.__py_filename = frame.f_code.co_filename diff --git a/llm_web_kit/extractor/extractor_chain.py b/llm_web_kit/extractor/extractor_chain.py index d02d17e0..0d601bcf 100644 --- a/llm_web_kit/extractor/extractor_chain.py +++ b/llm_web_kit/extractor/extractor_chain.py @@ -2,16 +2,17 @@ import commentjson as json -from llm_web_kit.exception.exception import (ExtractorChainConfigException, +from llm_web_kit.exception.exception import (ExtractorChainBaseException, + ExtractorChainConfigException, ExtractorChainInputException, ExtractorInitException, - ExtractorNotFoundException) + ExtractorNotFoundException, + LlmWebKitBaseException) from llm_web_kit.extractor.extractor import AbstractExtractor from llm_web_kit.extractor.post_extractor import AbstractPostExtractor from llm_web_kit.extractor.pre_extractor import AbstractPreExtractor from llm_web_kit.input.datajson import DataJson from llm_web_kit.libs.class_loader import load_python_class_by_name -from llm_web_kit.libs.logger import mylogger # ########################################################## @@ -55,11 +56,19 @@ def extract(self, data: DataJson) -> DataJson: data = post_ext.post_extract(data) except KeyError as e: - mylogger.error(f'Required field missing in input data: {str(e)}') - raise ExtractorChainInputException(f'Required field missing in input data: {str(e)}') - except Exception as e: - mylogger.error(f'Error during extraction: {str(e)}') + exc = ExtractorChainInputException(f'Required field missing: {str(e)}') + exc.dataset_name = data.get_dataset_name() + raise exc + except ExtractorChainBaseException as e: + e.dataset_name = data.get_dataset_name() + raise + except LlmWebKitBaseException as e: + e.dataset_name = data.get_dataset_name() raise + except Exception as e: + wrapped = ExtractorChainBaseException(f'Error during extraction: {str(e)}') + wrapped.dataset_name = data.get_dataset_name() + raise wrapped from e return data diff --git a/tests/llm_web_kit/exception/test_exception_data.py b/tests/llm_web_kit/exception/test_exception_data.py index 5f3ac069..b3ff1aa6 100644 --- a/tests/llm_web_kit/exception/test_exception_data.py +++ b/tests/llm_web_kit/exception/test_exception_data.py @@ -1,5 +1,6 @@ import unittest from pathlib import Path +from unittest.mock import patch from llm_web_kit.exception.exception import (CleanModelException, EbookFileExtractorException, @@ -199,6 +200,7 @@ def test_error_code_uniqueness(self): with open(json_path, 'r', encoding='utf-8') as f: import commentjson as json + data = json.load(f) for module in data.values(): @@ -206,3 +208,148 @@ def test_error_code_uniqueness(self): code = error_info['code'] self.assertNotIn(code, error_codes, f'Duplicate error code found: {code}') error_codes.add(code) + + def test_exception_dataset_name(self): + """Test dataset_name handling in exceptions.""" + # Test base exception initialization with empty dataset_name + base_exc = LlmWebKitBaseException('test message') + self.assertEqual(base_exc.dataset_name, '') + + # Test custom dataset_name assignment + base_exc.dataset_name = 'test_dataset' + self.assertEqual(base_exc.dataset_name, 'test_dataset') + + # Test dataset_name in child exceptions + chain_exc = ExtractorChainBaseException('chain error') + self.assertEqual(chain_exc.dataset_name, '') + chain_exc.dataset_name = 'chain_dataset' + self.assertEqual(chain_exc.dataset_name, 'chain_dataset') + + # Test dataset_name in concrete exceptions + test_cases = [ + (ExtractorInitException('init error'), 'init_dataset'), + (ExtractorChainInputException('input error'), 'input_dataset'), + (ExtractorChainConfigException('config error'), 'config_dataset'), + (ExtractorNotFoundException('not found error'), 'notfound_dataset'), + ] + + for exc, dataset_name in test_cases: + with self.subTest(exception_type=type(exc).__name__): + self.assertEqual(exc.dataset_name, '') + exc.dataset_name = dataset_name + self.assertEqual(exc.dataset_name, dataset_name) + + # Test exception handling when DataJson has no dataset_name + from llm_web_kit.extractor.extractor_chain import ExtractSimpleFactory + from llm_web_kit.input.datajson import DataJson + + config = { + 'extractor_pipe': { + 'pre_extractor': [ + { + 'enable': True, + 'python_class': 'llm_web_kit.extractor.html.pre_extractor.HTMLFileFormatFilterPreExtractor', + 'class_init_kwargs': {}, + } + ], + 'extractor': [ + { + 'enable': True, + 'python_class': 'llm_web_kit.extractor.html.extractor.HTMLFileFormatExtractor', + 'class_init_kwargs': {}, + } + ], + } + } + chain = ExtractSimpleFactory.create(config) + + input_data = DataJson( + { + 'dataset_name': 'test_dataset', + } + ) + + with self.assertRaises(ExtractorChainBaseException) as context: + chain.extract(input_data) + self.assertEqual(context.exception.dataset_name, 'test_dataset') + + @patch('llm_web_kit.libs.class_loader.load_python_class_by_name') + def test_extractor_chain_exceptions(self, mock_load_class): + """测试 ExtractorChain 中的异常处理机制.""" + from llm_web_kit.extractor.extractor_chain import ExtractSimpleFactory + from llm_web_kit.input.datajson import DataJson + + # 定义简单的 Mock 类,每个类负责抛出一种异常 + class KeyErrorExtractor: + def __init__(self, config, **kwargs): + pass + + def extract(self, data): + raise KeyError('test_key') + + class BaseExceptionExtractor: + def __init__(self, config, **kwargs): + pass + + def extract(self, data): + raise LlmWebKitBaseException('Base exception') + + class ChainExceptionExtractor: + def __init__(self, config, **kwargs): + pass + + def extract(self, data): + raise ExtractorChainBaseException('Chain exception') + + class GeneralExceptionExtractor: + def __init__(self, config, **kwargs): + pass + + def extract(self, data): + raise ValueError('General exception') + + mock_load_class.return_value = KeyErrorExtractor(None) + + # 基础配置 + config = { + 'extractor_pipe': { + 'pre_extractor': [ + { + 'enable': True, + 'python_class': 'llm_web_kit.extractor.html.pre_extractor.HTMLFileFormatFilterPreExtractor', + 'class_init_kwargs': {}, + } + ], + 'extractor': [ + { + 'enable': True, + 'python_class': 'llm_web_kit.extractor.html.extractor.HTMLFileFormatExtractor', + 'class_init_kwargs': {}, + } + ], + } + } + + # 测试数据 + data = DataJson({'dataset_name': 'test_dataset'}) + + # 测试场景 1: KeyError -> ExtractorChainInputException + chain = ExtractSimpleFactory.create(config) + with self.assertRaises(ExtractorChainInputException) as context: + chain.extract(data) + self.assertEqual(context.exception.dataset_name, 'test_dataset') + self.assertIn('Required field missing', str(context.exception)) + + # 测试场景 2: LlmWebKitBaseException 传递 + mock_load_class.return_value = BaseExceptionExtractor(None) + chain = ExtractSimpleFactory.create(config) + with self.assertRaises(LlmWebKitBaseException) as context: + chain.extract(data) + self.assertEqual(context.exception.dataset_name, 'test_dataset') + + # 测试场景 3: ExtractorChainBaseException 传递 + mock_load_class.return_value = ChainExceptionExtractor(None) + chain = ExtractSimpleFactory.create(config) + with self.assertRaises(ExtractorChainBaseException) as context: + chain.extract(data) + self.assertEqual(context.exception.dataset_name, 'test_dataset') diff --git a/tests/llm_web_kit/extractor/test_extractor_chain_normal.py b/tests/llm_web_kit/extractor/test_extractor_chain_normal.py index 8a9ab126..355ed50d 100644 --- a/tests/llm_web_kit/extractor/test_extractor_chain_normal.py +++ b/tests/llm_web_kit/extractor/test_extractor_chain_normal.py @@ -1,10 +1,16 @@ import json import os import unittest +from unittest.mock import MagicMock, patch -from llm_web_kit.exception.exception import (ExtractorChainInputException, - ExtractorNotFoundException) -from llm_web_kit.extractor.extractor_chain import ExtractSimpleFactory +from llm_web_kit.exception.exception import (ExtractorChainBaseException, + ExtractorChainConfigException, + ExtractorChainInputException, + ExtractorInitException, + ExtractorNotFoundException, + LlmWebKitBaseException) +from llm_web_kit.extractor.extractor_chain import (ExtractorChain, + ExtractSimpleFactory) from llm_web_kit.input.datajson import DataJson @@ -160,7 +166,11 @@ def test_error_handling(self): # Test invalid input type with self.assertRaises(ExtractorChainInputException): - chain.extract(DataJson({'data_source_category': 'html', 'html': '

Test

'})) + chain.extract(DataJson({ + 'dataset_name': 'test_dataset', # 添加 dataset_name + 'data_source_category': 'html', + 'html': '

Test

' + })) # Test invalid config invalid_config = {'extractor_pipe': {'extractor': [{'enable': True, 'python_class': 'non.existent.Extractor'}]}} @@ -182,4 +192,240 @@ def test_error_handling(self): # Test missing required fields with self.assertRaises(ExtractorChainInputException): - chain.extract(DataJson({'data_source_category': 'html'})) + chain.extract(DataJson({'data_source_category': 'html', 'dataset_name': 'test_dataset'})) + + def test_empty_config(self): + """测试空配置和禁用提取器.""" + # 测试完全空的配置 + chain = ExtractorChain({}) + self.assertEqual(len(chain._ExtractorChain__pre_extractors), 0) + self.assertEqual(len(chain._ExtractorChain__extractors), 0) + self.assertEqual(len(chain._ExtractorChain__post_extractors), 0) + + # 测试只有 extractor_pipe 但没有具体配置的情况 + chain = ExtractorChain({'extractor_pipe': {}}) + self.assertEqual(len(chain._ExtractorChain__pre_extractors), 0) + self.assertEqual(len(chain._ExtractorChain__extractors), 0) + self.assertEqual(len(chain._ExtractorChain__post_extractors), 0) + + # 测试禁用的提取器 + config = { + 'extractor_pipe': { + 'pre_extractor': [ + { + 'enable': False, + 'python_class': 'llm_web_kit.extractor.html.pre_extractor.HTMLFileFormatFilterPreExtractor', + 'class_init_kwargs': {}, + } + ], + 'extractor': [ + { + 'enable': False, + 'python_class': 'llm_web_kit.extractor.html.extractor.HTMLFileFormatExtractor', + 'class_init_kwargs': {}, + } + ], + 'post_extractor': [ + { + 'enable': False, + 'python_class': 'llm_web_kit.extractor.html.post_extractor.HTMLFileFormatPostExtractor', + 'class_init_kwargs': {}, + } + ] + } + } + chain = ExtractorChain(config) + self.assertEqual(len(chain._ExtractorChain__pre_extractors), 0) + self.assertEqual(len(chain._ExtractorChain__extractors), 0) + self.assertEqual(len(chain._ExtractorChain__post_extractors), 0) + + def test_config_errors(self): + """测试配置错误.""" + # 测试缺少 python_class 的情况 + config = { + 'extractor_pipe': { + 'extractor': [ + { + 'enable': True, + # 缺少 python_class + 'class_init_kwargs': {}, + } + ] + } + } + with self.assertRaises(ExtractorChainConfigException) as context: + ExtractorChain(config) + self.assertIn('python_class not specified', str(context.exception)) + + @patch('llm_web_kit.libs.class_loader.load_python_class_by_name') + def test_extractor_initialization_errors(self, mock_load): + """测试提取器初始化错误.""" + # 测试导入错误 + mock_load.side_effect = ImportError('Module not found') + + config = { + 'extractor_pipe': { + 'extractor': [ + { + 'enable': True, + 'python_class': 'llm_web_kit.extractor.html.extractor.NonExistentExtractor', + 'class_init_kwargs': {}, + } + ] + } + } + + with self.assertRaises(ExtractorChainBaseException) as context: + ExtractorChain(config) + self.assertIn('Failed to initialize extractor', str(context.exception)) + + # 重置 mock 并设置新的 side_effect + mock_load.reset_mock() + mock_load.side_effect = ValueError('Invalid configuration') + + with self.assertRaises(ExtractorInitException) as context: + ExtractorChain(config) + self.assertIn('Failed to initialize extractor', str(context.exception)) + + @patch('llm_web_kit.libs.class_loader.load_python_class_by_name') + def test_exception_handling_with_dataset_name(self, mock_load): + """测试异常处理中的 dataset_name 设置.""" + # 创建一个会抛出 KeyError 的 Mock 提取器 + mock_extractor = MagicMock() + mock_extractor.extract.side_effect = KeyError('required_field') + + # 直接设置 mock 返回值 + mock_load.return_value = mock_extractor + + config = { + 'extractor_pipe': { + 'extractor': [ + { + 'enable': True, + 'python_class': 'llm_web_kit.extractor.html.extractor.HTMLFileFormatExtractor', + 'class_init_kwargs': {}, + } + ] + } + } + + chain = ExtractorChain(config) + + # 测试有 dataset_name 的情况 + data = DataJson({'dataset_name': 'test_dataset'}) + with self.assertRaises(ExtractorChainInputException) as context: + chain.extract(data) + self.assertEqual(context.exception.dataset_name, 'test_dataset') + self.assertIn('Required field missing', str(context.exception)) + + def test_exception_propagation(self): + """测试不同类型异常的传播.""" + # 创建一个会抛出 LlmWebKitBaseException 的 Mock 提取器 + mock_base_error = MagicMock() + base_exception = LlmWebKitBaseException('Base error') + mock_base_error.extract.side_effect = base_exception + + # 创建一个会抛出 ExtractorChainBaseException 的 Mock 提取器 + mock_chain_error = MagicMock() + chain_exception = ExtractorChainBaseException('Chain error') + mock_chain_error.extract.side_effect = chain_exception + + # 创建一个会抛出一般异常的 Mock 提取器 + mock_general_error = MagicMock() + mock_general_error.extract.side_effect = ValueError('General error') + + # 创建一个测试用的 ExtractorChain 子类 + class TestExtractorChain(ExtractorChain): + """用于测试的 ExtractorChain 子类,使用类变量存储 mock 对象.""" + current_mock = None + + def __init__(self, config, mock_extractor): + # 先设置类变量 + TestExtractorChain.current_mock = mock_extractor + super().__init__(config) + + def _ExtractorChain__create_extractor(self, config): + return self.current_mock + + config = { + 'extractor_pipe': { + 'extractor': [ + { + 'enable': True, + 'python_class': 'llm_web_kit.extractor.html.extractor.HTMLFileFormatExtractor', + 'class_init_kwargs': {}, + } + ] + } + } + + # 创建包含所有必要字段的 DataJson 对象 + data = DataJson({ + 'dataset_name': 'test_dataset', + 'data_source_category': 'html', + 'html': '

Test

', + 'url': 'https://example.com' + }) + + # 测试 LlmWebKitBaseException 传播 + chain = TestExtractorChain(config, mock_base_error) + with self.assertRaises(LlmWebKitBaseException) as context: + chain.extract(data) + self.assertEqual(context.exception.dataset_name, 'test_dataset') + self.assertIsInstance(context.exception, LlmWebKitBaseException) + self.assertIn('Base error', str(context.exception)) + + # 测试 ExtractorChainBaseException 传播 + chain = TestExtractorChain(config, mock_chain_error) + with self.assertRaises(ExtractorChainBaseException) as context: + chain.extract(data) + self.assertEqual(context.exception.dataset_name, 'test_dataset') + self.assertIsInstance(context.exception, ExtractorChainBaseException) + self.assertIn('Chain error', str(context.exception)) + + # 测试一般异常包装为 ExtractorChainBaseException + chain = TestExtractorChain(config, mock_general_error) + with self.assertRaises(ExtractorChainBaseException) as context: + chain.extract(data) + self.assertEqual(context.exception.dataset_name, 'test_dataset') + self.assertIn('Error during extraction', str(context.exception)) + self.assertIsInstance(context.exception.__cause__, ValueError) + + def test_factory_method(self): + """测试工厂方法.""" + # 测试 ExtractSimpleFactory.create 方法 + config = self.html_config + chain = ExtractSimpleFactory.create(config) + self.assertIsInstance(chain, ExtractorChain) + + # 测试空配置 + chain = ExtractSimpleFactory.create({}) + self.assertIsInstance(chain, ExtractorChain) + self.assertEqual(len(chain._ExtractorChain__pre_extractors), 0) + self.assertEqual(len(chain._ExtractorChain__extractors), 0) + self.assertEqual(len(chain._ExtractorChain__post_extractors), 0) + + @patch('llm_web_kit.libs.class_loader.load_python_class_by_name') + def test_post_extractor_exceptions(self, mock_load): + """测试后处理阶段的异常处理.""" + # 创建一个正常的提取器 + mock_extractor = MagicMock() + mock_extractor.extract = lambda data: data + + # 创建会抛出 KeyError 的后处理器 + mock_key_error_post = MagicMock() + mock_key_error_post.post_extract.side_effect = KeyError('post_required_field') + + # 创建会抛出 ExtractorChainBaseException 的后处理器 + mock_chain_error_post = MagicMock() + chain_exception = ExtractorChainBaseException('Post chain error') + mock_chain_error_post.post_extract.side_effect = chain_exception + + # 创建会抛出 LlmWebKitBaseException 的后处理器 + mock_base_error_post = MagicMock() + base_exception = LlmWebKitBaseException('Post base error') + mock_base_error_post.post_extract.side_effect = base_exception + + # 创建会抛出一般异常的后处理器 + mock_general_error_post = MagicMock() + mock_general_error_post.post_extract.side_effect = ValueError('Post general error')