Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
29 changes: 29 additions & 0 deletions llm_web_kit/html_layout_classify/html_layout_classify.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# html layout classify layout分类

## 环境

配置 .xinghe.yaml
配置 .llm_web_kit.jsonc

## 入参

layout_sample_dir: 每个layout_id 随机选取3条的.jsonl文件路径或文件夹路径
layout_classify_dir:计算每个layout_id 对应的分类结果文件夹路径

layout_sample_dir 字段说明:

| 字段 | 类型 | 描述 | 是否必须 |
| --------- | ------ | ---------------------------- | -------- |
| layout_id | string | layout id | 是 |
| url | string | 数据url | 是 |
| simp_html | string | html原数据经过简化处理的html | 是 |

layout_classify_dir 字段说明:

| 字段 | 类型 | 描述 | 是否必须 |
| ------------- | ------ | --------------------------------------------------------------- | -------- |
| url_list | list | layout id 对应的url | 是 |
| layout_id | string | layout id | 是 |
| page_type | string | layout_id 经过分类之后的分类结果('other', 'article', 'forum') | 是 |
| max_pred_prod | float | 分类模型的分类可靠度 | 是 |
| version | string | 模型版本 | 是 |
115 changes: 115 additions & 0 deletions llm_web_kit/html_layout_classify/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import argparse
import json

from loguru import logger

from llm_web_kit.html_layout_classify.s3.client import list_s3_objects
from llm_web_kit.html_layout_classify.s3.read import read_s3_rows
from llm_web_kit.html_layout_classify.s3.write import S3DocWriter
from llm_web_kit.model.html_layout_cls import HTMLLayoutClassifier

CLASSIFY_MAP = {'other': 0, 'article': 1, 'forum': 2}
INT_CLASSIFY_MAP = {0: 'other', 1: 'article', 2: 'forum'}
MODEL_VERESION = '0.0.2'


def __list_layout_sample_dir(s3_dir: str) -> list:
"""列出所有的layout sample json文件."""
if s3_dir.endswith('/'):
layout_sample_files = [f for f in list(list_s3_objects(s3_dir, recursive=True)) if f.endswith('.jsonl')]
return layout_sample_files
return [s3_dir]


def __parse_predict_res(predict_res: list, layout_samples: list) -> int:
"""解析模型分类结果."""
# [{'pred_prob': '0.626298', 'pred_label': 'other'}]
res = {
'url_list': [i['url'] for i in layout_samples],
'layout_id': layout_samples[0]['layout_id'],
'page_type': INT_CLASSIFY_MAP.get(
__most_frequent_or_zero([CLASSIFY_MAP.get(i['pred_label'], 0) for i in predict_res]), 'other'),
'max_pred_prod': max([i['pred_prob'] for i in predict_res]),
'version': MODEL_VERESION,
}
return res


def __most_frequent_or_zero(int_elements):
"""计算分类结果最多的类型,否则为0."""
if not int_elements:
return 0

elif len(int_elements) == 1:
return int_elements[0]

elif len(int_elements) == 2:
return int_elements[0] if int_elements[0] == int_elements[1] else 0

elif len(int_elements) == 3:
if int_elements[0] == int_elements[1] or int_elements[0] == int_elements[2]:
return int_elements[0]
elif int_elements[1] == int_elements[2]:
return int_elements[1]
else:
return 0
else:
logger.error(f"most_frequent_or_zero error:{int_elements}")


def __process_one_layout_sample(layout_sample_file: str, layout_type_dir: str):
"""处理一个layout的代表群体."""
output_file_path = f"{layout_type_dir}{layout_sample_file.split('/')[-1]}"
writer = S3DocWriter(output_file_path)

def __get_type_by_layoutid(layout_samples: list):
# html_str_input = [general_simplify_html_str(html['html_source']) for html in layout_samples]
html_str_input = [html['simp_html'] for html in layout_samples]
layout_classify_lst = model.predict(html_str_input)
layout_classify = __parse_predict_res(layout_classify_lst, layout_samples)
return layout_classify

current_layout_id, samples = None, []
idx = 0
for row in read_s3_rows(layout_sample_file):
idx += 1
detail_data = json.loads(row.value)
if current_layout_id == detail_data['layout_id']:
samples.append(detail_data)
else:
if samples:
classify_res = __get_type_by_layoutid(samples)
writer.write(classify_res)
current_layout_id, samples = detail_data['layout_id'], [detail_data]
if samples:
classify_res = __get_type_by_layoutid(samples)
writer.write(classify_res)
writer.flush()
logger.info(f'read {layout_sample_file} file {idx} rows')


def __set_config():
global model
model = HTMLLayoutClassifier()


def main():
parser = argparse.ArgumentParser(description='Process files with specified function.')
parser.add_argument('layout_sample_dir', help='待分类文件夹路径或文件路径')
parser.add_argument('layout_classify_dir', help='已分类结果输出路径')

args = parser.parse_args()

try:
# 加载模型
__set_config()
layout_sample_files = __list_layout_sample_dir(args.layout_sample_dir)
# 读取每个json文件的数据,根据每个layout_id为一簇,计算每个layout_id 对应的 layout_classify,并将结果写入s3
for layout_sample_file in layout_sample_files:
__process_one_layout_sample(layout_sample_file, args.layout_classify_dir)
except Exception as e:
logger.error(f'get layout classify fail: {e}')


if __name__ == '__main__':
main()
Empty file.
139 changes: 139 additions & 0 deletions llm_web_kit/html_layout_classify/s3/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import time
from typing import Dict, List, Union

import boto3
from boto3.s3.transfer import TransferConfig
from botocore.client import Config
from botocore.exceptions import ClientError

from .conf import get_s3_config
from .path import split_s3_path


def get_s3_client(path: Union[str, List[str]], outside=False):
s3_config = get_s3_config(path, outside)
try:
return boto3.client(
's3',
aws_access_key_id=s3_config['ak'],
aws_secret_access_key=s3_config['sk'],
endpoint_url=s3_config['endpoint'],
config=Config(
s3={'addressing_style': 'path'},
retries={'max_attempts': 8, 'mode': 'standard'},
connect_timeout=600,
read_timeout=600,
),
)
except Exception:
# older boto3 do not support retries.mode param.
return boto3.client(
's3',
aws_access_key_id=s3_config['ak'],
aws_secret_access_key=s3_config['sk'],
endpoint_url=s3_config['endpoint'],
config=Config(s3={'addressing_style': 'path'}, retries={'max_attempts': 8}),
)


def is_s3_404_error(e: Exception):
if not isinstance(e, ClientError):
return False
return (
e.response.get('Error', {}).get('Code') in ['404', 'NoSuchKey']
or e.response.get('Error', {}).get('Message') == 'Not Found'
or e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 404
)


def head_s3_object(path: str, raise_404=False, client=None) -> Union[Dict, None]:
client = client or get_s3_client(path)
bucket, key = split_s3_path(path)
try:
resp = client.head_object(Bucket=bucket, Key=key)
return resp
except ClientError as e:
if not raise_404 and is_s3_404_error(e):
return None
raise


def _restore_and_wait(client, bucket: str, key: str, path: str):
while True:
head = client.head_object(Bucket=bucket, Key=key)
restore = head.get('Restore', '')
if not restore:
req = {'Days': 1, 'GlacierJobParameters': {'Tier': 'Standard'}}
client.restore_object(Bucket=bucket, Key=key, RestoreRequest=req)
print(f"restoration-started: {path}")
elif 'ongoing-request="true"' in restore:
print(f"restoration-ongoing: {path}")
elif 'ongoing-request="false"' in restore:
print(f"restoration-complete: {path}")
break
time.sleep(3)


def get_s3_object(path: str, client=None, **kwargs) -> dict:
client = client or get_s3_client(path)
bucket, key = split_s3_path(path)
try:
return client.get_object(Bucket=bucket, Key=key, **kwargs)
except ClientError as e:
if e.response.get('Error', {}).get('Code') == 'GlacierObjectNotRestore':
_restore_and_wait(client, bucket, key, path)
return client.get_object(Bucket=bucket, Key=key, **kwargs)
raise


def list_s3_objects(path: str, recursive=False, is_prefix=False, limit=0, client=None):
for content in list_s3_objects_detailed(path, recursive, is_prefix, limit, client=client):
yield content[0]


def list_s3_objects_detailed(path: str, recursive=False, is_prefix=False, limit=0, client=None):
client = client or get_s3_client(path)
if limit > 1000:
raise Exception('limit greater than 1000 is not supported.')
if not path.endswith('/') and not is_prefix:
path += '/'
bucket, prefix = split_s3_path(path)
marker = None
while True:
list_kwargs = dict(MaxKeys=1000, Bucket=bucket, Prefix=prefix)
if limit > 0:
list_kwargs['MaxKeys'] = limit
if not recursive:
list_kwargs['Delimiter'] = '/'
if marker:
list_kwargs['Marker'] = marker
response = client.list_objects(**list_kwargs)
marker = None
if not recursive:
common_prefixes = response.get('CommonPrefixes', [])
for cp in common_prefixes:
yield (f"s3://{bucket}/{cp['Prefix']}", cp)
if common_prefixes:
marker = common_prefixes[-1]['Prefix']
contents = response.get('Contents', [])
for content in contents:
if not content['Key'].endswith('/'):
yield (f"s3://{bucket}/{content['Key']}", content)
if contents:
last_key = contents[-1]['Key']
if not marker or last_key > marker:
marker = last_key
if limit or not response.get('IsTruncated') or not marker:
break


def upload_s3_object(path: str, local_file_path: str, client=None):
client = client or get_s3_client(path)
# upload
MB = 1024 ** 2
config = TransferConfig(
multipart_threshold=128 * MB,
multipart_chunksize=16 * MB, # 156.25GiB maximum
)
bucket, key = split_s3_path(path)
client.upload_file(local_file_path, bucket, key, Config=config)
Loading
Loading