diff --git a/llm_web_kit/html_layout_classify/__init__.py b/llm_web_kit/html_layout_classify/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/llm_web_kit/html_layout_classify/html_layout_classify.md b/llm_web_kit/html_layout_classify/html_layout_classify.md new file mode 100644 index 00000000..2a7595a5 --- /dev/null +++ b/llm_web_kit/html_layout_classify/html_layout_classify.md @@ -0,0 +1,29 @@ +# html layout classify layout分类 + +## 环境 + +配置 .xinghe.yaml +配置 .llm_web_kit.jsonc + +## 入参 + +layout_sample_dir: 每个layout_id 随机选取3条的.jsonl文件路径或文件夹路径 +layout_classify_dir:计算每个layout_id 对应的分类结果文件夹路径 + +layout_sample_dir 字段说明: + +| 字段 | 类型 | 描述 | 是否必须 | +| --------- | ------ | ---------------------------- | -------- | +| layout_id | string | layout id | 是 | +| url | string | 数据url | 是 | +| simp_html | string | html原数据经过简化处理的html | 是 | + +layout_classify_dir 字段说明: + +| 字段 | 类型 | 描述 | 是否必须 | +| ------------- | ------ | --------------------------------------------------------------- | -------- | +| url_list | list | layout id 对应的url | 是 | +| layout_id | string | layout id | 是 | +| page_type | string | layout_id 经过分类之后的分类结果('other', 'article', 'forum') | 是 | +| max_pred_prod | float | 分类模型的分类可靠度 | 是 | +| version | string | 模型版本 | 是 | diff --git a/llm_web_kit/html_layout_classify/main.py b/llm_web_kit/html_layout_classify/main.py new file mode 100644 index 00000000..b1b4f4d8 --- /dev/null +++ b/llm_web_kit/html_layout_classify/main.py @@ -0,0 +1,115 @@ +import argparse +import json + +from loguru import logger + +from llm_web_kit.html_layout_classify.s3.client import list_s3_objects +from llm_web_kit.html_layout_classify.s3.read import read_s3_rows +from llm_web_kit.html_layout_classify.s3.write import S3DocWriter +from llm_web_kit.model.html_layout_cls import HTMLLayoutClassifier + +CLASSIFY_MAP = {'other': 0, 'article': 1, 'forum': 2} +INT_CLASSIFY_MAP = {0: 'other', 1: 'article', 2: 'forum'} +MODEL_VERESION = '0.0.2' + + +def __list_layout_sample_dir(s3_dir: str) -> list: + """列出所有的layout sample json文件.""" + if s3_dir.endswith('/'): + layout_sample_files = [f for f in list(list_s3_objects(s3_dir, recursive=True)) if f.endswith('.jsonl')] + return layout_sample_files + return [s3_dir] + + +def __parse_predict_res(predict_res: list, layout_samples: list) -> int: + """解析模型分类结果.""" + # [{'pred_prob': '0.626298', 'pred_label': 'other'}] + res = { + 'url_list': [i['url'] for i in layout_samples], + 'layout_id': layout_samples[0]['layout_id'], + 'page_type': INT_CLASSIFY_MAP.get( + __most_frequent_or_zero([CLASSIFY_MAP.get(i['pred_label'], 0) for i in predict_res]), 'other'), + 'max_pred_prod': max([i['pred_prob'] for i in predict_res]), + 'version': MODEL_VERESION, + } + return res + + +def __most_frequent_or_zero(int_elements): + """计算分类结果最多的类型,否则为0.""" + if not int_elements: + return 0 + + elif len(int_elements) == 1: + return int_elements[0] + + elif len(int_elements) == 2: + return int_elements[0] if int_elements[0] == int_elements[1] else 0 + + elif len(int_elements) == 3: + if int_elements[0] == int_elements[1] or int_elements[0] == int_elements[2]: + return int_elements[0] + elif int_elements[1] == int_elements[2]: + return int_elements[1] + else: + return 0 + else: + logger.error(f"most_frequent_or_zero error:{int_elements}") + + +def __process_one_layout_sample(layout_sample_file: str, layout_type_dir: str): + """处理一个layout的代表群体.""" + output_file_path = f"{layout_type_dir}{layout_sample_file.split('/')[-1]}" + writer = S3DocWriter(output_file_path) + + def __get_type_by_layoutid(layout_samples: list): + # html_str_input = [general_simplify_html_str(html['html_source']) for html in layout_samples] + html_str_input = [html['simp_html'] for html in layout_samples] + layout_classify_lst = model.predict(html_str_input) + layout_classify = __parse_predict_res(layout_classify_lst, layout_samples) + return layout_classify + + current_layout_id, samples = None, [] + idx = 0 + for row in read_s3_rows(layout_sample_file): + idx += 1 + detail_data = json.loads(row.value) + if current_layout_id == detail_data['layout_id']: + samples.append(detail_data) + else: + if samples: + classify_res = __get_type_by_layoutid(samples) + writer.write(classify_res) + current_layout_id, samples = detail_data['layout_id'], [detail_data] + if samples: + classify_res = __get_type_by_layoutid(samples) + writer.write(classify_res) + writer.flush() + logger.info(f'read {layout_sample_file} file {idx} rows') + + +def __set_config(): + global model + model = HTMLLayoutClassifier() + + +def main(): + parser = argparse.ArgumentParser(description='Process files with specified function.') + parser.add_argument('layout_sample_dir', help='待分类文件夹路径或文件路径') + parser.add_argument('layout_classify_dir', help='已分类结果输出路径') + + args = parser.parse_args() + + try: + # 加载模型 + __set_config() + layout_sample_files = __list_layout_sample_dir(args.layout_sample_dir) + # 读取每个json文件的数据,根据每个layout_id为一簇,计算每个layout_id 对应的 layout_classify,并将结果写入s3 + for layout_sample_file in layout_sample_files: + __process_one_layout_sample(layout_sample_file, args.layout_classify_dir) + except Exception as e: + logger.error(f'get layout classify fail: {e}') + + +if __name__ == '__main__': + main() diff --git a/llm_web_kit/html_layout_classify/s3/__init__.py b/llm_web_kit/html_layout_classify/s3/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/llm_web_kit/html_layout_classify/s3/client.py b/llm_web_kit/html_layout_classify/s3/client.py new file mode 100644 index 00000000..3450ffd6 --- /dev/null +++ b/llm_web_kit/html_layout_classify/s3/client.py @@ -0,0 +1,139 @@ +import time +from typing import Dict, List, Union + +import boto3 +from boto3.s3.transfer import TransferConfig +from botocore.client import Config +from botocore.exceptions import ClientError + +from .conf import get_s3_config +from .path import split_s3_path + + +def get_s3_client(path: Union[str, List[str]], outside=False): + s3_config = get_s3_config(path, outside) + try: + return boto3.client( + 's3', + aws_access_key_id=s3_config['ak'], + aws_secret_access_key=s3_config['sk'], + endpoint_url=s3_config['endpoint'], + config=Config( + s3={'addressing_style': 'path'}, + retries={'max_attempts': 8, 'mode': 'standard'}, + connect_timeout=600, + read_timeout=600, + ), + ) + except Exception: + # older boto3 do not support retries.mode param. + return boto3.client( + 's3', + aws_access_key_id=s3_config['ak'], + aws_secret_access_key=s3_config['sk'], + endpoint_url=s3_config['endpoint'], + config=Config(s3={'addressing_style': 'path'}, retries={'max_attempts': 8}), + ) + + +def is_s3_404_error(e: Exception): + if not isinstance(e, ClientError): + return False + return ( + e.response.get('Error', {}).get('Code') in ['404', 'NoSuchKey'] + or e.response.get('Error', {}).get('Message') == 'Not Found' + or e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 404 + ) + + +def head_s3_object(path: str, raise_404=False, client=None) -> Union[Dict, None]: + client = client or get_s3_client(path) + bucket, key = split_s3_path(path) + try: + resp = client.head_object(Bucket=bucket, Key=key) + return resp + except ClientError as e: + if not raise_404 and is_s3_404_error(e): + return None + raise + + +def _restore_and_wait(client, bucket: str, key: str, path: str): + while True: + head = client.head_object(Bucket=bucket, Key=key) + restore = head.get('Restore', '') + if not restore: + req = {'Days': 1, 'GlacierJobParameters': {'Tier': 'Standard'}} + client.restore_object(Bucket=bucket, Key=key, RestoreRequest=req) + print(f"restoration-started: {path}") + elif 'ongoing-request="true"' in restore: + print(f"restoration-ongoing: {path}") + elif 'ongoing-request="false"' in restore: + print(f"restoration-complete: {path}") + break + time.sleep(3) + + +def get_s3_object(path: str, client=None, **kwargs) -> dict: + client = client or get_s3_client(path) + bucket, key = split_s3_path(path) + try: + return client.get_object(Bucket=bucket, Key=key, **kwargs) + except ClientError as e: + if e.response.get('Error', {}).get('Code') == 'GlacierObjectNotRestore': + _restore_and_wait(client, bucket, key, path) + return client.get_object(Bucket=bucket, Key=key, **kwargs) + raise + + +def list_s3_objects(path: str, recursive=False, is_prefix=False, limit=0, client=None): + for content in list_s3_objects_detailed(path, recursive, is_prefix, limit, client=client): + yield content[0] + + +def list_s3_objects_detailed(path: str, recursive=False, is_prefix=False, limit=0, client=None): + client = client or get_s3_client(path) + if limit > 1000: + raise Exception('limit greater than 1000 is not supported.') + if not path.endswith('/') and not is_prefix: + path += '/' + bucket, prefix = split_s3_path(path) + marker = None + while True: + list_kwargs = dict(MaxKeys=1000, Bucket=bucket, Prefix=prefix) + if limit > 0: + list_kwargs['MaxKeys'] = limit + if not recursive: + list_kwargs['Delimiter'] = '/' + if marker: + list_kwargs['Marker'] = marker + response = client.list_objects(**list_kwargs) + marker = None + if not recursive: + common_prefixes = response.get('CommonPrefixes', []) + for cp in common_prefixes: + yield (f"s3://{bucket}/{cp['Prefix']}", cp) + if common_prefixes: + marker = common_prefixes[-1]['Prefix'] + contents = response.get('Contents', []) + for content in contents: + if not content['Key'].endswith('/'): + yield (f"s3://{bucket}/{content['Key']}", content) + if contents: + last_key = contents[-1]['Key'] + if not marker or last_key > marker: + marker = last_key + if limit or not response.get('IsTruncated') or not marker: + break + + +def upload_s3_object(path: str, local_file_path: str, client=None): + client = client or get_s3_client(path) + # upload + MB = 1024 ** 2 + config = TransferConfig( + multipart_threshold=128 * MB, + multipart_chunksize=16 * MB, # 156.25GiB maximum + ) + bucket, key = split_s3_path(path) + client.upload_file(local_file_path, bucket, key, Config=config) diff --git a/llm_web_kit/html_layout_classify/s3/cmd.py b/llm_web_kit/html_layout_classify/s3/cmd.py new file mode 100644 index 00000000..580a3040 --- /dev/null +++ b/llm_web_kit/html_layout_classify/s3/cmd.py @@ -0,0 +1,146 @@ +import ast +import json +import re +from typing import Union + +try: + import orjson +except Exception: + orjson = None + +_surrogates_re = r'[\ud800-\udfff]' + + +def json_dumps(d: dict, **kwargs) -> str: + if not kwargs and orjson: + try: + return orjson.dumps(d).decode('utf-8') + except Exception: + pass + return json.dumps(d, ensure_ascii=False, **kwargs) + + +def json_loads(s: Union[str, bytes], **kwargs) -> dict: + if not kwargs and orjson: + try: + return orjson.loads(s) + except Exception: + pass + try: + return json.loads(s, **kwargs) + except Exception as e: + if 'enclosed in double quotes' not in str(e): + raise e + if isinstance(s, bytes): + s = s.decode('utf-8') + else: + s = str(s) + return ast.literal_eval(s) + + +def json_encode(d: dict, end='\n', **kwargs) -> bytes: + return str_encode(json_dumps(d, **kwargs) + end) + + +def str_encode(s: str) -> bytes: + # try remote special characters + s = re.sub(_surrogates_re, '\ufffd', s) + + try: + return s.encode('utf-8') + except UnicodeEncodeError as e: + debug_start = max(0, e.start - 1000) + debug_end = min(len(s), e.end + 1000) + print(f"{debug_start=}, {debug_end=}, debug_s={s[debug_start:debug_end]}") + raise + +# def json_print(obj): +# if isinstance(obj, list) and len(obj): +# return json_print(obj[0]) +# if isinstance(obj, bytes): +# return json_print(obj.decode("utf-8")) +# if isinstance(obj, str): +# return json_print(json_loads(obj)) +# if isinstance(obj, dict): +# return print(json_dumps(obj, indent=2)) +# +# from .row_fallback import Row +# +# if isinstance(obj, Row) and "value" in obj: +# return json_print(obj.value) +# +# print(obj) +# +# def _format_datetime(dt): +# if not dt: +# return "" +# dt = dt.replace(tzinfo=timezone.utc).astimezone(tz=None) # localtime +# return dt.strftime("%y-%m-%d %H:%M:%S %Z") +# +# +# def _format_size(size): +# if size is None: +# return "" +# size = str(size) +# parts = [] +# while len(size): +# part_size = 3 +# if not parts and len(size) % part_size: +# part_size = len(size) % part_size +# parts.append(size[:part_size]) +# size = size[part_size:] +# return ",".join(parts) +# +# +# def _format_detail(detail): +# path, obj = detail +# if path.endswith("/"): +# return f"{'DIR'.rjust(53)} {path}" +# tm = _format_datetime(obj.get("LastModified")) +# sz = _format_size(obj.get("Size") or obj.get("ContentLength", 0)) +# owner = obj.get("Owner", {}).get("ID", "") +# return f"{tm} {sz.rjust(15)} {owner.rjust(15)} {path}" +# +# +# def head(path): +# obj_head = head_s3_object_with_retry(path) +# if obj_head is not None: +# print(json_dumps(obj_head, indent=2, default=str)) +# +# +# def cat(path, limit=1, show_loc=False): +# if "?bytes=" in path: +# row = read_s3_row(path) +# if row is not None: +# if show_loc: +# print(row.loc) +# json_print(row) +# return +# for row in read_s3_rows(path, use_stream=True, limit=limit): +# if show_loc: +# print(row.loc) +# json_print(row) +# +# +# def ls(path, limit=100): +# for obj in list_s3_objects(path, limit=limit): +# print(obj) +# +# +# def ls_r(path, limit=100): +# for item in list_s3_objects(path, True, True, limit): +# print(item) +# +# +# def ll(path, limit=100): +# for detail in list_s3_objects_detailed(path, limit=limit): +# print(_format_detail(detail)) +# +# +# def ll_r(path, limit=100): +# for detail in list_s3_objects_detailed(path, True, True, limit): +# print(_format_detail(detail)) +# +# +# def download(path): +# print(get_s3_presigned_url(path)) diff --git a/llm_web_kit/html_layout_classify/s3/conf.py b/llm_web_kit/html_layout_classify/s3/conf.py new file mode 100644 index 00000000..02115bd7 --- /dev/null +++ b/llm_web_kit/html_layout_classify/s3/conf.py @@ -0,0 +1,92 @@ +import fnmatch +import random +import socket +from typing import List, Union + +from .config import s3_bucket_prefixes, s3_buckets, s3_profiles +from .path import split_s3_path + + +def _is_inside_cluster(cluster_config: dict): + inside_hosts = cluster_config.get('inside_hosts') + if not (isinstance(inside_hosts, list) and inside_hosts): + return False + inside_hosts = [str(pat).lower() for pat in inside_hosts] + try: + host = socket.gethostname().lower() + except Exception: + return False + for host_pattern in inside_hosts: + if fnmatch.fnmatch(host, host_pattern): + return True + return False + + +def _get_s3_bucket_config(path: str): + bucket = split_s3_path(path)[0] if path else '' + bucket_config = s3_buckets.get(bucket) + if not bucket_config: + for prefix, c in s3_bucket_prefixes.items(): + if bucket.startswith(prefix): + bucket_config = c + break + if not bucket_config: + bucket_config = s3_profiles.get(bucket) + if not bucket_config: + bucket_config = s3_buckets.get('[default]') + assert bucket_config is not None + return bucket_config + + +def _get_s3_config( + bucket_config, + outside: bool, + prefer_ip=False, + prefer_auto=False, +): + cluster = bucket_config['cluster'] + assert isinstance(cluster, dict) + + if outside: + endpoint_key = 'outside' + elif prefer_auto: + endpoint_key = 'auto' + elif _is_inside_cluster(cluster): + endpoint_key = 'inside' + else: + endpoint_key = 'outside' + + if endpoint_key not in cluster: + endpoint_key = 'outside' + + if prefer_ip and f"{endpoint_key}_ips" in cluster: + endpoint_key = f"{endpoint_key}_ips" + + endpoints = cluster[endpoint_key] + + if isinstance(endpoints, str): + endpoint = endpoints + elif isinstance(endpoints, list): + endpoint = random.choice(endpoints) + else: + raise Exception(f"invalid endpoint for [{cluster}]") + + return { + 'endpoint': endpoint, + 'ak': bucket_config['ak'], + 'sk': bucket_config['sk'], + } + + +def get_s3_config(path: Union[str, List[str]], outside=False): + paths = [path] if type(path) == str else path + bucket_config = None + for p in paths: + bc = _get_s3_bucket_config(p) + if bucket_config in [bc, None]: + bucket_config = bc + continue + raise Exception(f"{paths} have different s3 config, cannot read together.") + if not bucket_config: + raise Exception('path is empty.') + return _get_s3_config(bucket_config, outside, prefer_ip=True) diff --git a/llm_web_kit/html_layout_classify/s3/config.py b/llm_web_kit/html_layout_classify/s3/config.py new file mode 100644 index 00000000..0499b96e --- /dev/null +++ b/llm_web_kit/html_layout_classify/s3/config.py @@ -0,0 +1,62 @@ +from .reader import read_config + +config = read_config() + +spark_clusters: dict = config.get('spark', {}).get('clusters', {}) +kafka_clusters: dict = config.get('kafka', {}).get('clusters', {}) +es_clusters: dict = config.get('es', {}).get('clusters', {}) +kudu_clusters: dict = config.get('kudu', {}).get('clusters', {}) +hive_clusters: dict = config.get('hive', {}).get('clusters', {}) + +_s3_buckets: dict = config.get('s3', {}).get('buckets', {}) +_s3_profiles: dict = config.get('s3', {}).get('profiles', {}) +s3_clusters: dict = config.get('s3', {}).get('endpoints', {}) + +_kc = list(s3_clusters.keys()) +_kc.sort(key=lambda s: -len(s)) + +s3_profiles = {} +for name, profile in _s3_profiles.items(): + assert isinstance(profile, dict) + ak = profile.get('aws_access_key_id') + sk = profile.get('aws_secret_access_key') + if not ak and sk: + continue + c = next((c for c in _kc if name.startswith(c)), None) + cluster = s3_clusters[c] if c else None + if not cluster and profile.get('endpoint_url'): + cluster = {'outside': profile['endpoint_url']} + if not cluster: + continue + s3_profiles[name] = { + 'profile': name, + 'ak': ak, + 'sk': sk, + 'cluster': cluster, + } + +s3_buckets = {} +s3_bucket_prefixes = {} +for bucket, profile_name in _s3_buckets.items(): + profile = s3_profiles.get(profile_name) + if not profile: + continue + if '*' not in bucket: + s3_buckets[bucket] = profile + continue + bucket_prefix = bucket.rstrip('*') + if '*' not in bucket_prefix: + s3_bucket_prefixes[bucket_prefix] = profile + +__all__ = [ + 'config', + 's3_buckets', + 's3_bucket_prefixes', + 's3_profiles', + 's3_clusters', + 'spark_clusters', + 'kafka_clusters', + 'es_clusters', + 'kudu_clusters', + 'hive_clusters', +] diff --git a/llm_web_kit/html_layout_classify/s3/const.py b/llm_web_kit/html_layout_classify/s3/const.py new file mode 100644 index 00000000..f6cfb124 --- /dev/null +++ b/llm_web_kit/html_layout_classify/s3/const.py @@ -0,0 +1,18 @@ +SUCCESS_MARK_FILE = '_SUCCESS' +SUCCESS_MARK_FILE2 = '.SUCCESS' + +FAILURE_MARK_FILE = '_FAILURE' +RESERVE_MARK_FILE = '_RESERVE' +SUMMARY_MARK_FILE = '_SUMMARY' +DELETED_MARK_FILE = '_DELETED' + +FIELD_ID = 'id' +FIELD_SUB_PATH = 'sub_path' + + +def is_flag_field(f: str): + return f.startswith('is_') or f.startswith('has_') + + +def is_acc_field(f: str): + return f.startswith('acc_') diff --git a/llm_web_kit/html_layout_classify/s3/path.py b/llm_web_kit/html_layout_classify/s3/path.py new file mode 100644 index 00000000..8a869517 --- /dev/null +++ b/llm_web_kit/html_layout_classify/s3/path.py @@ -0,0 +1,27 @@ +import re + +__re_s3_path = re.compile('^s3a?://([^/]+)(?:/(.*))?$') + + +def is_s3_path(path: str) -> bool: + return path.startswith('s3://') or path.startswith('s3a://') + + +def ensure_s3a_path(path: str) -> str: + if not path.startswith('s3://'): + return path + return 's3a://' + path[len('s3://'):] + + +def ensure_s3_path(path: str) -> str: + if not path.startswith('s3a://'): + return path + return 's3://' + path[len('s3a://'):] + + +def split_s3_path(path: str): + """split bucket and key from path.""" + m = __re_s3_path.match(path) + if m is None: + return '', '' + return m.group(1), (m.group(2) or '') diff --git a/llm_web_kit/html_layout_classify/s3/read.py b/llm_web_kit/html_layout_classify/s3/read.py new file mode 100644 index 00000000..ec637698 --- /dev/null +++ b/llm_web_kit/html_layout_classify/s3/read.py @@ -0,0 +1,164 @@ +import io +import re +from typing import Tuple, Union + +from botocore.exceptions import ClientError +from botocore.response import StreamingBody + +from .client import get_s3_object +from .read_resume import ResumableS3Stream + +__re_bytes = re.compile('^([0-9]+)([,-])([0-9]+)$') +__re_bytes_1 = re.compile('^([0-9]+),([0-9]+)$') + +SIZE_1M = 1 << 20 + + +def read_s3_object_detailed( + path: str, + bytes: Union[str, None] = None, + client=None, +) -> Tuple[StreamingBody, dict]: + """ + ### Usage + ``` + obj = read_object("s3://bkt/path/to/file.txt") + for line in obj.iter_lines(): + handle(line) + ``` + """ + kwargs = {} + if bytes: + m = __re_bytes.match(bytes) + if m is not None: + frm = int(m.group(1)) + to = int(m.group(3)) + sep = m.group(2) + if sep == ',': + to = frm + to - 1 + if to >= frm: + kwargs['Range'] = f"bytes={frm}-{to}" + elif frm > 0: + kwargs['Range'] = f"bytes={frm}-" + + obj = get_s3_object(path, client=client, **kwargs) + return obj.pop('Body'), obj + + +def read_s3_object_bytes_detailed(path: str, size_limit=0, client=None): + """This method cache all content in memory, avoid large file.""" + import time + + retries = 0 + last_e = None + while True: + if retries > 5: + msg = f"Retry exhausted for reading [{path}]" + raise Exception(msg) from last_e + try: + stream, obj = read_s3_object_detailed(path, client=client) + with stream: + amt = size_limit if size_limit > 0 else None + buf = stream.read(amt) + break + except ClientError: + raise + except Exception as e: + last_e = e + retries += 1 + time.sleep(3) + + assert isinstance(buf, bytes) + return buf, obj + + +def read_s3_object_io_detailed(path: str, size_limit=0, client=None): + """This method cache all content in memory, avoid large file.""" + import io + + buf, obj = read_s3_object_bytes_detailed(path, size_limit, client=client) + return io.BytesIO(buf), obj + + +def read_s3_object_io(path: str, size_limit=0, client=None): + """This method cache all content in memory, avoid large file.""" + return read_s3_object_io_detailed(path, size_limit, client=client)[0] + + +def buffered_stream(stream, buffer_size: int, **kwargs): + from warcio.bufferedreaders import BufferedReader + return BufferedReader(stream, buffer_size, **kwargs) + + +def read_records(path: str, stream: io.IOBase, buffer_size: int): + """do not handle stream.close()""" + offset = 0 + + # if path.endswith(".warc") or path.endswith(".warc.gz"): + # from .read_warc import read_warc_records + # + # yield from read_warc_records(path, stream) + + if path.endswith('.gz'): + r = buffered_stream(stream, buffer_size, decomp_type='gzip') + while True: + line = r.readline() + if not line: + if r.read_next_member(): + continue + break + tell = stream.tell() - r.rem_length() + yield (line.decode('utf-8'), offset, tell - offset) + offset = tell + + elif path.endswith('.bz2'): + raise Exception('bz2 is not supported yet.') + + # elif path.endswith(".7z"): + # from .read_7z import SevenZipReadStream + # + # stream1 = SevenZipReadStream(stream) + # stream2 = buffered_stream(stream1, buffer_size) + # + # while True: + # line = stream2.readline() + # if not line: + # break + # yield (line.decode("utf-8"), int(-1), int(0)) + + else: # plaintext + stream1 = stream + if isinstance(stream, ResumableS3Stream): + stream1 = buffered_stream(stream, buffer_size) + + while True: + line = stream1.readline() + if not line: + break + yield (line.decode('utf-8'), offset, len(line)) + offset += len(line) + + +def read_s3_rows(path: str, use_stream=False, limit=0, size_limit=0, client=None): + from .row_fallback import Row + + if use_stream: + stream = ResumableS3Stream(path, size_limit, client=client) + else: + stream = read_s3_object_io(path, size_limit, client=client) + + with stream: + cnt = 0 + for record in read_records(path, stream, SIZE_1M): + value, offset, length = record + + if offset >= 0: + loc = f"{path}?bytes={offset},{length}" + else: + loc = path + + yield Row(value=value, loc=loc) + + cnt += 1 + if limit > 0 and cnt >= limit: + break diff --git a/llm_web_kit/html_layout_classify/s3/read_resume.py b/llm_web_kit/html_layout_classify/s3/read_resume.py new file mode 100644 index 00000000..94fc172e --- /dev/null +++ b/llm_web_kit/html_layout_classify/s3/read_resume.py @@ -0,0 +1,105 @@ +import io +import time + +from botocore.exceptions import (IncompleteReadError, ReadTimeoutError, + ResponseStreamingError) +from botocore.response import StreamingBody + +from .retry import get_s3_object_with_retry, head_s3_object_with_retry + + +class _EmptyStream: + def read(self, n): + return b'' + + def close(self): + pass + + +class ResumableS3Stream(io.IOBase): + def __init__(self, path: str, size_limit=0, client=None): + self.path = path + self.size_limit = size_limit + self.client = client + self.pos = 0 + self.size = -1 + self.stream = self.new_stream() + + def new_stream(self) -> StreamingBody: + if self.size < 0 and self.size_limit > 0: + head = head_s3_object_with_retry(self.path, True, self.client) + assert head is not None + if int(head['ContentLength']) <= self.size_limit: + self.size_limit = 0 + self.size = int(head['ContentLength']) + else: + self.size = self.size_limit + + if self.size_limit > 0: + kwargs = {'Range': f"bytes={self.pos}-{self.size_limit - 1}"} + else: + kwargs = {'Range': f"bytes={self.pos}-"} if self.pos > 0 else {} + + obj = get_s3_object_with_retry(self.path, client=self.client, **kwargs) + + if self.size < 0: + self.size = int(obj['ContentLength']) + + return obj['Body'] + + def readable(self): + return True + + def read(self, n=None): + if self.pos >= self.size: + return b'' + + retries = 0 + last_e = None + while True: + if retries > 5: + msg = f"Retry exhausted for reading [{self.path}]" + raise Exception(msg) from last_e + try: + data = self.stream.read(n) + self.pos += len(data) + return data + except (ReadTimeoutError, ResponseStreamingError, IncompleteReadError) as e: + try: + self.stream.close() + except Exception: + pass + last_e = e + retries += 1 + time.sleep(3) + self.stream = self.new_stream() + + def seekable(self): + return True + + def seek(self, offset, whence=io.SEEK_SET): + if whence == io.SEEK_SET: + pos = offset + elif whence == io.SEEK_CUR: + pos = self.pos + offset + elif whence == io.SEEK_END: + pos = self.size + offset + else: + raise ValueError('Invalid argument: whence') + if pos != self.pos: + self.pos = pos + try: + self.stream.close() + except Exception: + pass + if self.pos < self.size: + self.stream = self.new_stream() + else: + self.stream = _EmptyStream() + return pos + + def tell(self): + return self.pos + + def close(self): + self.stream.close() diff --git a/llm_web_kit/html_layout_classify/s3/reader.py b/llm_web_kit/html_layout_classify/s3/reader.py new file mode 100644 index 00000000..2edf8d6d --- /dev/null +++ b/llm_web_kit/html_layout_classify/s3/reader.py @@ -0,0 +1,127 @@ +"""read config from: + +- user config [specific s3 configs] +- user ~/.aws/ [for ak/sk of s3 clusters] +- default config [same for all users] +""" + +import configparser +import json +import os +import re +from pathlib import Path + +import yaml + +_USER_CONFIG_FILES = [ + '.xinghe.yaml', + '.xinghe.yml', + '.code-clean.yaml', + '.code-clean.yml', +] + + +def _get_home_dir(): + spark_user = os.environ.get('SPARK_USER') + if spark_user: + return os.path.join('/share', spark_user) # hard code + return Path.home() + + +def _read_ini_s3_section(s: str): + config = {} + for line in s.split('\n'): + ml = re.match(r'^\s*([^=\s]+)\s*=\s*(.+?)\s*$', line) + if ml: + config[ml.group(1)] = ml.group(2) + return config + + +def _read_ini_file(file: str): + parser = configparser.ConfigParser() + parser.read(file) + config = {} + for name, section in parser.items(): + name = re.sub(r'^\s*profile\s+', '', name) + name = name.strip().strip('"') + for key, val in section.items(): + if key == 's3': + val = _read_ini_s3_section(val) + config.setdefault(name, {}).update(val) + else: + config.setdefault(name, {})[key] = val + return config + + +def _merge_config(old: dict, new: dict): + ret = {} + for key, old_val in old.items(): + if key not in new: + ret[key] = old_val + continue + new_val = new.pop(key) + if isinstance(old_val, dict) and isinstance(new_val, dict): + ret[key] = _merge_config(old_val, new_val) + else: + ret[key] = new_val + for key, new_val in new.items(): + ret[key] = new_val + return ret + + +def _read_s3_config(): + home = _get_home_dir() + conf_file = os.path.join(home, '.aws', 'config') + cred_file = os.path.join(home, '.aws', 'credentials') + config = {} + if os.path.isfile(conf_file): + config = _read_ini_file(conf_file) + if os.path.isfile(cred_file): + config = _merge_config(config, _read_ini_file(cred_file)) + return {'s3': {'profiles': config}} + + +def _read_user_config(): + home = _get_home_dir() + for filename in _USER_CONFIG_FILES: + conf_file = os.path.join(home, filename) + if os.path.isfile(conf_file): + break + else: + return {} + with open(conf_file, 'r') as f: + config = yaml.safe_load(f) + assert isinstance(config, dict) + return config + + +def get_conf_dir(): + conf_dir = os.getenv('XINGHE_CONF_DIR') + if not conf_dir and os.getenv('BASE_DIR'): + conf_dir = os.path.join(os.environ['BASE_DIR'], 'conf') + if not conf_dir: + raise Exception('XINGHE_CONF_DIR not set.') + return conf_dir + + +def _read_default_config(): + conf_file = os.path.join(get_conf_dir(), 'config.yaml') + if not os.path.isfile(conf_file): + raise Exception(f"config file [{conf_file}] not found.") + with open(conf_file, 'r') as f: + config = yaml.safe_load(f) + assert isinstance(config, dict) + return config + + +def read_config(): + # config = _read_default_config() + config = {} + config = _merge_config(config, _read_s3_config()) + config = _merge_config(config, _read_user_config()) + return config + + +if __name__ == '__main__': + c = read_config() + print(json.dumps(c, indent=2)) diff --git a/llm_web_kit/html_layout_classify/s3/retry.py b/llm_web_kit/html_layout_classify/s3/retry.py new file mode 100644 index 00000000..3458f319 --- /dev/null +++ b/llm_web_kit/html_layout_classify/s3/retry.py @@ -0,0 +1,39 @@ +from botocore.exceptions import ClientError + +from .client import get_s3_object, head_s3_object, upload_s3_object +from .retry_utils import with_retry + + +@with_retry +def _get_s3_object_or_ex(path: str, client, **kwargs): + try: + return get_s3_object(path, client=client, **kwargs) + except ClientError as e: + return e + + +def get_s3_object_with_retry(path: str, client=None, **kwargs): + ret = _get_s3_object_or_ex(path, client, **kwargs) + if isinstance(ret, ClientError): + raise ret + return ret + + +@with_retry +def _head_s3_object_or_ex(path: str, raise_404: bool, client): + try: + return head_s3_object(path, raise_404, client=client) + except ClientError as e: + return e + + +def head_s3_object_with_retry(path: str, raise_404=False, client=None): + ret = _head_s3_object_or_ex(path, raise_404, client) + if isinstance(ret, ClientError): + raise ret + return ret + + +@with_retry(sleep_time=180) +def upload_s3_object_with_retry(path: str, local_file_path: str, client=None): + upload_s3_object(path, local_file_path, client=client) diff --git a/llm_web_kit/html_layout_classify/s3/retry_utils.py b/llm_web_kit/html_layout_classify/s3/retry_utils.py new file mode 100644 index 00000000..bf52ed33 --- /dev/null +++ b/llm_web_kit/html_layout_classify/s3/retry_utils.py @@ -0,0 +1,45 @@ +import functools +import time + + +def get_func_path(func) -> str: + if not callable(func): + return func + return f"{func.__module__}.{func.__name__}" + + +def with_retry(func=None, max_retries=5, sleep_time=3): + def try_sleep(): + try: + time.sleep(sleep_time) + except Exception: + pass + + def get_msg(func, args, kwargs): + msg = f"Retry exhausted for [{get_func_path(func)}]" + msg += f", args={args}" if args else '' + msg += f", kwargs={kwargs}" if kwargs else '' + return msg + + def handle(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + retries = 0 + last_e = None + while True: + if retries > max_retries: + msg = get_msg(func, args, kwargs) + raise Exception(msg) from last_e + try: + return func(*args, **kwargs) + except Exception as e: + retries += 1 + last_e = e + try_sleep() + + return wrapper + + if func is not None: + return handle(func) + + return handle diff --git a/llm_web_kit/html_layout_classify/s3/row_fallback.py b/llm_web_kit/html_layout_classify/s3/row_fallback.py new file mode 100644 index 00000000..25728e83 --- /dev/null +++ b/llm_web_kit/html_layout_classify/s3/row_fallback.py @@ -0,0 +1,191 @@ +# copied from pyspark/sql/types.py + +from typing import Any, Dict, List, Optional, Tuple, Union, overload + + +def _create_row(fields: Union['Row', List[str]], values: Union[Tuple[Any, ...], List[Any]]) -> 'Row': + row = Row(*values) + row.__fields__ = fields + return row + + +class Row(tuple): + """A row in :class:`DataFrame`. The fields in it can be accessed: + + * like attributes (``row.key``) + * like dictionary values (``row[key]``) + + ``key in row`` will search through row keys. + + Row can be used to create a row object by using named arguments. + It is not allowed to omit a named argument to represent that the value is + None or missing. This should be explicitly set to None in this case. + + .. versionchanged:: 3.0.0 + Rows created from named arguments no longer have + field names sorted alphabetically and will be ordered in the position as + entered. + + Examples + -------- + >>> from pyspark.sql import Row + >>> row = Row(name="Alice", age=11) + >>> row + Row(name='Alice', age=11) + >>> row['name'], row['age'] + ('Alice', 11) + >>> row.name, row.age + ('Alice', 11) + >>> 'name' in row + True + >>> 'wrong_key' in row + False + + Row also can be used to create another Row like class, then it + could be used to create Row objects, such as + + >>> Person = Row("name", "age") + >>> Person + + >>> 'name' in Person + True + >>> 'wrong_key' in Person + False + >>> Person("Alice", 11) + Row(name='Alice', age=11) + + This form can also be used to create rows as tuple values, i.e. with unnamed + fields. + + >>> row1 = Row("Alice", 11) + >>> row2 = Row(name="Alice", age=11) + >>> row1 == row2 + True + """ + + @overload + def __new__(cls, *args: str) -> 'Row': + ... + + @overload + def __new__(cls, **kwargs: Any) -> 'Row': + ... + + def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> 'Row': + if args and kwargs: + raise ValueError('Can not use both args ' 'and kwargs to create Row') + if kwargs: + # create row objects + row = tuple.__new__(cls, list(kwargs.values())) + row.__fields__ = list(kwargs.keys()) + return row + else: + # create row class or objects + return tuple.__new__(cls, args) + + def asDict(self, recursive: bool = False) -> Dict[str, Any]: + """Return as a dict. + + Parameters + ---------- + recursive : bool, optional + turns the nested Rows to dict (default: False). + + Notes + ----- + If a row contains duplicate field names, e.g., the rows of a join + between two :class:`DataFrame` that both have the fields of same names, + one of the duplicate fields will be selected by ``asDict``. ``__getitem__`` + will also return one of the duplicate fields, however returned value might + be different to ``asDict``. + + Examples + -------- + >>> from pyspark.sql import Row + >>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11} + True + >>> row = Row(key=1, value=Row(name='a', age=2)) + >>> row.asDict() == {'key': 1, 'value': Row(name='a', age=2)} + True + >>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}} + True + """ + if not hasattr(self, '__fields__'): + raise TypeError('Cannot convert a Row class into dict') + + if recursive: + + def conv(obj: Any) -> Any: + if isinstance(obj, Row): + return obj.asDict(True) + elif isinstance(obj, list): + return [conv(o) for o in obj] + elif isinstance(obj, dict): + return dict((k, conv(v)) for k, v in obj.items()) + else: + return obj + + return dict(zip(self.__fields__, (conv(o) for o in self))) + else: + return dict(zip(self.__fields__, self)) + + def __contains__(self, item: Any) -> bool: + if hasattr(self, '__fields__'): + return item in self.__fields__ + else: + return super(Row, self).__contains__(item) + + # let object acts like class + def __call__(self, *args: Any) -> 'Row': + """create new Row object.""" + if len(args) > len(self): + raise ValueError( + 'Can not create Row with fields %s, expected %d values ' 'but got %s' % (self, len(self), args)) + return _create_row(self, args) + + def __getitem__(self, item: Any) -> Any: + if isinstance(item, (int, slice)): + return super(Row, self).__getitem__(item) + try: + # it will be slow when it has many fields, + # but this will not be used in normal cases + idx = self.__fields__.index(item) + return super(Row, self).__getitem__(idx) + except IndexError: + raise KeyError(item) + except ValueError: + raise ValueError(item) + + def __getattr__(self, item: str) -> Any: + if item.startswith('__'): + raise AttributeError(item) + try: + # it will be slow when it has many fields, + # but this will not be used in normal cases + idx = self.__fields__.index(item) + return self[idx] + except IndexError: + raise AttributeError(item) + except ValueError: + raise AttributeError(item) + + def __setattr__(self, key: Any, value: Any) -> None: + if key != '__fields__': + raise RuntimeError('Row is read-only') + self.__dict__[key] = value + + def __reduce__( + self, + ) -> Union[str, Tuple[Any, ...]]: + """Returns a tuple so Python knows how to pickle Row.""" + if hasattr(self, '__fields__'): + return (_create_row, (self.__fields__, tuple(self))) + else: + return tuple.__reduce__(self) + + def __repr__(self) -> str: + """Printable representation of Row used in Python REPL.""" + if hasattr(self, '__fields__'): + return 'Row(%s)' % ', '.join('%s=%r' % (k, v) for k, v in zip(self.__fields__, tuple(self))) + else: + return '' % ', '.join('%r' % field for field in self) diff --git a/llm_web_kit/html_layout_classify/s3/utils.py b/llm_web_kit/html_layout_classify/s3/utils.py new file mode 100644 index 00000000..d71af4f0 --- /dev/null +++ b/llm_web_kit/html_layout_classify/s3/utils.py @@ -0,0 +1,34 @@ +from .client import (get_s3_client, get_s3_object, head_s3_object, + is_s3_404_error, list_s3_objects, + list_s3_objects_detailed, upload_s3_object) +from .conf import get_s3_config +from .path import ensure_s3_path, ensure_s3a_path, is_s3_path, split_s3_path +from .read import (read_s3_object_bytes_detailed, read_s3_object_detailed, + read_s3_object_io, read_s3_object_io_detailed, read_s3_rows) +from .retry import (get_s3_object_with_retry, head_s3_object_with_retry, + upload_s3_object_with_retry) +from .write import S3DocWriter + +__all__ = [ + 'is_s3_path', + 'ensure_s3a_path', + 'ensure_s3_path', + 'split_s3_path', + 'get_s3_config', + 'get_s3_client', + 'head_s3_object', + 'get_s3_object', + 'upload_s3_object', + 'list_s3_objects', + 'list_s3_objects_detailed', + 'is_s3_404_error', + 'read_s3_object_detailed', + 'read_s3_object_bytes_detailed', + 'read_s3_object_io', + 'read_s3_object_io_detailed', + 'read_s3_rows', + 'get_s3_object_with_retry', + 'head_s3_object_with_retry', + 'upload_s3_object_with_retry', + 'S3DocWriter', +] diff --git a/llm_web_kit/html_layout_classify/s3/write.py b/llm_web_kit/html_layout_classify/s3/write.py new file mode 100644 index 00000000..52094116 --- /dev/null +++ b/llm_web_kit/html_layout_classify/s3/write.py @@ -0,0 +1,136 @@ +import bz2 +import gzip +import io +import os +import uuid + +from .cmd import json_encode +from .const import FIELD_ID +from .retry import upload_s3_object_with_retry + +_compressions = { + 'gz': 'gz', + 'gzip': 'gz', + 'bz2': 'bz2', + 'bzip': 'bz2', + 'bzip2': 'bz2', + 'raw': 'raw', + 'none': 'raw', +} + + +def s3_upload_tmp_dir(): + tmp_dir = '/tmp/s3_upload' + try: + os.makedirs(tmp_dir, exist_ok=True) + test_file = os.path.join(tmp_dir, '__test_file') + try: + open(test_file, 'a').close() + finally: + try: + os.remove(test_file) + except Exception: + pass + except Exception: + tmp_dir = os.path.join('/tmp', 's3_upload') + os.makedirs(tmp_dir, exist_ok=True) + return tmp_dir + + +class S3DocWriter: + def __init__( + self, + path: str, + client=None, + tmp_dir=None, + skip_loc=False, + compression='', + ) -> None: + if not path.startswith('s3://'): + raise Exception(f"invalid s3 path [{path}].") + + compression = _compressions.get(compression) + if compression and not path.endswith(f".{compression}"): + raise Exception(f"path must endswith [.{compression}]") + if not compression and path.endswith('.gz'): + compression = 'gz' + if not compression and path.endswith('.bz2'): + compression = 'bz2' + + self.path = path + self.client = client + self.skip_loc = skip_loc + self.compression = compression + + if not tmp_dir: + tmp_dir = s3_upload_tmp_dir() + os.makedirs(tmp_dir, exist_ok=True) + + ext = self.__get_ext(path) + self.tmp_file = os.path.join(tmp_dir, f"{str(uuid.uuid4())}.{ext}") + self.tmp_fh = open(self.tmp_file, 'ab') + self.offset = 0 + + def __enter__(self): + return self + + def __exit__(self, type, value, tb): + self.flush() + + @staticmethod + def __get_ext(path: str): + filename = os.path.basename(path) + parts = filename.split('.') + if len(parts) > 1 and parts[0]: + return parts[-1] + return 'txt' + + def write(self, d: dict): + d = d.copy() + + if not self.skip_loc and 'doc_loc' in d: + track_loc = d.get('track_loc') or [] + track_loc.append(d['doc_loc']) + d['track_loc'] = track_loc + + if self.compression == 'gz': + if not self.skip_loc and FIELD_ID in d: + d['doc_loc'] = f"{self.path}?bytes={self.offset},0" + buf = io.BytesIO() + with gzip.GzipFile(fileobj=buf, mode='wb') as f: + f.write(json_encode(d)) + doc_bytes = buf.getvalue() + + elif self.compression == 'bz2': + if not self.skip_loc and FIELD_ID in d: + d['doc_loc'] = f"{self.path}?bytes={self.offset},0" + buf = io.BytesIO() + with bz2.BZ2File(buf, mode='wb') as f: + f.write(json_encode(d)) + doc_bytes = buf.getvalue() + + else: + doc_bytes = json_encode(d) + + # add doc_loc if doc has id + if not self.skip_loc and FIELD_ID in d: + doc_len, last_len = len(doc_bytes), 0 + while doc_len != last_len: + d['doc_loc'] = f"{self.path}?bytes={self.offset},{doc_len}" + doc_bytes = json_encode(d) + doc_len, last_len = len(doc_bytes), doc_len + + self.tmp_fh.write(doc_bytes) + self.offset += len(doc_bytes) + + return len(doc_bytes) + + def flush(self): + try: + self.tmp_fh.close() + upload_s3_object_with_retry(self.path, self.tmp_file, self.client) + finally: + try: + os.remove(self.tmp_file) + except Exception: + pass diff --git a/requirements/runtime.txt b/requirements/runtime.txt index cec5055c..6ff7ee03 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,3 +1,4 @@ +beautifulsoup4>=4.12.2 boto3==1.28.43 cairosvg==2.7.1 click==8.1.8