diff --git a/llm_web_kit/extractor/html/recognizer/image.py b/llm_web_kit/extractor/html/recognizer/image.py
index 30f8241d..53f612dc 100644
--- a/llm_web_kit/extractor/html/recognizer/image.py
+++ b/llm_web_kit/extractor/html/recognizer/image.py
@@ -157,8 +157,7 @@ def __parse_img_elements(self, base_url: str, img_elements: HtmlElement, html_ob
'html': raw_img_html, # 保留原始
标签作为属性值
'format': 'url', # 指定图片格式,url|base
}
- if elem.text and elem.text.strip():
- attributes['caption'] = elem.text.strip()
+ attributes['caption'] = elem.xpath('normalize-space()')
if tag in ['embed', 'object', 'iframe', 'video', 'audio', 'canvas']:
if not [img_elem for img_elem in self.IMG_LABEL if
img_elem in raw_img_html.lower()]:
diff --git a/llm_web_kit/input/datajson.py b/llm_web_kit/input/datajson.py
index 3ded4c2b..bf970f93 100644
--- a/llm_web_kit/input/datajson.py
+++ b/llm_web_kit/input/datajson.py
@@ -298,7 +298,7 @@ def __content_lst_node_2_md(self, content_lst_node: dict, exclude_inline_types:
else:
image_caption = ''
- image_des = image_title if image_title else image_caption if image_caption else ''
+ image_des = image_title if image_title else ''
# 优先使用data, 其次path.其中data是base64编码的图片,path是图片的url
if image_data:
if image_des:
@@ -310,7 +310,13 @@ def __content_lst_node_2_md(self, content_lst_node: dict, exclude_inline_types:
image = f''
else:
image = f''
- return image
+
+ if image_caption:
+ image_with_caption = f'{image}\n\n{image_caption}'
+ else:
+ image_with_caption = image
+
+ return image_with_caption
elif node_type == DocElementType.AUDIO:
return '' # TODO: 音频格式
elif node_type == DocElementType.VIDEO:
diff --git a/tests/llm_web_kit/extractor/html/recognizer/test_image.py b/tests/llm_web_kit/extractor/html/recognizer/test_image.py
index 73a4c3a0..6396c6d1 100644
--- a/tests/llm_web_kit/extractor/html/recognizer/test_image.py
+++ b/tests/llm_web_kit/extractor/html/recognizer/test_image.py
@@ -350,3 +350,36 @@ def test_complex_heading_image_removal(self):
img_in_p.extend(p.xpath('.//img'))
self.assertEqual(len(img_in_p), 0, '段落中不应该有img标签')
+
+ def test_image_caption(self):
+ complex_html = """
+
+
+ Roger Moore in
+
+ For
+ Your Eyes Only
+ . Photo Courtesy: United
+ Artists/Everett Collection
+
+
+
+ """
+ element = html_to_element(complex_html)
+ base_url = 'http://example.com'
+ parts = self.img_recognizer.recognize(base_url, [(element, element)], complex_html)
+ html = element_to_html(parts[0][0])
+ self.assertIn('caption="Roger Moore in For Your Eyes Only . Photo Courtesy: United Artists/Everett Collection', html)
diff --git a/tests/llm_web_kit/extractor/test_extractor_chain.py b/tests/llm_web_kit/extractor/test_extractor_chain.py
index a5f68540..5255efe9 100644
--- a/tests/llm_web_kit/extractor/test_extractor_chain.py
+++ b/tests/llm_web_kit/extractor/test_extractor_chain.py
@@ -105,7 +105,7 @@ def test_html_pipeline(self):
self.assertEqual(html_content['content']['title'], 'image-title')
self.assertEqual(html_content['content']['alt'], 'image-alt')
self.assertEqual(html_content['content']['url'], 'https://www.test.com/test.png')
- self.assertEqual(html_content['content']['caption'], None)
+ self.assertEqual(html_content['content']['caption'], '')
# 然后是simple table
html_content = html_content_list[4]