Skip to content
4 changes: 1 addition & 3 deletions llm_web_kit/exception/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

import commentjson as json

from llm_web_kit.input.datajson import DataJsonKey


class ErrorMsg:
"""Error message manager class."""
Expand Down Expand Up @@ -54,7 +52,7 @@ def __init__(self, custom_message: str | None = None, error_code: int | None = N
self.error_code = error_code
self.message = ErrorMsg.get_error_message(self.error_code)
self.custom_message = custom_message
self.dataset_name = DataJsonKey.DATASET_NAME
self.dataset_name = ''
super().__init__(self.message)
frame = inspect.currentframe().f_back
self.__py_filename = frame.f_code.co_filename
Expand Down
23 changes: 16 additions & 7 deletions llm_web_kit/extractor/extractor_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@

import commentjson as json

from llm_web_kit.exception.exception import (ExtractorChainConfigException,
from llm_web_kit.exception.exception import (ExtractorChainBaseException,
ExtractorChainConfigException,
ExtractorChainInputException,
ExtractorInitException,
ExtractorNotFoundException)
ExtractorNotFoundException,
LlmWebKitBaseException)
from llm_web_kit.extractor.extractor import AbstractExtractor
from llm_web_kit.extractor.post_extractor import AbstractPostExtractor
from llm_web_kit.extractor.pre_extractor import AbstractPreExtractor
from llm_web_kit.input.datajson import DataJson
from llm_web_kit.libs.class_loader import load_python_class_by_name
from llm_web_kit.libs.logger import mylogger


# ##########################################################
Expand Down Expand Up @@ -55,11 +56,19 @@ def extract(self, data: DataJson) -> DataJson:
data = post_ext.post_extract(data)

except KeyError as e:
mylogger.error(f'Required field missing in input data: {str(e)}')
raise ExtractorChainInputException(f'Required field missing in input data: {str(e)}')
except Exception as e:
mylogger.error(f'Error during extraction: {str(e)}')
exc = ExtractorChainInputException(f'Required field missing: {str(e)}')
exc.dataset_name = data.get_dataset_name()
raise exc
except ExtractorChainBaseException as e:
e.dataset_name = data.get_dataset_name()
raise
except LlmWebKitBaseException as e:
e.dataset_name = data.get_dataset_name()
raise
except Exception as e:
wrapped = ExtractorChainBaseException(f'Error during extraction: {str(e)}')
wrapped.dataset_name = data.get_dataset_name()
raise wrapped from e

return data

Expand Down
147 changes: 147 additions & 0 deletions tests/llm_web_kit/exception/test_exception_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
from pathlib import Path
from unittest.mock import patch

from llm_web_kit.exception.exception import (CleanModelException,
EbookFileExtractorException,
Expand Down Expand Up @@ -199,10 +200,156 @@ def test_error_code_uniqueness(self):

with open(json_path, 'r', encoding='utf-8') as f:
import commentjson as json

data = json.load(f)

for module in data.values():
for error_info in module.values():
code = error_info['code']
self.assertNotIn(code, error_codes, f'Duplicate error code found: {code}')
error_codes.add(code)

def test_exception_dataset_name(self):
"""Test dataset_name handling in exceptions."""
# Test base exception initialization with empty dataset_name
base_exc = LlmWebKitBaseException('test message')
self.assertEqual(base_exc.dataset_name, '')

# Test custom dataset_name assignment
base_exc.dataset_name = 'test_dataset'
self.assertEqual(base_exc.dataset_name, 'test_dataset')

# Test dataset_name in child exceptions
chain_exc = ExtractorChainBaseException('chain error')
self.assertEqual(chain_exc.dataset_name, '')
chain_exc.dataset_name = 'chain_dataset'
self.assertEqual(chain_exc.dataset_name, 'chain_dataset')

# Test dataset_name in concrete exceptions
test_cases = [
(ExtractorInitException('init error'), 'init_dataset'),
(ExtractorChainInputException('input error'), 'input_dataset'),
(ExtractorChainConfigException('config error'), 'config_dataset'),
(ExtractorNotFoundException('not found error'), 'notfound_dataset'),
]

for exc, dataset_name in test_cases:
with self.subTest(exception_type=type(exc).__name__):
self.assertEqual(exc.dataset_name, '')
exc.dataset_name = dataset_name
self.assertEqual(exc.dataset_name, dataset_name)

# Test exception handling when DataJson has no dataset_name
from llm_web_kit.extractor.extractor_chain import ExtractSimpleFactory
from llm_web_kit.input.datajson import DataJson

config = {
'extractor_pipe': {
'pre_extractor': [
{
'enable': True,
'python_class': 'llm_web_kit.extractor.html.pre_extractor.HTMLFileFormatFilterPreExtractor',
'class_init_kwargs': {},
}
],
'extractor': [
{
'enable': True,
'python_class': 'llm_web_kit.extractor.html.extractor.HTMLFileFormatExtractor',
'class_init_kwargs': {},
}
],
}
}
chain = ExtractSimpleFactory.create(config)

input_data = DataJson(
{
'dataset_name': 'test_dataset',
}
)

with self.assertRaises(ExtractorChainBaseException) as context:
chain.extract(input_data)
self.assertEqual(context.exception.dataset_name, 'test_dataset')

@patch('llm_web_kit.libs.class_loader.load_python_class_by_name')
def test_extractor_chain_exceptions(self, mock_load_class):
"""测试 ExtractorChain 中的异常处理机制."""
from llm_web_kit.extractor.extractor_chain import ExtractSimpleFactory
from llm_web_kit.input.datajson import DataJson

# 定义简单的 Mock 类,每个类负责抛出一种异常
class KeyErrorExtractor:
def __init__(self, config, **kwargs):
pass

def extract(self, data):
raise KeyError('test_key')

class BaseExceptionExtractor:
def __init__(self, config, **kwargs):
pass

def extract(self, data):
raise LlmWebKitBaseException('Base exception')

class ChainExceptionExtractor:
def __init__(self, config, **kwargs):
pass

def extract(self, data):
raise ExtractorChainBaseException('Chain exception')

class GeneralExceptionExtractor:
def __init__(self, config, **kwargs):
pass

def extract(self, data):
raise ValueError('General exception')

mock_load_class.return_value = KeyErrorExtractor(None)

# 基础配置
config = {
'extractor_pipe': {
'pre_extractor': [
{
'enable': True,
'python_class': 'llm_web_kit.extractor.html.pre_extractor.HTMLFileFormatFilterPreExtractor',
'class_init_kwargs': {},
}
],
'extractor': [
{
'enable': True,
'python_class': 'llm_web_kit.extractor.html.extractor.HTMLFileFormatExtractor',
'class_init_kwargs': {},
}
],
}
}

# 测试数据
data = DataJson({'dataset_name': 'test_dataset'})

# 测试场景 1: KeyError -> ExtractorChainInputException
chain = ExtractSimpleFactory.create(config)
with self.assertRaises(ExtractorChainInputException) as context:
chain.extract(data)
self.assertEqual(context.exception.dataset_name, 'test_dataset')
self.assertIn('Required field missing', str(context.exception))

# 测试场景 2: LlmWebKitBaseException 传递
mock_load_class.return_value = BaseExceptionExtractor(None)
chain = ExtractSimpleFactory.create(config)
with self.assertRaises(LlmWebKitBaseException) as context:
chain.extract(data)
self.assertEqual(context.exception.dataset_name, 'test_dataset')

# 测试场景 3: ExtractorChainBaseException 传递
mock_load_class.return_value = ChainExceptionExtractor(None)
chain = ExtractSimpleFactory.create(config)
with self.assertRaises(ExtractorChainBaseException) as context:
chain.extract(data)
self.assertEqual(context.exception.dataset_name, 'test_dataset')
Loading