diff --git a/llm_web_kit/extractor/html/recognizer/list.py b/llm_web_kit/extractor/html/recognizer/list.py
index d564d41e..7694ba1a 100644
--- a/llm_web_kit/extractor/html/recognizer/list.py
+++ b/llm_web_kit/extractor/html/recognizer/list.py
@@ -4,6 +4,7 @@
from lxml.etree import _Element as HtmlElement
from overrides import override
+from llm_web_kit.exception.exception import HtmlListRecognizerException
from llm_web_kit.extractor.html.recognizer.recognizer import (
BaseHTMLElementRecognizer, CCTag)
from llm_web_kit.libs.doc_element_type import DocElementType, ParagraphTextType
@@ -22,13 +23,14 @@ def to_content_list_node(self, base_url: str, parsed_content: str, raw_html_segm
Returns:
"""
- ordered, content_list, _ = self.__get_attribute(parsed_content)
+ ordered, content_list, _, list_nest_level = self.__get_attribute(parsed_content)
ele_node = {
'type': DocElementType.LIST,
'raw_content': raw_html_segment,
'content': {
'items': content_list,
- 'ordered': ordered
+ 'ordered': ordered,
+ 'list_nest_level': list_nest_level
}
}
@@ -148,12 +150,35 @@ def __extract_list_element(self, ele: HtmlElement) -> tuple[int, bool, list[list
return list_nest_level, is_ordered, content_list, raw_html, tail_text
def __get_list_type(self, list_ele:HtmlElement) -> int:
- """获取list嵌套的类型."""
- if list_ele.tag not in ['ul', 'ol', 'dl', 'menu', 'dir']:
- return 0
- ancestor_count = list_ele.xpath('count(ancestor::ul | ancestor::ol)')
- # 层级 = 祖先列表数量 + 自身(1层)
- return int(ancestor_count) + 1
+ """获取list嵌套的层级。
+
+ 计算一个列表元素的最大嵌套深度,通过递归遍历所有子元素。
+ 例如:
+ - 没有嵌套的列表返回1
+ - 有一层嵌套的列表返回2
+ - 有两层嵌套的列表返回3
+
+ Args:
+ list_ele: 列表HTML元素
+
+ Returns:
+ int: 列表的最大嵌套深度
+ """
+ list_type = ['ul', 'ol', 'dl', 'menu', 'dir']
+
+ def get_max_depth(element):
+ max_child_depth = 0
+ for child in element.iterchildren():
+ if child.tag in list_type:
+ # 找到嵌套列表,其深度至少为1
+ child_depth = 1 + get_max_depth(child)
+ max_child_depth = max(max_child_depth, child_depth)
+ else:
+ # 对非列表元素递归检查其子元素
+ child_depth = get_max_depth(child)
+ max_child_depth = max(max_child_depth, child_depth)
+ return max_child_depth
+ return get_max_depth(list_ele) + 1
def __extract_list_item_text(self, root:HtmlElement) -> list[list]:
"""提取列表项的文本.
@@ -208,7 +233,7 @@ def __get_attribute(self, html:str) -> Tuple[bool, dict, str]:
ordered = ele.attrib.get('ordered', 'False') in ['True', 'true']
content_list = json.loads(ele.text)
raw_html = ele.attrib.get('html')
- return ordered, content_list, raw_html
+ list_nest_level = ele.attrib.get('list_nest_level', 0)
+ return ordered, content_list, raw_html, list_nest_level
else:
- # TODO 抛出异常, 需要自定义
- raise ValueError(f'{html}中没有cctitle标签')
+ raise HtmlListRecognizerException(f'{html}中没有cctitle标签')
diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/list_nest_three.html b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/list_nest_three.html
new file mode 100644
index 00000000..018a85ab
--- /dev/null
+++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html/list_nest_three.html
@@ -0,0 +1,30 @@
+
+
+ - 外层列表项
+ -
+
+
+ - 第二层列表项
+
+
+
+ - 第二层其他项
+
+
+
+ - 外层另一个列表项
+ -
+
+
+
+
\ No newline at end of file
diff --git a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html_data_input.jsonl b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html_data_input.jsonl
index 6c191184..857adf62 100644
--- a/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html_data_input.jsonl
+++ b/tests/llm_web_kit/extractor/assets/extractor_chain_input/good_data/html_data_input.jsonl
@@ -15,4 +15,5 @@
{"track_id": "list_empty", "dataset_name": "test_list_empty", "url": "https://productcenter.ru/products/27276/naturalnoie-krymskoie-mylo-ruchnoi-raboty-39-raznovidnostiei","data_source_category": "HTML", "path":"test_list_empty.html", "file_bytes": 1000, "meta_info": {"input_datetime": "2020-01-01 00:00:00"}}
{"track_id": "table_include_math_p", "dataset_name": "table_include_math_p", "url": "https://math.stackexchange.com/questions/458323/is-8327-1-a-prime-number?answertab=active","data_source_category": "HTML", "path":"table_include_math_p.html", "file_bytes": 1000, "meta_info": {"input_datetime": "2020-01-01 00:00:00"}}
{"track_id": "table_include_table_math", "dataset_name": "table_include_table_math", "url": "https://test","data_source_category": "HTML", "path":"table_include_table_math.html", "file_bytes": 1000, "meta_info": {"input_datetime": "2020-01-01 00:00:00"}}
-{"track_id": "test_clean_tags", "dataset_name": "test_pipeline_suit", "url": "https://math.stackexchange.com/questions/4082284/solving-for-vector-contained-in-a-diagonal-matrix","data_source_category": "HTML", "path":"test_clean_tags.html", "file_bytes": 1000, "page_layout_type":"forum", "meta_info": {"input_datetime": "2020-01-01 00:00:00"}}
\ No newline at end of file
+{"track_id": "test_clean_tags", "dataset_name": "test_pipeline_suit", "url": "https://math.stackexchange.com/questions/4082284/solving-for-vector-contained-in-a-diagonal-matrix","data_source_category": "HTML", "path":"test_clean_tags.html", "file_bytes": 1000, "page_layout_type":"forum", "meta_info": {"input_datetime": "2020-01-01 00:00:00"}}
+{"track_id": "list_nest_three", "dataset_name": "list_nest_three", "url": "http://test.com","data_source_category": "HTML", "path":"list_nest_three.html", "file_bytes": 1000, "page_layout_type":"forum", "meta_info": {"input_datetime": "2020-01-01 00:00:00"}}
\ No newline at end of file
diff --git a/tests/llm_web_kit/extractor/test_extractor_chain.py b/tests/llm_web_kit/extractor/test_extractor_chain.py
index 40f9c9a5..b58b6964 100644
--- a/tests/llm_web_kit/extractor/test_extractor_chain.py
+++ b/tests/llm_web_kit/extractor/test_extractor_chain.py
@@ -59,7 +59,7 @@ def setUp(self):
for line in f:
self.data_json.append(json.loads(line.strip()))
- assert len(self.data_json) == 18
+ assert len(self.data_json) == 19
# Config for HTML extraction
self.config = {
@@ -434,3 +434,13 @@ def test_clean_tags(self):
result = chain.extract(input_data)
content_md = result.get_content_list().to_mm_md()
self.assertNotIn('begingroup', content_md)
+
+ def test_list_nest_three(self):
+ """测试列表嵌套三层."""
+ chain = ExtractSimpleFactory.create(self.config)
+ self.assertIsNotNone(chain)
+ test_data = self.data_json[18]
+ input_data = DataJson(test_data)
+ result = chain.extract(input_data)
+ result_content_list = result.get_content_list()._get_data()
+ assert int(result_content_list[0][0]['content']['list_nest_level']) == 3