diff --git a/llm_web_kit/exception/exception.py b/llm_web_kit/exception/exception.py index 4c28c700..c3f4f5d1 100644 --- a/llm_web_kit/exception/exception.py +++ b/llm_web_kit/exception/exception.py @@ -3,8 +3,6 @@ import commentjson as json -from llm_web_kit.input.datajson import DataJsonKey - class ErrorMsg: """Error message manager class.""" @@ -54,7 +52,7 @@ def __init__(self, custom_message: str | None = None, error_code: int | None = N self.error_code = error_code self.message = ErrorMsg.get_error_message(self.error_code) self.custom_message = custom_message - self.dataset_name = DataJsonKey.DATASET_NAME + self.dataset_name = '' super().__init__(self.message) frame = inspect.currentframe().f_back self.__py_filename = frame.f_code.co_filename diff --git a/llm_web_kit/extractor/extractor_chain.py b/llm_web_kit/extractor/extractor_chain.py index d02d17e0..0d601bcf 100644 --- a/llm_web_kit/extractor/extractor_chain.py +++ b/llm_web_kit/extractor/extractor_chain.py @@ -2,16 +2,17 @@ import commentjson as json -from llm_web_kit.exception.exception import (ExtractorChainConfigException, +from llm_web_kit.exception.exception import (ExtractorChainBaseException, + ExtractorChainConfigException, ExtractorChainInputException, ExtractorInitException, - ExtractorNotFoundException) + ExtractorNotFoundException, + LlmWebKitBaseException) from llm_web_kit.extractor.extractor import AbstractExtractor from llm_web_kit.extractor.post_extractor import AbstractPostExtractor from llm_web_kit.extractor.pre_extractor import AbstractPreExtractor from llm_web_kit.input.datajson import DataJson from llm_web_kit.libs.class_loader import load_python_class_by_name -from llm_web_kit.libs.logger import mylogger # ########################################################## @@ -55,11 +56,19 @@ def extract(self, data: DataJson) -> DataJson: data = post_ext.post_extract(data) except KeyError as e: - mylogger.error(f'Required field missing in input data: {str(e)}') - raise ExtractorChainInputException(f'Required field missing in input data: {str(e)}') - except Exception as e: - mylogger.error(f'Error during extraction: {str(e)}') + exc = ExtractorChainInputException(f'Required field missing: {str(e)}') + exc.dataset_name = data.get_dataset_name() + raise exc + except ExtractorChainBaseException as e: + e.dataset_name = data.get_dataset_name() + raise + except LlmWebKitBaseException as e: + e.dataset_name = data.get_dataset_name() raise + except Exception as e: + wrapped = ExtractorChainBaseException(f'Error during extraction: {str(e)}') + wrapped.dataset_name = data.get_dataset_name() + raise wrapped from e return data diff --git a/tests/llm_web_kit/exception/test_exception_data.py b/tests/llm_web_kit/exception/test_exception_data.py index 5f3ac069..b3ff1aa6 100644 --- a/tests/llm_web_kit/exception/test_exception_data.py +++ b/tests/llm_web_kit/exception/test_exception_data.py @@ -1,5 +1,6 @@ import unittest from pathlib import Path +from unittest.mock import patch from llm_web_kit.exception.exception import (CleanModelException, EbookFileExtractorException, @@ -199,6 +200,7 @@ def test_error_code_uniqueness(self): with open(json_path, 'r', encoding='utf-8') as f: import commentjson as json + data = json.load(f) for module in data.values(): @@ -206,3 +208,148 @@ def test_error_code_uniqueness(self): code = error_info['code'] self.assertNotIn(code, error_codes, f'Duplicate error code found: {code}') error_codes.add(code) + + def test_exception_dataset_name(self): + """Test dataset_name handling in exceptions.""" + # Test base exception initialization with empty dataset_name + base_exc = LlmWebKitBaseException('test message') + self.assertEqual(base_exc.dataset_name, '') + + # Test custom dataset_name assignment + base_exc.dataset_name = 'test_dataset' + self.assertEqual(base_exc.dataset_name, 'test_dataset') + + # Test dataset_name in child exceptions + chain_exc = ExtractorChainBaseException('chain error') + self.assertEqual(chain_exc.dataset_name, '') + chain_exc.dataset_name = 'chain_dataset' + self.assertEqual(chain_exc.dataset_name, 'chain_dataset') + + # Test dataset_name in concrete exceptions + test_cases = [ + (ExtractorInitException('init error'), 'init_dataset'), + (ExtractorChainInputException('input error'), 'input_dataset'), + (ExtractorChainConfigException('config error'), 'config_dataset'), + (ExtractorNotFoundException('not found error'), 'notfound_dataset'), + ] + + for exc, dataset_name in test_cases: + with self.subTest(exception_type=type(exc).__name__): + self.assertEqual(exc.dataset_name, '') + exc.dataset_name = dataset_name + self.assertEqual(exc.dataset_name, dataset_name) + + # Test exception handling when DataJson has no dataset_name + from llm_web_kit.extractor.extractor_chain import ExtractSimpleFactory + from llm_web_kit.input.datajson import DataJson + + config = { + 'extractor_pipe': { + 'pre_extractor': [ + { + 'enable': True, + 'python_class': 'llm_web_kit.extractor.html.pre_extractor.HTMLFileFormatFilterPreExtractor', + 'class_init_kwargs': {}, + } + ], + 'extractor': [ + { + 'enable': True, + 'python_class': 'llm_web_kit.extractor.html.extractor.HTMLFileFormatExtractor', + 'class_init_kwargs': {}, + } + ], + } + } + chain = ExtractSimpleFactory.create(config) + + input_data = DataJson( + { + 'dataset_name': 'test_dataset', + } + ) + + with self.assertRaises(ExtractorChainBaseException) as context: + chain.extract(input_data) + self.assertEqual(context.exception.dataset_name, 'test_dataset') + + @patch('llm_web_kit.libs.class_loader.load_python_class_by_name') + def test_extractor_chain_exceptions(self, mock_load_class): + """测试 ExtractorChain 中的异常处理机制.""" + from llm_web_kit.extractor.extractor_chain import ExtractSimpleFactory + from llm_web_kit.input.datajson import DataJson + + # 定义简单的 Mock 类,每个类负责抛出一种异常 + class KeyErrorExtractor: + def __init__(self, config, **kwargs): + pass + + def extract(self, data): + raise KeyError('test_key') + + class BaseExceptionExtractor: + def __init__(self, config, **kwargs): + pass + + def extract(self, data): + raise LlmWebKitBaseException('Base exception') + + class ChainExceptionExtractor: + def __init__(self, config, **kwargs): + pass + + def extract(self, data): + raise ExtractorChainBaseException('Chain exception') + + class GeneralExceptionExtractor: + def __init__(self, config, **kwargs): + pass + + def extract(self, data): + raise ValueError('General exception') + + mock_load_class.return_value = KeyErrorExtractor(None) + + # 基础配置 + config = { + 'extractor_pipe': { + 'pre_extractor': [ + { + 'enable': True, + 'python_class': 'llm_web_kit.extractor.html.pre_extractor.HTMLFileFormatFilterPreExtractor', + 'class_init_kwargs': {}, + } + ], + 'extractor': [ + { + 'enable': True, + 'python_class': 'llm_web_kit.extractor.html.extractor.HTMLFileFormatExtractor', + 'class_init_kwargs': {}, + } + ], + } + } + + # 测试数据 + data = DataJson({'dataset_name': 'test_dataset'}) + + # 测试场景 1: KeyError -> ExtractorChainInputException + chain = ExtractSimpleFactory.create(config) + with self.assertRaises(ExtractorChainInputException) as context: + chain.extract(data) + self.assertEqual(context.exception.dataset_name, 'test_dataset') + self.assertIn('Required field missing', str(context.exception)) + + # 测试场景 2: LlmWebKitBaseException 传递 + mock_load_class.return_value = BaseExceptionExtractor(None) + chain = ExtractSimpleFactory.create(config) + with self.assertRaises(LlmWebKitBaseException) as context: + chain.extract(data) + self.assertEqual(context.exception.dataset_name, 'test_dataset') + + # 测试场景 3: ExtractorChainBaseException 传递 + mock_load_class.return_value = ChainExceptionExtractor(None) + chain = ExtractSimpleFactory.create(config) + with self.assertRaises(ExtractorChainBaseException) as context: + chain.extract(data) + self.assertEqual(context.exception.dataset_name, 'test_dataset') diff --git a/tests/llm_web_kit/extractor/test_extractor_chain_normal.py b/tests/llm_web_kit/extractor/test_extractor_chain_normal.py index 8a9ab126..355ed50d 100644 --- a/tests/llm_web_kit/extractor/test_extractor_chain_normal.py +++ b/tests/llm_web_kit/extractor/test_extractor_chain_normal.py @@ -1,10 +1,16 @@ import json import os import unittest +from unittest.mock import MagicMock, patch -from llm_web_kit.exception.exception import (ExtractorChainInputException, - ExtractorNotFoundException) -from llm_web_kit.extractor.extractor_chain import ExtractSimpleFactory +from llm_web_kit.exception.exception import (ExtractorChainBaseException, + ExtractorChainConfigException, + ExtractorChainInputException, + ExtractorInitException, + ExtractorNotFoundException, + LlmWebKitBaseException) +from llm_web_kit.extractor.extractor_chain import (ExtractorChain, + ExtractSimpleFactory) from llm_web_kit.input.datajson import DataJson @@ -160,7 +166,11 @@ def test_error_handling(self): # Test invalid input type with self.assertRaises(ExtractorChainInputException): - chain.extract(DataJson({'data_source_category': 'html', 'html': '