From dcdb988bc59ae9fbf4188d3b966a7f1421c1da69 Mon Sep 17 00:00:00 2001 From: drunkpig Date: Thu, 13 Mar 2025 15:12:11 +0800 Subject: [PATCH 1/4] feat: page classify --- .gitignore | 1 + .../html_layout_classify/classify-spot.sh | 142 ++++++++++ llm_web_kit/html_layout_classify/classify.sh | 107 ++++++++ llm_web_kit/html_layout_classify/main.py | 229 ++++++++-------- .../{html_layout_classify.md => readme.md} | 20 +- llm_web_kit/html_layout_classify/server.py | 250 ++++++++++++++++++ requirements/dev.txt | 3 + 7 files changed, 644 insertions(+), 108 deletions(-) create mode 100755 llm_web_kit/html_layout_classify/classify-spot.sh create mode 100755 llm_web_kit/html_layout_classify/classify.sh rename llm_web_kit/html_layout_classify/{html_layout_classify.md => readme.md} (62%) create mode 100644 llm_web_kit/html_layout_classify/server.py diff --git a/.gitignore b/.gitignore index fddaf477..f5183022 100644 --- a/.gitignore +++ b/.gitignore @@ -47,3 +47,4 @@ coverage.xml llm_web_kit.egg-info/* .llm-web-kit.jsonc +.llm-web-kit-pageclassify.jsonc diff --git a/llm_web_kit/html_layout_classify/classify-spot.sh b/llm_web_kit/html_layout_classify/classify-spot.sh new file mode 100755 index 00000000..b679091a --- /dev/null +++ b/llm_web_kit/html_layout_classify/classify-spot.sh @@ -0,0 +1,142 @@ +#!/bin/bash + +command -v proxyoff >/dev/null 2>&1 && proxyoff +command -v proxy_off >/dev/null 2>&1 && proxy_off + + +function count_used_gpus(){ + all_jobs=`squeue --me -p $1` + + gpu_num=0 + for name in $all_jobs + do + if [ "$(echo $name | grep "gpu:")" != "" ];then + num="${name//gpu:/}" + gpu_num=$((($gpu_num+$num))) + fi + done + echo $gpu_num +} + + +# 函数:获取当前用户所有处于PD状态的任务数量 +get_pd_count() { + squeue -u "$USER" -t PD -h |grep spot | wc -l +} + +# 定义一个函数来计算 SPOT_USED 的总和 +calculate_total_spot_used() { + # 执行 svp list 并获取输出 + local svp_output=$(svp list) + + # 使用 awk 解析并计算 SPOT_USED 列的总和 + local total_spot_used=$(echo "$svp_output" | awk ' + NR == 1 {next} # 跳过标题行 + { + sum += $6 # 假设 SPOT_USED 是第6列 + } + END { + print sum + }') + + # 返回结果 + echo $total_spot_used +} + +calculate_total_reserved_idle() { + # 执行 svp list 并获取输出 + local svp_output=$(svp list) + + #总和 + local total_reserved_idle=$(echo "$svp_output" | awk ' + NR == 1 {next} + { + sum += $5 + } + END { + print sum + }') + # 返回结果 + echo $total_reserved_idle +} + +####################################################################################### +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --partation) + PARTATION="$2" + shift 2 + ;; + --tag) + TAG="$2" + shift 2 + ;; + --task-num) + TASK_NUM="$2" + shift 2 + ;; + --debug) + DEBUG=1 + shift 1 + ;; + --server-addr) + SERVER_ADDR="$2" + shift 2 + ;; + --result-save-dir) + RESULT_SAVE_DIR="$2" + shift 2 + ;; + *) + echo "Unknown argument: $1" + exit 1 + ;; + esac +done + +MAX_PENDING_JOBS=10 # 用户pending任务数量,不能超过这个值 +MAX_JOBS=1000 # 用户最大提交任务数量 +MY_NAME="${USER}" # 用户名 + +MY_HOME=$(echo $HOME) +SLURM_LOG_DIR=${MY_HOME}/slum-logs/${TAG} +# 创建日志目录(如果不存在) +mkdir -p ${SLURM_LOG_DIR}/logs +mkdir -p ${SLURM_LOG_DIR}/error +export SLURM_SUBMIT_DIR=${SLURM_LOG_DIR} +export LLM_WEB_KIT_CFG_PATH=/share/xuchao/.llm-web-kit-pageclassify.jsonc +TASK_NUM="${TASK_NUM:-1}" # Default to 1 if not provided + + +# Check required arguments +if [ -z "$PARTATION" ] || [ -z "$TAG" ]; then + echo "Usage: $0 --partation --tag " + exit 1 +fi + +# 核心思路是只要不超过最大的pending任务数量,就一直提交任务 +while true +do + for partation in "${PARTATION[@]}"; do + PD_COUNT=$(get_pd_count) + spot_count=$(squeue -u ${MY_NAME} | grep -i spot |wc -l) + + if [ "$PD_COUNT" -lt "$MAX_PENDING_JOBS" ] && [ $spot_count -lt $MAX_JOBS ]; then + # 如果PD任务数小于最大限制,则提交新任务 + # tt=$(date '+%Y-%m-%d %H:%M:%S') + # total_spot_used=$(calculate_total_spot_used) + # total_reserved_idle=$(calculate_total_reserved_idle) + # echo -e "check $partation spot \n tt:$tt \n total_spot_used: $total_spot_used\n total_reserved_idle: $total_reserved_idle \n PD_COUNT: $PD_COUNT" + if [ $DEBUG -eq 1 ]; then + LOG_LEVEL=ERROR srun -p ${partation} --quotatype=spot --output=${SLURM_LOG_DIR}/logs/output_%j.out --export=ALL --cpus-per-task=${TASK_NUM} --error=${SLURM_LOG_DIR}/error/error_%j.err -N ${TASK_NUM} --gres=gpu:1 python main.py ${SERVER_ADDR} --result-save-dir ${RESULT_SAVE_DIR} + else + LOG_LEVEL=ERROR srun -p ${partation} --quotatype=spot --output=${SLURM_LOG_DIR}/logs/output_%j.out --export=ALL --cpus-per-task=${TASK_NUM} --error=${SLURM_LOG_DIR}/error/error_%j.err -N ${TASK_NUM} --gres=gpu:1 --async python main.py ${SERVER_ADDR} --result-save-dir ${RESULT_SAVE_DIR} + fi + echo "use ${partation} submit job succ, submit next job now..." + rm batchscript* 2>/dev/null + fi + break + done # for + sleep 20 +done # while diff --git a/llm_web_kit/html_layout_classify/classify.sh b/llm_web_kit/html_layout_classify/classify.sh new file mode 100755 index 00000000..32a98156 --- /dev/null +++ b/llm_web_kit/html_layout_classify/classify.sh @@ -0,0 +1,107 @@ +#! /bin/bash + +command -v proxyoff >/dev/null 2>&1 && proxyoff +command -v proxy_off >/dev/null 2>&1 && proxy_off + +function count_used_gpus(){ + all_jobs=`squeue --me -p $1` + + gpu_num=0 + for name in $all_jobs + do + if [ "$(echo $name | grep "gpu:")" != "" ];then + num="${name//gpu:/}" + gpu_num=$((($gpu_num+$num))) + fi + done + echo $gpu_num +} + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --partation) + PARTATION="$2" + shift 2 + ;; + --max-job) + MAX_JOB_TOTAL="$2" + shift 2 + ;; + --tag) + TAG="$2" + shift 2 + ;; + --task-num) + TASK_NUM="$2" + shift 2 + ;; + --debug) + DEBUG=1 + shift 1 + ;; + --result-save-dir) + RESULT_SAVE_DIR="$2" + shift 2 + ;; + --server-addr) + SERVER_ADDR="$2" + shift 2 + ;; + *) + echo "Unknown argument: $1" + exit 1 + ;; + esac +done + + +MY_HOME=$(echo $HOME) +MY_NAME="${USER}" # 用户名 +SLURM_LOG_DIR=${MY_HOME}/slum-logs/${TAG} +# 创建日志目录(如果不存在) +mkdir -p ${SLURM_LOG_DIR}/logs +mkdir -p ${SLURM_LOG_DIR}/error +export SLURM_SUBMIT_DIR=${SLURM_LOG_DIR} +export LLM_WEB_KIT_CFG_PATH=/share/${MY_NAME}/.llm-web-kit-pageclassify.jsonc +TASK_NUM="${TASK_NUM:-1}" # Default to 1 if not provided +SERVER_ADDR="${SERVER_ADDR:-http://127.0.0.1:5000}" +PYTHON=/share/${MY_NAME}/.conda/envs/webkitdev/bin/python + + +# Check required arguments +if [ -z "$PARTATION" ] || [ -z "$MAX_JOB_TOTAL" ] || [ -z "$TAG" ]; then + echo "Usage: $0 --partation --max-job --tag --debug " + exit 1 +fi + + +submited_job_num=0 # 成功提交的任务数 + +while [ $submited_job_num -lt $MAX_JOB_TOTAL ] +do + used_gpu=($(count_used_gpus $PARTATION)) # 分区中自己已使用的GPU数 + avai_gpu=$(svp list -p $PARTATION|grep $PARTATION | awk '{print $5}') # 分区中可用的GPU数 + echo -e "check partation $PARTATION \n used_gpu: $used_gpu\n avai_gpu: $avai_gpu" + + if [ $avai_gpu -gt 0 ]; then + # 提交一个任务,睡眠 + if [ $DEBUG -eq 1 ]; then + LOG_LEVEL=INFO srun -p ${PARTATION} --output=${SLURM_LOG_DIR}/logs/output_%j.out --export=ALL --cpus-per-task=${TASK_NUM} --error=${SLURM_LOG_DIR}/error/error_%j.err --gres=gpu:1 -N ${TASK_NUM} ${PYTHON} main.py --server-addr ${SERVER_ADDR} --result-save-dir ${RESULT_SAVE_DIR} + else + + LOG_LEVEL=ERROR srun -p ${PARTATION} --output=${SLURM_LOG_DIR}/logs/output_%j.out --export=ALL --cpus-per-task=${TASK_NUM} --error=${SLURM_LOG_DIR}/error/error_%j.err --gres=gpu:1 --async -N ${TASK_NUM} ${PYTHON} main.py --server-addr ${SERVER_ADDR} --result-save-dir ${RESULT_SAVE_DIR} + fi + # TODO 判断任务是否提交成功 + submited_job_num=$((submited_job_num+1)) + sleep 2 + echo "use ${PARTATION} submit job succ, submit next job now..." + rm batchscript* 2>/dev/null + else + echo "skip ${PARTATION}, used_GPU = ${used_gpu}, no available GPU" + sleep 2 + fi + +done # while + +echo "任务提交完成" diff --git a/llm_web_kit/html_layout_classify/main.py b/llm_web_kit/html_layout_classify/main.py index fd306c78..59fe046f 100644 --- a/llm_web_kit/html_layout_classify/main.py +++ b/llm_web_kit/html_layout_classify/main.py @@ -1,114 +1,137 @@ -import argparse import json - +import os +import socket +import time +from io import BytesIO +from pathlib import Path + +import click +import requests from loguru import logger +from retry import retry -from llm_web_kit.html_layout_classify.s3.client import list_s3_objects -from llm_web_kit.html_layout_classify.s3.read import read_s3_rows -from llm_web_kit.html_layout_classify.s3.write import S3DocWriter from llm_web_kit.model.html_layout_cls import HTMLLayoutClassifier CLASSIFY_MAP = {'other': 0, 'article': 1, 'forum': 2} INT_CLASSIFY_MAP = {0: 'other', 1: 'article', 2: 'forum'} MODEL_VERESION = '0.0.2' - - -def __list_layout_sample_dir(s3_dir: str) -> list: - """列出所有的layout sample json文件.""" - if s3_dir.endswith('/'): - layout_sample_files = [f for f in list(list_s3_objects(s3_dir, recursive=True)) if f.endswith('.jsonl')] - return layout_sample_files - return [s3_dir] - - -def __parse_predict_res(predict_res: list, layout_samples: list) -> int: - """解析模型分类结果.""" - # [{'pred_prob': '0.626298', 'pred_label': 'other'}] - res = { - 'url_list': [i['url'] for i in layout_samples], - 'layout_id': layout_samples[0]['layout_id'], - 'page_type': INT_CLASSIFY_MAP.get( - __most_frequent_or_zero([CLASSIFY_MAP.get(i['pred_label'], 0) for i in predict_res]), 'other'), - 'max_pred_prod': max([i['pred_prob'] for i in predict_res]), - 'version': MODEL_VERESION, - } - return res - - -def __most_frequent_or_zero(int_elements): - """计算分类结果最多的类型,否则为0.""" - if not int_elements: - return 0 - - elif len(int_elements) == 1: - return int_elements[0] - - elif len(int_elements) == 2: - return int_elements[0] if int_elements[0] == int_elements[1] else 0 - - elif len(int_elements) == 3: - if int_elements[0] == int_elements[1] or int_elements[0] == int_elements[2]: - return int_elements[0] - elif int_elements[1] == int_elements[2]: - return int_elements[1] - else: - return 0 - else: - logger.error(f'most_frequent_or_zero error:{int_elements}') - - -def __process_one_layout_sample(layout_sample_file: str, layout_type_dir: str): - """处理一个layout的代表群体.""" - output_file_path = f"{layout_type_dir}{layout_sample_file.split('/')[-1]}" - writer = S3DocWriter(output_file_path) - - def __get_type_by_layoutid(layout_samples: list): - # html_str_input = [general_simplify_html_str(html['html_source']) for html in layout_samples] - html_str_input = [html['simp_html'] for html in layout_samples] - layout_classify_lst = model.predict(html_str_input) - layout_classify = __parse_predict_res(layout_classify_lst, layout_samples) - return layout_classify - - current_layout_id, samples = None, [] - idx = 0 - for row in read_s3_rows(layout_sample_file): - idx += 1 - detail_data = json.loads(row.value) - if current_layout_id == detail_data['layout_id']: - samples.append(detail_data) - else: - if samples: - classify_res = __get_type_by_layoutid(samples) - writer.write(classify_res) - current_layout_id, samples = detail_data['layout_id'], [detail_data] - if samples: - classify_res = __get_type_by_layoutid(samples) - writer.write(classify_res) - writer.flush() - logger.info(f'read {layout_sample_file} file {idx} rows') - - -def __set_config(): - global model - model = HTMLLayoutClassifier() - - -def main(): - parser = argparse.ArgumentParser(description='Process files with specified function.') - parser.add_argument('layout_sample_dir', help='待分类文件夹路径或文件路径') - parser.add_argument('layout_classify_dir', help='已分类结果输出路径') - - args = parser.parse_args() - - try: - # 加载模型 - __set_config() - layout_sample_files = __list_layout_sample_dir(args.layout_sample_dir) - # 读取每个json文件的数据,根据每个layout_id为一簇,计算每个layout_id 对应的 layout_classify,并将结果写入s3 - for layout_sample_file in layout_sample_files: - __process_one_layout_sample(layout_sample_file, args.layout_classify_dir) - except Exception as e: - logger.error(f'get layout classify fail: {e}') +MODEL = None +GET_FILE_URL = None +UPDATE_STATUS_URL = None + + +def __get_runtime_id(): + # 获取 hostname + hostname = socket.gethostname() + job_id = os.environ.get('SLURM_JOB_ID', 'unknown') + return f'{hostname}_{job_id}' + + +def __read_sample_of_layout_id(to_process_file_path): + """读取to_process_file_path路径的文件,并根据layout_id进行读取,每个layout_id为一组. + + 每次yield一个list, 列表中是layout_id对应的layout_samples. 读取直到layout_id发生变化. + """ + cur_layout_id = None + cur_layout_samples = [] + with open(to_process_file_path, 'r') as f: + for line in f: + data = json.loads(line) + if data['layout_id'] == cur_layout_id: + cur_layout_samples.append(data) + else: + yield cur_layout_samples + cur_layout_id = data['layout_id'] + cur_layout_samples = [data] + yield cur_layout_samples + + +def __do_page_classify(samples:list) -> tuple[str, float]: + """对samples进行分类,返回分类结果标签和最大概率.""" + if len(samples) < 3: + logger.error(f"samples of layout_id {samples[0]['layout_id']} is less than 3") + # 进行2次分类 + html_str_inputs = [html['simp_html'] for html in samples] + classify_res_top_2 = MODEL.predict(html_str_inputs[0:2]) + # 如果分类结果一致,则直接写入结果 + if classify_res_top_2[0]['pred_label'] == classify_res_top_2[1]['pred_label']: # 如果1和2的分类结果一致,则直接返回1的分类结果 + return classify_res_top_2[0]['pred_label'], max(classify_res_top_2[0]['pred_prob'], classify_res_top_2[1]['pred_prob']) + else: # 如果分类结果不一致,则进行第三次分类 + classify_3 = MODEL.predict(html_str_inputs[3]) + if classify_3[0]['pred_label'] == classify_res_top_2[0]['pred_label']: + return classify_3[0]['pred_label'], max(classify_3[0]['pred_prob'], classify_res_top_2[0]['pred_prob']) + elif classify_3[0]['pred_label'] == classify_res_top_2[1]['pred_label']: + return classify_3[0]['pred_label'], max(classify_3[0]['pred_prob'], classify_res_top_2[1]['pred_prob']) + else: # 第三个和前两个任何一个分类结果都不一致,则把类别分到other里 + return 'other', 0.0 + + +def __process_one_layout_file(result_save_dir, to_process_file_path): + """读取to_process_file_path路径的文件,并进行分类,将结果写入result_save_dir路径 + 读取的时候,需要根据layout_id进行读取,每个layout_id为一组,每组内部进行分类。 分类的时候,先进行2次分类,如果分类结果一致,则直接 + 写入结果,如果分类结果不一致,则进行第三次分类,第三次分类的时候,需要根据前两次分类结果,进行分类。""" + result_file_path = result_save_dir + Path(to_process_file_path).name + # 检查如果result_file_path存在,则不进行处理 + if Path(result_file_path).exists(): + logger.info(f'result_file_path {result_file_path} exists, skip') + __report_status(UPDATE_STATUS_URL, to_process_file_path, 'SUCC') + return + + file_buffer = BytesIO() + for samples_of_layout_id in __read_sample_of_layout_id(to_process_file_path): + label, max_score = __do_page_classify(samples_of_layout_id) + classify_res = { + 'url_list': [i['url'] for i in samples_of_layout_id], + 'layout_id': samples_of_layout_id[0]['layout_id'], + 'page_type': label, + 'max_pred_prod': max_score, + 'version': MODEL_VERESION, + } + file_buffer.write(json.dumps(classify_res, ensure_ascii=False) + '\n') + + file_buffer.seek(0) + # 一次性写入到磁盘,降低磁盘IO + with open(result_file_path, 'w') as f: + f.write(file_buffer.getvalue()) + + file_buffer.close() + + +@retry(tries=5, delay=10, max_delay=5) +def __report_status(server_addr, file_path, status): + """更新server上的状态.""" + UPDATE_STATUS_URL = f'{server_addr}/update_status' + requests.post(UPDATE_STATUS_URL, json={'file_path': file_path, 'status': status}) + logger.info(f'report status {status} for file {file_path}') + + +@click.command() +@click.option('--result-save-dir', type=click.Path(exists=True), help='分类结果文件输出路径') +@click.option('--server-addr', type=str, help='server的地址,例如http://127.0.0.1:5000') +def main(result_save_dir: str, server_addr: str): + global GET_FILE_URL, UPDATE_STATUS_URL, MODEL + GET_FILE_URL = f'{server_addr}/get_file' + UPDATE_STATUS_URL = f'{server_addr}/update_status' + logger.info('init model') + MODEL = HTMLLayoutClassifier() + logger.info('init model done') + while True: + try: + # 获取待处理的文件路径 + logger.info(f'get layout classify file from {GET_FILE_URL}') + to_process_file_path = requests.get(GET_FILE_URL).json()['file_path'] + logger.info(f'get layout classify file: {to_process_file_path}') + if not to_process_file_path: + logger.info('no file to process, sleep 10s') + time.sleep(10) + continue + # 处理文件 + __process_one_layout_file(result_save_dir, to_process_file_path) + # 更新状态 + __report_status(UPDATE_STATUS_URL, to_process_file_path, 'SUCC') + except Exception as e: + logger.error(f'get layout classify fail: {e}') + time.sleep(1) if __name__ == '__main__': diff --git a/llm_web_kit/html_layout_classify/html_layout_classify.md b/llm_web_kit/html_layout_classify/readme.md similarity index 62% rename from llm_web_kit/html_layout_classify/html_layout_classify.md rename to llm_web_kit/html_layout_classify/readme.md index 2a7595a5..6ee52469 100644 --- a/llm_web_kit/html_layout_classify/html_layout_classify.md +++ b/llm_web_kit/html_layout_classify/readme.md @@ -3,14 +3,12 @@ ## 环境 配置 .xinghe.yaml + 配置 .llm_web_kit.jsonc ## 入参 -layout_sample_dir: 每个layout_id 随机选取3条的.jsonl文件路径或文件夹路径 -layout_classify_dir:计算每个layout_id 对应的分类结果文件夹路径 - -layout_sample_dir 字段说明: +layout_sample_dir: 一个本地的目录,内含多个jsonl文件,每个文件的结构如下: | 字段 | 类型 | 描述 | 是否必须 | | --------- | ------ | ---------------------------- | -------- | @@ -18,7 +16,7 @@ layout_sample_dir 字段说明: | url | string | 数据url | 是 | | simp_html | string | html原数据经过简化处理的html | 是 | -layout_classify_dir 字段说明: +layout_classify_dir:分类结果的保存目录。输出的jsonl文件,每个文件的结构如下: | 字段 | 类型 | 描述 | 是否必须 | | ------------- | ------ | --------------------------------------------------------------- | -------- | @@ -27,3 +25,15 @@ layout_classify_dir 字段说明: | page_type | string | layout_id 经过分类之后的分类结果('other', 'article', 'forum') | 是 | | max_pred_prod | float | 分类模型的分类可靠度 | 是 | | version | string | 模型版本 | 是 | + +## 执行步骤 + +1. 执行server.py,启动服务,此服务提供2个接口: + +- /get_file:获取待分类的文件路径,每次一个,如果队列中没有文件,则返回空 +- /update_status:更新文件分类状态 +- /index:一个简单的web界面,可以查看当前的分类进度 + +2. 执行classify.sh,此脚本会调用server.py的/get_file接口获取待分类的文件,然后进行分类,并调用server.py的/update_status接口更新文件分类状态。 + +3. 执行classify-spot.sh ,可以利用spot资源。 diff --git a/llm_web_kit/html_layout_classify/server.py b/llm_web_kit/html_layout_classify/server.py new file mode 100644 index 00000000..c96df82d --- /dev/null +++ b/llm_web_kit/html_layout_classify/server.py @@ -0,0 +1,250 @@ +""" +实现一个flask的server,实现: +1. 启动的时候,接受一个命令行参数 --layout_sample_dir,表示layout_sample_dir路径。扫描这个路径下所有.jsonl文件,保存他们的绝对路径。把这些路径放到一个队列里。 +2. 实现一个http get接口,每次返回队列里的一个路径,并从队列里删除该路径。被删除的路径保存到另外一个dict里,value是个当时的时间start_tm。用来记录未来处理是否成功。 +3. 实现一个http post接口,接受一个路径,和对这个路径的处理结果SUCC|FAIL, 和一条msg。把这3条信息存到dict的vlaue里,加上end_tm。 +4. 实现一个http get,返还一个html表格。 显示queue里总路径,dict里的路径,和处理结果。显示处理总进度=dict里处理成功的路径/queue里总路径+dict里总的路径。 + +""" +import json +import os +import sys +from collections import deque +from datetime import datetime +from pathlib import Path + +import click +from flask import Flask, jsonify, render_template_string, request +from loguru import logger + +app = Flask(__name__) + +# Global variables +file_queue = deque() +processed_files = {} +total_files = 0 + +# Queue persistence file path +QUEUE_FILE = os.path.expanduser('~/.page_classify_queue') + +# HTML template for status page +HTML_TEMPLATE = """ + + + + Processing Status + + + +

Processing Status

+

Progress: {{ progress }}%

+

Queue Status:

+

Files remaining in queue: {{ queue_length }}

+

Files currently processing: {{ processing_count }}

+

Processed Files:

+ + + + + + + + + + {% for path, info in processed_files.items() %} + + + + + + + + + {% endfor %} +
File PathStatusMessageStart TimeEnd TimeDuration
{{ path }}{{ info.get('status', '') }}{{ info.get('msg', '') }}{{ info.get('start_tm', '') }}{{ info.get('end_tm', '') }}{{ info.get('duration', '') }}
+ + + +""" + + +def load_processed_files(): + """Load processing files from persistence file.""" + global processed_files + if os.path.exists(QUEUE_FILE): + with open(QUEUE_FILE, 'r') as f: + saved_files = json.load(f) + # Only load files that are still in PROCESSING status + processed_files = { + path: info for path, info in saved_files.items() + if info.get('status') == 'PROCESSING' + } + + +def save_processed_files(): + """Save processed files to persistence file.""" + with open(QUEUE_FILE, 'w') as f: + json.dump(processed_files, f) + + +def __init_queue(layout_sample_dir): + """Initialize queue with .jsonl files from the given directory.""" + global file_queue, total_files + + layout_dir = Path(layout_sample_dir) + if not layout_dir.exists(): + print(f'Error: Directory {layout_sample_dir} does not exist') + sys.exit(1) + + # Load processed files first to exclude processing files + load_processed_files() + + # Get set of files currently being processed + processing_files = {path for path, info in processed_files.items() + if info.get('status') == 'PROCESSING'} + + layout_dir = Path(layout_sample_dir) + for file_path in layout_dir.rglob('*.jsonl'): + file_path_str = str(file_path) + # Only add files that are not currently being processed + if file_path_str not in processing_files: + file_queue.append(file_path_str) + + total_files = len(file_queue) + + +@app.route('/get_file', methods=['GET']) +def get_file(): + global file_queue, processed_files + # Check for timed out files before getting next file + current_time = datetime.now() + timed_out_files = [] + for file_path, info in processed_files.items(): + if info['status'] == 'PROCESSING': + start_time = datetime.strptime(info['start_tm'], '%Y-%m-%d %H:%M:%S') + duration = (current_time - start_time).total_seconds() + if duration > app.config['TIMEOUT']: + logger.info(f'File {file_path} timed out, adding back to queue') + timed_out_files.append(file_path) # 超时的文件,重新加入队列 + file_queue.append(file_path) + + # Remove timed out files from processed_files + for file_path in timed_out_files: + del processed_files[file_path] + + try: + file_path = file_queue.popleft() + except IndexError: + logger.error('No more files in queue') + return jsonify({'file_path': ''}) + + processed_files[file_path] = { + 'start_tm': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), + 'status': 'PROCESSING' + } + + # Save updated processed files + save_processed_files() + + logger.info(f'get layout classify file: {file_path}') + return jsonify({'file_path': file_path}) + + +@app.route('/update_status///', methods=['POST']) +def update_status(file_path, status, msg): + """Update processing status for a file.""" + if file_path not in processed_files: + return jsonify({'error': 'File not found in processed list'}) + + end_time = datetime.now() + start_time = datetime.strptime(processed_files[file_path]['start_tm'], '%Y-%m-%d %H:%M:%S') + duration = end_time - start_time + + processed_files[file_path].update({ + 'status': status, + 'msg': msg, + 'end_tm': end_time.strftime('%Y-%m-%d %H:%M:%S'), + 'duration': str(duration) + }) + + # Save updated processed files + save_processed_files() + logger.info(f'update layout classify status: {file_path} {status} {msg}') + return jsonify({'status': 'success'}) + + +@app.route('/index', methods=['GET']) +def index(): + """Get processing status page.""" + success_count = sum(1 for info in processed_files.values() + if info.get('status') == 'SUCC') + processing_count = sum(1 for info in processed_files.values() + if info.get('status') == 'PROCESSING') + error_count = sum(1 for info in processed_files.values() + if info.get('status') == 'FAIL') + total = total_files + progress = (success_count / total * 100) if total > 0 else 0 + + # Get page parameter from request, default to 1 + page = request.args.get('page', 1, type=int) + per_page = 50 + + # Get paginated list of processed files + items = list(processed_files.items()) + total_pages = (len(items) + per_page - 1) // per_page + start = (page - 1) * per_page + end = start + per_page + current_items = dict(items[start:end]) + + return render_template_string( + HTML_TEMPLATE, + queue_length=len(file_queue), + processed_files=current_items, + progress=round(progress, 2), + page=page, + total_pages=total_pages, + processing_count=processing_count, + error_count=error_count + ) + + +@click.command() +@click.option('--layout_sample_dir', required=True, help='Directory containing layout sample files') +@click.option('--port', default=5000, help='Port to run the server on') +@click.option('--host', default='0.0.0.0', help='Host IP to run the server on') +@click.option('--timeout', default=10, help='timeout to process one file') +def run_server(layout_sample_dir, port, host, timeout): + """Initialize and run the server.""" + __init_queue(layout_sample_dir) + app.config['TIMEOUT'] = timeout + app.run(host=host, port=port) + + +if __name__ == '__main__': + run_server() diff --git a/requirements/dev.txt b/requirements/dev.txt index 9f379ead..19923b09 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1,6 +1,9 @@ +flask==3.0.2 # for html_layout_classify pre-commit==3.8.0 pydantic==2.10.6 pytest==8.3.3 # coverage tools pytest-cov==6.0.0 pytest-xdist==3.6.1 +requests==2.31.0 # for html_layout_classify +retry==0.9.2 # for html_layout_classify From 3f210f6a4b9622e40eef6763f6c63de59ff7f6bc Mon Sep 17 00:00:00 2001 From: drunkpig Date: Tue, 18 Mar 2025 14:02:35 +0800 Subject: [PATCH 2/4] feat: classify page by html layout use GPU --- llm_web_kit/extractor/html/extractor.py | 3 +- .../html_layout_classify/classify-spot.sh | 6 +- llm_web_kit/html_layout_classify/classify.sh | 5 +- llm_web_kit/html_layout_classify/main.py | 39 +++--- llm_web_kit/html_layout_classify/readme.md | 24 +++- llm_web_kit/html_layout_classify/server.py | 118 +++++++++++------- 6 files changed, 124 insertions(+), 71 deletions(-) diff --git a/llm_web_kit/extractor/html/extractor.py b/llm_web_kit/extractor/html/extractor.py index a3d4a5f6..1194b71c 100644 --- a/llm_web_kit/extractor/html/extractor.py +++ b/llm_web_kit/extractor/html/extractor.py @@ -103,7 +103,7 @@ def _do_extract(self, data_json: DataJson) -> DataJson: return data_json - def _extract_main_html(self, raw_html:str, base_url:str, page_layout_type:str) -> (str, str): + def _extract_main_html(self, raw_html:str, base_url:str, page_layout_type:str) -> Tuple[str, str]: """从html文本中提取主要的内容. Args: @@ -126,7 +126,6 @@ def _extract_code(self, base_url:str, html_lst:List[Tuple[str,str]], raw_html:st base_url (str): html文本的网页地址 html_lst (List[Tuple[str,str]]): html文本 raw_html (str): html文本 - Returns: """ diff --git a/llm_web_kit/html_layout_classify/classify-spot.sh b/llm_web_kit/html_layout_classify/classify-spot.sh index b679091a..5fbd0758 100755 --- a/llm_web_kit/html_layout_classify/classify-spot.sh +++ b/llm_web_kit/html_layout_classify/classify-spot.sh @@ -107,7 +107,7 @@ mkdir -p ${SLURM_LOG_DIR}/error export SLURM_SUBMIT_DIR=${SLURM_LOG_DIR} export LLM_WEB_KIT_CFG_PATH=/share/xuchao/.llm-web-kit-pageclassify.jsonc TASK_NUM="${TASK_NUM:-1}" # Default to 1 if not provided - +DEBUG="${DEBUG:-0}" # Check required arguments if [ -z "$PARTATION" ] || [ -z "$TAG" ]; then @@ -129,9 +129,9 @@ do # total_reserved_idle=$(calculate_total_reserved_idle) # echo -e "check $partation spot \n tt:$tt \n total_spot_used: $total_spot_used\n total_reserved_idle: $total_reserved_idle \n PD_COUNT: $PD_COUNT" if [ $DEBUG -eq 1 ]; then - LOG_LEVEL=ERROR srun -p ${partation} --quotatype=spot --output=${SLURM_LOG_DIR}/logs/output_%j.out --export=ALL --cpus-per-task=${TASK_NUM} --error=${SLURM_LOG_DIR}/error/error_%j.err -N ${TASK_NUM} --gres=gpu:1 python main.py ${SERVER_ADDR} --result-save-dir ${RESULT_SAVE_DIR} + LOG_LEVEL=ERROR srun -p ${partation} --quotatype=spot --output=${SLURM_LOG_DIR}/logs/output_%j.out --export=ALL --error=${SLURM_LOG_DIR}/error/error_%j.err -N 1 -n${TASK_NUM} --gres=gpu:1 python main.py ${SERVER_ADDR} --result-save-dir ${RESULT_SAVE_DIR} else - LOG_LEVEL=ERROR srun -p ${partation} --quotatype=spot --output=${SLURM_LOG_DIR}/logs/output_%j.out --export=ALL --cpus-per-task=${TASK_NUM} --error=${SLURM_LOG_DIR}/error/error_%j.err -N ${TASK_NUM} --gres=gpu:1 --async python main.py ${SERVER_ADDR} --result-save-dir ${RESULT_SAVE_DIR} + LOG_LEVEL=ERROR srun -p ${partation} --quotatype=spot --output=${SLURM_LOG_DIR}/logs/output_%j.out --export=ALL --error=${SLURM_LOG_DIR}/error/error_%j.err -N 1 -n ${TASK_NUM} --gres=gpu:1 --async python main.py ${SERVER_ADDR} --result-save-dir ${RESULT_SAVE_DIR} fi echo "use ${partation} submit job succ, submit next job now..." rm batchscript* 2>/dev/null diff --git a/llm_web_kit/html_layout_classify/classify.sh b/llm_web_kit/html_layout_classify/classify.sh index 32a98156..e59b2e75 100755 --- a/llm_web_kit/html_layout_classify/classify.sh +++ b/llm_web_kit/html_layout_classify/classify.sh @@ -65,6 +65,7 @@ mkdir -p ${SLURM_LOG_DIR}/error export SLURM_SUBMIT_DIR=${SLURM_LOG_DIR} export LLM_WEB_KIT_CFG_PATH=/share/${MY_NAME}/.llm-web-kit-pageclassify.jsonc TASK_NUM="${TASK_NUM:-1}" # Default to 1 if not provided +DEBUG="${DEBUG:-0}" SERVER_ADDR="${SERVER_ADDR:-http://127.0.0.1:5000}" PYTHON=/share/${MY_NAME}/.conda/envs/webkitdev/bin/python @@ -87,10 +88,10 @@ do if [ $avai_gpu -gt 0 ]; then # 提交一个任务,睡眠 if [ $DEBUG -eq 1 ]; then - LOG_LEVEL=INFO srun -p ${PARTATION} --output=${SLURM_LOG_DIR}/logs/output_%j.out --export=ALL --cpus-per-task=${TASK_NUM} --error=${SLURM_LOG_DIR}/error/error_%j.err --gres=gpu:1 -N ${TASK_NUM} ${PYTHON} main.py --server-addr ${SERVER_ADDR} --result-save-dir ${RESULT_SAVE_DIR} + LOG_LEVEL=INFO srun -p ${PARTATION} --output=${SLURM_LOG_DIR}/logs/output_%j.out --export=ALL --error=${SLURM_LOG_DIR}/error/error_%j.err --gres=gpu:1 -N 1 -n ${TASK_NUM} ${PYTHON} main.py --server-addr ${SERVER_ADDR} --result-save-dir ${RESULT_SAVE_DIR} else - LOG_LEVEL=ERROR srun -p ${PARTATION} --output=${SLURM_LOG_DIR}/logs/output_%j.out --export=ALL --cpus-per-task=${TASK_NUM} --error=${SLURM_LOG_DIR}/error/error_%j.err --gres=gpu:1 --async -N ${TASK_NUM} ${PYTHON} main.py --server-addr ${SERVER_ADDR} --result-save-dir ${RESULT_SAVE_DIR} + LOG_LEVEL=ERROR srun -p ${PARTATION} --output=${SLURM_LOG_DIR}/logs/output_%j.out --export=ALL --error=${SLURM_LOG_DIR}/error/error_%j.err --gres=gpu:1 --async -N 1 -n ${TASK_NUM} ${PYTHON} main.py --server-addr ${SERVER_ADDR} --result-save-dir ${RESULT_SAVE_DIR} fi # TODO 判断任务是否提交成功 submited_job_num=$((submited_job_num+1)) diff --git a/llm_web_kit/html_layout_classify/main.py b/llm_web_kit/html_layout_classify/main.py index 59fe046f..ff9a1d8c 100644 --- a/llm_web_kit/html_layout_classify/main.py +++ b/llm_web_kit/html_layout_classify/main.py @@ -37,8 +37,9 @@ def __read_sample_of_layout_id(to_process_file_path): with open(to_process_file_path, 'r') as f: for line in f: data = json.loads(line) - if data['layout_id'] == cur_layout_id: + if data['layout_id'] == cur_layout_id or cur_layout_id is None: cur_layout_samples.append(data) + cur_layout_id = data['layout_id'] else: yield cur_layout_samples cur_layout_id = data['layout_id'] @@ -48,8 +49,10 @@ def __read_sample_of_layout_id(to_process_file_path): def __do_page_classify(samples:list) -> tuple[str, float]: """对samples进行分类,返回分类结果标签和最大概率.""" - if len(samples) < 3: - logger.error(f"samples of layout_id {samples[0]['layout_id']} is less than 3") + if len(samples) <= 1: + logger.error(f"samples of layout_id {samples[0]['layout_id']} is less than 1 or empty") + return 'other', 0.0 + # 进行2次分类 html_str_inputs = [html['simp_html'] for html in samples] classify_res_top_2 = MODEL.predict(html_str_inputs[0:2]) @@ -57,12 +60,15 @@ def __do_page_classify(samples:list) -> tuple[str, float]: if classify_res_top_2[0]['pred_label'] == classify_res_top_2[1]['pred_label']: # 如果1和2的分类结果一致,则直接返回1的分类结果 return classify_res_top_2[0]['pred_label'], max(classify_res_top_2[0]['pred_prob'], classify_res_top_2[1]['pred_prob']) else: # 如果分类结果不一致,则进行第三次分类 - classify_3 = MODEL.predict(html_str_inputs[3]) - if classify_3[0]['pred_label'] == classify_res_top_2[0]['pred_label']: - return classify_3[0]['pred_label'], max(classify_3[0]['pred_prob'], classify_res_top_2[0]['pred_prob']) - elif classify_3[0]['pred_label'] == classify_res_top_2[1]['pred_label']: - return classify_3[0]['pred_label'], max(classify_3[0]['pred_prob'], classify_res_top_2[1]['pred_prob']) - else: # 第三个和前两个任何一个分类结果都不一致,则把类别分到other里 + if len(samples) > 2: + classify_3 = MODEL.predict([html_str_inputs[2]]) + if classify_3[0]['pred_label'] == classify_res_top_2[0]['pred_label']: + return classify_3[0]['pred_label'], max(classify_3[0]['pred_prob'], classify_res_top_2[0]['pred_prob']) + elif classify_3[0]['pred_label'] == classify_res_top_2[1]['pred_label']: + return classify_3[0]['pred_label'], max(classify_3[0]['pred_prob'], classify_res_top_2[1]['pred_prob']) + else: # 第三个和前两个任何一个分类结果都不一致,则把类别分到other里 + return 'other', 0.0 + else: return 'other', 0.0 @@ -70,7 +76,7 @@ def __process_one_layout_file(result_save_dir, to_process_file_path): """读取to_process_file_path路径的文件,并进行分类,将结果写入result_save_dir路径 读取的时候,需要根据layout_id进行读取,每个layout_id为一组,每组内部进行分类。 分类的时候,先进行2次分类,如果分类结果一致,则直接 写入结果,如果分类结果不一致,则进行第三次分类,第三次分类的时候,需要根据前两次分类结果,进行分类。""" - result_file_path = result_save_dir + Path(to_process_file_path).name + result_file_path = os.path.join(result_save_dir , Path(to_process_file_path).name) # 检查如果result_file_path存在,则不进行处理 if Path(result_file_path).exists(): logger.info(f'result_file_path {result_file_path} exists, skip') @@ -87,21 +93,21 @@ def __process_one_layout_file(result_save_dir, to_process_file_path): 'max_pred_prod': max_score, 'version': MODEL_VERESION, } - file_buffer.write(json.dumps(classify_res, ensure_ascii=False) + '\n') + logger.info(f"{samples_of_layout_id[0]['layout_id']}, {label}, {max_score}, {samples_of_layout_id[0]['url']}") + file_buffer.write(json.dumps(classify_res, ensure_ascii=False).encode('utf-8') + b'\n') file_buffer.seek(0) # 一次性写入到磁盘,降低磁盘IO - with open(result_file_path, 'w') as f: + with open(result_file_path, 'wb') as f: f.write(file_buffer.getvalue()) - + logger.info(f'finished process {to_process_file_path}, write result to {result_file_path}') file_buffer.close() @retry(tries=5, delay=10, max_delay=5) -def __report_status(server_addr, file_path, status): +def __report_status(server_url, file_path, status, msg=''): """更新server上的状态.""" - UPDATE_STATUS_URL = f'{server_addr}/update_status' - requests.post(UPDATE_STATUS_URL, json={'file_path': file_path, 'status': status}) + requests.post(server_url, json={'file_path': file_path, 'status': status, 'msg': msg}) logger.info(f'report status {status} for file {file_path}') @@ -131,6 +137,7 @@ def main(result_save_dir: str, server_addr: str): __report_status(UPDATE_STATUS_URL, to_process_file_path, 'SUCC') except Exception as e: logger.error(f'get layout classify fail: {e}') + logger.exception(e) time.sleep(1) diff --git a/llm_web_kit/html_layout_classify/readme.md b/llm_web_kit/html_layout_classify/readme.md index 6ee52469..a278a2a3 100644 --- a/llm_web_kit/html_layout_classify/readme.md +++ b/llm_web_kit/html_layout_classify/readme.md @@ -30,10 +30,24 @@ layout_classify_dir:分类结果的保存目录。输出的jsonl文件,每 1. 执行server.py,启动服务,此服务提供2个接口: -- /get_file:获取待分类的文件路径,每次一个,如果队列中没有文件,则返回空 -- /update_status:更新文件分类状态 -- /index:一个简单的web界面,可以查看当前的分类进度 - -2. 执行classify.sh,此脚本会调用server.py的/get_file接口获取待分类的文件,然后进行分类,并调用server.py的/update_status接口更新文件分类状态。 + - /get_file:获取待分类的文件路径,每次一个,如果队列中没有文件,则返回空 + - /update_status:更新文件分类状态 + - /index:一个简单的web界面,可以查看当前的分类进度 + - 启动参数为: + - --layout_sample_dir:layout样本的保存目录,这里面每个文件会被server分发出去。 + - --port:服务端口 + - --host:服务地址 + - --timeout:客户端处理一个文件的超时时间,如果超时会被重新分配。 + - --reset:是否重置。会清空当前的分类状态,不保存重启前的任务状态。 + +2. 执行classify.sh,此脚本会向slurm集群提交任务。这些任务常驻GPU,每个任务调用server.py的/get_file接口获取待分类的文件,然后进行分类,并调用server.py的/update_status接口更新文件分类状态。 + + - --partation:slurm的partation,例如:xinghe-gpu + - --max-job:最大提交任务数 + - --tag:slurm的tag,例如:html_layout_classify,用于同一个管理节点启动区分不同的任务隔离开日志的输出 + - --task-num:每个GPU上开启多少个任务实例,为了充分提高GPU的利用率 + - --debug:是否开启debug模式 + - --result-save-dir:分类结果的保存目录 + - --server-addr:server的地址,例如:http://127.0.0.1:5000 3. 执行classify-spot.sh ,可以利用spot资源。 diff --git a/llm_web_kit/html_layout_classify/server.py b/llm_web_kit/html_layout_classify/server.py index c96df82d..70e31bc5 100644 --- a/llm_web_kit/html_layout_classify/server.py +++ b/llm_web_kit/html_layout_classify/server.py @@ -8,10 +8,12 @@ """ import json import os +import queue import sys -from collections import deque from datetime import datetime from pathlib import Path +from queue import Queue +from threading import Lock import click from flask import Flask, jsonify, render_template_string, request @@ -20,9 +22,13 @@ app = Flask(__name__) # Global variables -file_queue = deque() + +file_queue = Queue() processed_files = {} total_files = 0 +processed_files_lock = Lock() +succ_count = 0 # 处理成功的计数 + # Queue persistence file path QUEUE_FILE = os.path.expanduser('~/.page_classify_queue') @@ -55,6 +61,7 @@

Processing Status

Progress: {{ progress }}%

Queue Status:

+

Files succ processing: {{ succ_count }}

Files remaining in queue: {{ queue_length }}

Files currently processing: {{ processing_count }}

Processed Files:

@@ -96,7 +103,7 @@ def load_processed_files(): """Load processing files from persistence file.""" - global processed_files + global processed_files, succ_count if os.path.exists(QUEUE_FILE): with open(QUEUE_FILE, 'r') as f: saved_files = json.load(f) @@ -105,6 +112,15 @@ def load_processed_files(): path: info for path, info in saved_files.items() if info.get('status') == 'PROCESSING' } + for _, info in saved_files.items(): + if info.get('status') == 'SUCC': + succ_count += 1 + + +def clear_processed_files(): + """Clear processed files from persistence file.""" + if os.path.exists(QUEUE_FILE): + os.remove(QUEUE_FILE) def save_processed_files(): @@ -113,7 +129,7 @@ def save_processed_files(): json.dump(processed_files, f) -def __init_queue(layout_sample_dir): +def __init_queue(layout_sample_dir, reset): """Initialize queue with .jsonl files from the given directory.""" global file_queue, total_files @@ -123,6 +139,8 @@ def __init_queue(layout_sample_dir): sys.exit(1) # Load processed files first to exclude processing files + if reset: + clear_processed_files() load_processed_files() # Get set of files currently being processed @@ -134,50 +152,59 @@ def __init_queue(layout_sample_dir): file_path_str = str(file_path) # Only add files that are not currently being processed if file_path_str not in processing_files: - file_queue.append(file_path_str) + file_queue.put(file_path_str) - total_files = len(file_queue) + total_files = file_queue.qsize() + + +# Add lock as a global variable at module level @app.route('/get_file', methods=['GET']) def get_file(): - global file_queue, processed_files - # Check for timed out files before getting next file - current_time = datetime.now() - timed_out_files = [] - for file_path, info in processed_files.items(): - if info['status'] == 'PROCESSING': - start_time = datetime.strptime(info['start_tm'], '%Y-%m-%d %H:%M:%S') - duration = (current_time - start_time).total_seconds() - if duration > app.config['TIMEOUT']: - logger.info(f'File {file_path} timed out, adding back to queue') - timed_out_files.append(file_path) # 超时的文件,重新加入队列 - file_queue.append(file_path) - - # Remove timed out files from processed_files - for file_path in timed_out_files: - del processed_files[file_path] - - try: - file_path = file_queue.popleft() - except IndexError: - logger.error('No more files in queue') - return jsonify({'file_path': ''}) - - processed_files[file_path] = { - 'start_tm': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), - 'status': 'PROCESSING' - } + global file_queue, processed_files, processed_files_lock + + with processed_files_lock: + # Check for timed out files before getting next file + current_time = datetime.now() + timed_out_files = [] + for file_path, info in processed_files.items(): + if info['status'] == 'PROCESSING': + start_time = datetime.strptime(info['start_tm'], '%Y-%m-%d %H:%M:%S') + duration = (current_time - start_time).total_seconds() + if duration > app.config['TIMEOUT']: + logger.info(f'File {file_path} timed out, adding back to queue') + timed_out_files.append(file_path) # 超时的文件,重新加入队列 + file_queue.put(file_path) + + # Remove timed out files from processed_files + for file_path in timed_out_files: + del processed_files[file_path] + + try: + file_path = file_queue.get(block=False) + except queue.Empty: # queue.get() raises queue.Empty when empty, not IndexError + logger.error('No more files in queue') + return jsonify({'file_path': ''}) + + processed_files[file_path] = { + 'start_tm': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), + 'status': 'PROCESSING' + } - # Save updated processed files - save_processed_files() + # Save updated processed files + save_processed_files() - logger.info(f'get layout classify file: {file_path}') - return jsonify({'file_path': file_path}) + logger.info(f'get layout classify file: {file_path}') + return jsonify({'file_path': file_path}) -@app.route('/update_status///', methods=['POST']) -def update_status(file_path, status, msg): +@app.route('/update_status', methods=['POST']) +def update_status(): + data = request.get_json() + file_path = data['file_path'] + status = data['status'] + msg = data.get('msg', '') # Optional message parameter """Update processing status for a file.""" if file_path not in processed_files: return jsonify({'error': 'File not found in processed list'}) @@ -202,12 +229,15 @@ def update_status(file_path, status, msg): @app.route('/index', methods=['GET']) def index(): """Get processing status page.""" + global succ_count success_count = sum(1 for info in processed_files.values() if info.get('status') == 'SUCC') processing_count = sum(1 for info in processed_files.values() if info.get('status') == 'PROCESSING') error_count = sum(1 for info in processed_files.values() if info.get('status') == 'FAIL') + _succ_count = succ_count + sum(1 for info in processed_files.values() + if info.get('status') == 'SUCC') total = total_files progress = (success_count / total * 100) if total > 0 else 0 @@ -224,13 +254,14 @@ def index(): return render_template_string( HTML_TEMPLATE, - queue_length=len(file_queue), + queue_length=file_queue.qsize(), processed_files=current_items, progress=round(progress, 2), page=page, total_pages=total_pages, processing_count=processing_count, - error_count=error_count + error_count=error_count, + succ_count=_succ_count ) @@ -239,9 +270,10 @@ def index(): @click.option('--port', default=5000, help='Port to run the server on') @click.option('--host', default='0.0.0.0', help='Host IP to run the server on') @click.option('--timeout', default=10, help='timeout to process one file') -def run_server(layout_sample_dir, port, host, timeout): +@click.option('--reset', is_flag=True, default=False, help='Reset cached files') +def run_server(layout_sample_dir, port, host, timeout, reset): """Initialize and run the server.""" - __init_queue(layout_sample_dir) + __init_queue(layout_sample_dir, reset) app.config['TIMEOUT'] = timeout app.run(host=host, port=port) From efda13ed67b749c5f1876d0f9c75eafbc4e97400 Mon Sep 17 00:00:00 2001 From: drunkpig Date: Tue, 18 Mar 2025 15:14:24 +0800 Subject: [PATCH 3/4] feat: raise CleanModelUnsupportedLanguageException in clean module --- llm_web_kit/exception/exception.jsonc | 4 ++++ llm_web_kit/exception/exception.py | 13 +++++++++++++ llm_web_kit/model/quality_model.py | 9 ++++----- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/llm_web_kit/exception/exception.jsonc b/llm_web_kit/exception/exception.jsonc index 24600d60..c72a62f2 100644 --- a/llm_web_kit/exception/exception.jsonc +++ b/llm_web_kit/exception/exception.jsonc @@ -142,6 +142,10 @@ "CleanModelException": { "code": 46000000, "message": "Clean model exception" + }, + "CleanModelUnsupportedLanguageException": { + "code": 46100000, + "message": "Clean model unsupported language exception" } } } diff --git a/llm_web_kit/exception/exception.py b/llm_web_kit/exception/exception.py index c3f4f5d1..f6b92cf7 100644 --- a/llm_web_kit/exception/exception.py +++ b/llm_web_kit/exception/exception.py @@ -358,3 +358,16 @@ def __init__(self, custom_message: str | None = None, error_code: int | None = N if error_code is None: error_code = ErrorMsg.get_error_code('Model', 'CleanModelException') super().__init__(custom_message, error_code) + + +############################################################################## +# +# Model Exceptions +# +############################################################################## +class CleanModelUnsupportedLanguageException(CleanModelException): + """Exception raised for clean model unsupported language.""" + def __init__(self, custom_message: str | None = None, error_code: int | None = None): + if error_code is None: + error_code = ErrorMsg.get_error_code('Model', 'CleanModelUnsupportedLanguageException') + super().__init__(custom_message, error_code) diff --git a/llm_web_kit/model/quality_model.py b/llm_web_kit/model/quality_model.py index f6d95bd1..d0839b84 100644 --- a/llm_web_kit/model/quality_model.py +++ b/llm_web_kit/model/quality_model.py @@ -8,7 +8,8 @@ import llm_web_kit.model.basic_functions as bfuncs from llm_web_kit.config.cfg_reader import load_config -from llm_web_kit.exception.exception import ModelInputException +from llm_web_kit.exception.exception import ( + CleanModelUnsupportedLanguageException, ModelInputException) from llm_web_kit.input.datajson import DataJson from llm_web_kit.libs.logger import mylogger as logger from llm_web_kit.model.basic_functions.features import ( @@ -386,15 +387,13 @@ def filter( content_style (str): the content style of the content Raises: - TODO use custom exception instead of - ValueError: raise ValueError if the language and content_style are not supported + CleanModelUnsupportedLanguageException: raise if the language and content_style are not supported Returns: bool: True if the content should remain, False if the content should be filtered out """ if not self.check_supported(language, content_style): - # TODO move the exception to the upper level - raise ValueError( + raise CleanModelUnsupportedLanguageException( f"Unsupport language '{language}' with content_style '{content_style}'" ) else: From 4fdad12f956c48a8ea95d41626c669eac703b766 Mon Sep 17 00:00:00 2001 From: drunkpig Date: Tue, 18 Mar 2025 15:30:02 +0800 Subject: [PATCH 4/4] fix: clean model exception --- tests/llm_web_kit/model/test_quality_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/llm_web_kit/model/test_quality_model.py b/tests/llm_web_kit/model/test_quality_model.py index 4172be95..14418909 100644 --- a/tests/llm_web_kit/model/test_quality_model.py +++ b/tests/llm_web_kit/model/test_quality_model.py @@ -4,7 +4,8 @@ from unittest import TestCase from unittest.mock import MagicMock, mock_open, patch -from llm_web_kit.exception.exception import ModelInputException # noqa: E402 +from llm_web_kit.exception.exception import ( # noqa: E402 + CleanModelUnsupportedLanguageException, ModelInputException) from llm_web_kit.model.quality_model import QualityModel # noqa: E402 from llm_web_kit.model.quality_model import get_quality_model # noqa: E402 from llm_web_kit.model.quality_model import quality_prober # noqa: E402 @@ -320,7 +321,7 @@ def test_filter_supported_low_score(self, mock_get_model): @patch.dict('llm_web_kit.model.quality_model._model_resource_map', {}, clear=True) def test_filter_unsupported_combination(self): """测试不支持的语言风格组合.""" - with self.assertRaises(ValueError) as context: + with self.assertRaises(CleanModelUnsupportedLanguageException) as context: self.filter.filter('content', 'jp', 'details', 'novel') self.assertIn(