diff --git a/llm_web_kit/html_layout_classify/__init__.py b/llm_web_kit/html_layout_classify/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/llm_web_kit/html_layout_classify/html_layout_classify.md b/llm_web_kit/html_layout_classify/html_layout_classify.md
new file mode 100644
index 00000000..2a7595a5
--- /dev/null
+++ b/llm_web_kit/html_layout_classify/html_layout_classify.md
@@ -0,0 +1,29 @@
+# html layout classify layout分类
+
+## 环境
+
+配置 .xinghe.yaml
+配置 .llm_web_kit.jsonc
+
+## 入参
+
+layout_sample_dir: 每个layout_id 随机选取3条的.jsonl文件路径或文件夹路径
+layout_classify_dir:计算每个layout_id 对应的分类结果文件夹路径
+
+layout_sample_dir 字段说明:
+
+| 字段 | 类型 | 描述 | 是否必须 |
+| --------- | ------ | ---------------------------- | -------- |
+| layout_id | string | layout id | 是 |
+| url | string | 数据url | 是 |
+| simp_html | string | html原数据经过简化处理的html | 是 |
+
+layout_classify_dir 字段说明:
+
+| 字段 | 类型 | 描述 | 是否必须 |
+| ------------- | ------ | --------------------------------------------------------------- | -------- |
+| url_list | list | layout id 对应的url | 是 |
+| layout_id | string | layout id | 是 |
+| page_type | string | layout_id 经过分类之后的分类结果('other', 'article', 'forum') | 是 |
+| max_pred_prod | float | 分类模型的分类可靠度 | 是 |
+| version | string | 模型版本 | 是 |
diff --git a/llm_web_kit/html_layout_classify/main.py b/llm_web_kit/html_layout_classify/main.py
new file mode 100644
index 00000000..b1b4f4d8
--- /dev/null
+++ b/llm_web_kit/html_layout_classify/main.py
@@ -0,0 +1,115 @@
+import argparse
+import json
+
+from loguru import logger
+
+from llm_web_kit.html_layout_classify.s3.client import list_s3_objects
+from llm_web_kit.html_layout_classify.s3.read import read_s3_rows
+from llm_web_kit.html_layout_classify.s3.write import S3DocWriter
+from llm_web_kit.model.html_layout_cls import HTMLLayoutClassifier
+
+CLASSIFY_MAP = {'other': 0, 'article': 1, 'forum': 2}
+INT_CLASSIFY_MAP = {0: 'other', 1: 'article', 2: 'forum'}
+MODEL_VERESION = '0.0.2'
+
+
+def __list_layout_sample_dir(s3_dir: str) -> list:
+ """列出所有的layout sample json文件."""
+ if s3_dir.endswith('/'):
+ layout_sample_files = [f for f in list(list_s3_objects(s3_dir, recursive=True)) if f.endswith('.jsonl')]
+ return layout_sample_files
+ return [s3_dir]
+
+
+def __parse_predict_res(predict_res: list, layout_samples: list) -> int:
+ """解析模型分类结果."""
+ # [{'pred_prob': '0.626298', 'pred_label': 'other'}]
+ res = {
+ 'url_list': [i['url'] for i in layout_samples],
+ 'layout_id': layout_samples[0]['layout_id'],
+ 'page_type': INT_CLASSIFY_MAP.get(
+ __most_frequent_or_zero([CLASSIFY_MAP.get(i['pred_label'], 0) for i in predict_res]), 'other'),
+ 'max_pred_prod': max([i['pred_prob'] for i in predict_res]),
+ 'version': MODEL_VERESION,
+ }
+ return res
+
+
+def __most_frequent_or_zero(int_elements):
+ """计算分类结果最多的类型,否则为0."""
+ if not int_elements:
+ return 0
+
+ elif len(int_elements) == 1:
+ return int_elements[0]
+
+ elif len(int_elements) == 2:
+ return int_elements[0] if int_elements[0] == int_elements[1] else 0
+
+ elif len(int_elements) == 3:
+ if int_elements[0] == int_elements[1] or int_elements[0] == int_elements[2]:
+ return int_elements[0]
+ elif int_elements[1] == int_elements[2]:
+ return int_elements[1]
+ else:
+ return 0
+ else:
+ logger.error(f"most_frequent_or_zero error:{int_elements}")
+
+
+def __process_one_layout_sample(layout_sample_file: str, layout_type_dir: str):
+ """处理一个layout的代表群体."""
+ output_file_path = f"{layout_type_dir}{layout_sample_file.split('/')[-1]}"
+ writer = S3DocWriter(output_file_path)
+
+ def __get_type_by_layoutid(layout_samples: list):
+ # html_str_input = [general_simplify_html_str(html['html_source']) for html in layout_samples]
+ html_str_input = [html['simp_html'] for html in layout_samples]
+ layout_classify_lst = model.predict(html_str_input)
+ layout_classify = __parse_predict_res(layout_classify_lst, layout_samples)
+ return layout_classify
+
+ current_layout_id, samples = None, []
+ idx = 0
+ for row in read_s3_rows(layout_sample_file):
+ idx += 1
+ detail_data = json.loads(row.value)
+ if current_layout_id == detail_data['layout_id']:
+ samples.append(detail_data)
+ else:
+ if samples:
+ classify_res = __get_type_by_layoutid(samples)
+ writer.write(classify_res)
+ current_layout_id, samples = detail_data['layout_id'], [detail_data]
+ if samples:
+ classify_res = __get_type_by_layoutid(samples)
+ writer.write(classify_res)
+ writer.flush()
+ logger.info(f'read {layout_sample_file} file {idx} rows')
+
+
+def __set_config():
+ global model
+ model = HTMLLayoutClassifier()
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Process files with specified function.')
+ parser.add_argument('layout_sample_dir', help='待分类文件夹路径或文件路径')
+ parser.add_argument('layout_classify_dir', help='已分类结果输出路径')
+
+ args = parser.parse_args()
+
+ try:
+ # 加载模型
+ __set_config()
+ layout_sample_files = __list_layout_sample_dir(args.layout_sample_dir)
+ # 读取每个json文件的数据,根据每个layout_id为一簇,计算每个layout_id 对应的 layout_classify,并将结果写入s3
+ for layout_sample_file in layout_sample_files:
+ __process_one_layout_sample(layout_sample_file, args.layout_classify_dir)
+ except Exception as e:
+ logger.error(f'get layout classify fail: {e}')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/llm_web_kit/html_layout_classify/s3/__init__.py b/llm_web_kit/html_layout_classify/s3/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/llm_web_kit/html_layout_classify/s3/client.py b/llm_web_kit/html_layout_classify/s3/client.py
new file mode 100644
index 00000000..3450ffd6
--- /dev/null
+++ b/llm_web_kit/html_layout_classify/s3/client.py
@@ -0,0 +1,139 @@
+import time
+from typing import Dict, List, Union
+
+import boto3
+from boto3.s3.transfer import TransferConfig
+from botocore.client import Config
+from botocore.exceptions import ClientError
+
+from .conf import get_s3_config
+from .path import split_s3_path
+
+
+def get_s3_client(path: Union[str, List[str]], outside=False):
+ s3_config = get_s3_config(path, outside)
+ try:
+ return boto3.client(
+ 's3',
+ aws_access_key_id=s3_config['ak'],
+ aws_secret_access_key=s3_config['sk'],
+ endpoint_url=s3_config['endpoint'],
+ config=Config(
+ s3={'addressing_style': 'path'},
+ retries={'max_attempts': 8, 'mode': 'standard'},
+ connect_timeout=600,
+ read_timeout=600,
+ ),
+ )
+ except Exception:
+ # older boto3 do not support retries.mode param.
+ return boto3.client(
+ 's3',
+ aws_access_key_id=s3_config['ak'],
+ aws_secret_access_key=s3_config['sk'],
+ endpoint_url=s3_config['endpoint'],
+ config=Config(s3={'addressing_style': 'path'}, retries={'max_attempts': 8}),
+ )
+
+
+def is_s3_404_error(e: Exception):
+ if not isinstance(e, ClientError):
+ return False
+ return (
+ e.response.get('Error', {}).get('Code') in ['404', 'NoSuchKey']
+ or e.response.get('Error', {}).get('Message') == 'Not Found'
+ or e.response.get('ResponseMetadata', {}).get('HTTPStatusCode') == 404
+ )
+
+
+def head_s3_object(path: str, raise_404=False, client=None) -> Union[Dict, None]:
+ client = client or get_s3_client(path)
+ bucket, key = split_s3_path(path)
+ try:
+ resp = client.head_object(Bucket=bucket, Key=key)
+ return resp
+ except ClientError as e:
+ if not raise_404 and is_s3_404_error(e):
+ return None
+ raise
+
+
+def _restore_and_wait(client, bucket: str, key: str, path: str):
+ while True:
+ head = client.head_object(Bucket=bucket, Key=key)
+ restore = head.get('Restore', '')
+ if not restore:
+ req = {'Days': 1, 'GlacierJobParameters': {'Tier': 'Standard'}}
+ client.restore_object(Bucket=bucket, Key=key, RestoreRequest=req)
+ print(f"restoration-started: {path}")
+ elif 'ongoing-request="true"' in restore:
+ print(f"restoration-ongoing: {path}")
+ elif 'ongoing-request="false"' in restore:
+ print(f"restoration-complete: {path}")
+ break
+ time.sleep(3)
+
+
+def get_s3_object(path: str, client=None, **kwargs) -> dict:
+ client = client or get_s3_client(path)
+ bucket, key = split_s3_path(path)
+ try:
+ return client.get_object(Bucket=bucket, Key=key, **kwargs)
+ except ClientError as e:
+ if e.response.get('Error', {}).get('Code') == 'GlacierObjectNotRestore':
+ _restore_and_wait(client, bucket, key, path)
+ return client.get_object(Bucket=bucket, Key=key, **kwargs)
+ raise
+
+
+def list_s3_objects(path: str, recursive=False, is_prefix=False, limit=0, client=None):
+ for content in list_s3_objects_detailed(path, recursive, is_prefix, limit, client=client):
+ yield content[0]
+
+
+def list_s3_objects_detailed(path: str, recursive=False, is_prefix=False, limit=0, client=None):
+ client = client or get_s3_client(path)
+ if limit > 1000:
+ raise Exception('limit greater than 1000 is not supported.')
+ if not path.endswith('/') and not is_prefix:
+ path += '/'
+ bucket, prefix = split_s3_path(path)
+ marker = None
+ while True:
+ list_kwargs = dict(MaxKeys=1000, Bucket=bucket, Prefix=prefix)
+ if limit > 0:
+ list_kwargs['MaxKeys'] = limit
+ if not recursive:
+ list_kwargs['Delimiter'] = '/'
+ if marker:
+ list_kwargs['Marker'] = marker
+ response = client.list_objects(**list_kwargs)
+ marker = None
+ if not recursive:
+ common_prefixes = response.get('CommonPrefixes', [])
+ for cp in common_prefixes:
+ yield (f"s3://{bucket}/{cp['Prefix']}", cp)
+ if common_prefixes:
+ marker = common_prefixes[-1]['Prefix']
+ contents = response.get('Contents', [])
+ for content in contents:
+ if not content['Key'].endswith('/'):
+ yield (f"s3://{bucket}/{content['Key']}", content)
+ if contents:
+ last_key = contents[-1]['Key']
+ if not marker or last_key > marker:
+ marker = last_key
+ if limit or not response.get('IsTruncated') or not marker:
+ break
+
+
+def upload_s3_object(path: str, local_file_path: str, client=None):
+ client = client or get_s3_client(path)
+ # upload
+ MB = 1024 ** 2
+ config = TransferConfig(
+ multipart_threshold=128 * MB,
+ multipart_chunksize=16 * MB, # 156.25GiB maximum
+ )
+ bucket, key = split_s3_path(path)
+ client.upload_file(local_file_path, bucket, key, Config=config)
diff --git a/llm_web_kit/html_layout_classify/s3/cmd.py b/llm_web_kit/html_layout_classify/s3/cmd.py
new file mode 100644
index 00000000..580a3040
--- /dev/null
+++ b/llm_web_kit/html_layout_classify/s3/cmd.py
@@ -0,0 +1,146 @@
+import ast
+import json
+import re
+from typing import Union
+
+try:
+ import orjson
+except Exception:
+ orjson = None
+
+_surrogates_re = r'[\ud800-\udfff]'
+
+
+def json_dumps(d: dict, **kwargs) -> str:
+ if not kwargs and orjson:
+ try:
+ return orjson.dumps(d).decode('utf-8')
+ except Exception:
+ pass
+ return json.dumps(d, ensure_ascii=False, **kwargs)
+
+
+def json_loads(s: Union[str, bytes], **kwargs) -> dict:
+ if not kwargs and orjson:
+ try:
+ return orjson.loads(s)
+ except Exception:
+ pass
+ try:
+ return json.loads(s, **kwargs)
+ except Exception as e:
+ if 'enclosed in double quotes' not in str(e):
+ raise e
+ if isinstance(s, bytes):
+ s = s.decode('utf-8')
+ else:
+ s = str(s)
+ return ast.literal_eval(s)
+
+
+def json_encode(d: dict, end='\n', **kwargs) -> bytes:
+ return str_encode(json_dumps(d, **kwargs) + end)
+
+
+def str_encode(s: str) -> bytes:
+ # try remote special characters
+ s = re.sub(_surrogates_re, '\ufffd', s)
+
+ try:
+ return s.encode('utf-8')
+ except UnicodeEncodeError as e:
+ debug_start = max(0, e.start - 1000)
+ debug_end = min(len(s), e.end + 1000)
+ print(f"{debug_start=}, {debug_end=}, debug_s={s[debug_start:debug_end]}")
+ raise
+
+# def json_print(obj):
+# if isinstance(obj, list) and len(obj):
+# return json_print(obj[0])
+# if isinstance(obj, bytes):
+# return json_print(obj.decode("utf-8"))
+# if isinstance(obj, str):
+# return json_print(json_loads(obj))
+# if isinstance(obj, dict):
+# return print(json_dumps(obj, indent=2))
+#
+# from .row_fallback import Row
+#
+# if isinstance(obj, Row) and "value" in obj:
+# return json_print(obj.value)
+#
+# print(obj)
+#
+# def _format_datetime(dt):
+# if not dt:
+# return ""
+# dt = dt.replace(tzinfo=timezone.utc).astimezone(tz=None) # localtime
+# return dt.strftime("%y-%m-%d %H:%M:%S %Z")
+#
+#
+# def _format_size(size):
+# if size is None:
+# return ""
+# size = str(size)
+# parts = []
+# while len(size):
+# part_size = 3
+# if not parts and len(size) % part_size:
+# part_size = len(size) % part_size
+# parts.append(size[:part_size])
+# size = size[part_size:]
+# return ",".join(parts)
+#
+#
+# def _format_detail(detail):
+# path, obj = detail
+# if path.endswith("/"):
+# return f"{'DIR'.rjust(53)} {path}"
+# tm = _format_datetime(obj.get("LastModified"))
+# sz = _format_size(obj.get("Size") or obj.get("ContentLength", 0))
+# owner = obj.get("Owner", {}).get("ID", "")
+# return f"{tm} {sz.rjust(15)} {owner.rjust(15)} {path}"
+#
+#
+# def head(path):
+# obj_head = head_s3_object_with_retry(path)
+# if obj_head is not None:
+# print(json_dumps(obj_head, indent=2, default=str))
+#
+#
+# def cat(path, limit=1, show_loc=False):
+# if "?bytes=" in path:
+# row = read_s3_row(path)
+# if row is not None:
+# if show_loc:
+# print(row.loc)
+# json_print(row)
+# return
+# for row in read_s3_rows(path, use_stream=True, limit=limit):
+# if show_loc:
+# print(row.loc)
+# json_print(row)
+#
+#
+# def ls(path, limit=100):
+# for obj in list_s3_objects(path, limit=limit):
+# print(obj)
+#
+#
+# def ls_r(path, limit=100):
+# for item in list_s3_objects(path, True, True, limit):
+# print(item)
+#
+#
+# def ll(path, limit=100):
+# for detail in list_s3_objects_detailed(path, limit=limit):
+# print(_format_detail(detail))
+#
+#
+# def ll_r(path, limit=100):
+# for detail in list_s3_objects_detailed(path, True, True, limit):
+# print(_format_detail(detail))
+#
+#
+# def download(path):
+# print(get_s3_presigned_url(path))
diff --git a/llm_web_kit/html_layout_classify/s3/conf.py b/llm_web_kit/html_layout_classify/s3/conf.py
new file mode 100644
index 00000000..02115bd7
--- /dev/null
+++ b/llm_web_kit/html_layout_classify/s3/conf.py
@@ -0,0 +1,92 @@
+import fnmatch
+import random
+import socket
+from typing import List, Union
+
+from .config import s3_bucket_prefixes, s3_buckets, s3_profiles
+from .path import split_s3_path
+
+
+def _is_inside_cluster(cluster_config: dict):
+ inside_hosts = cluster_config.get('inside_hosts')
+ if not (isinstance(inside_hosts, list) and inside_hosts):
+ return False
+ inside_hosts = [str(pat).lower() for pat in inside_hosts]
+ try:
+ host = socket.gethostname().lower()
+ except Exception:
+ return False
+ for host_pattern in inside_hosts:
+ if fnmatch.fnmatch(host, host_pattern):
+ return True
+ return False
+
+
+def _get_s3_bucket_config(path: str):
+ bucket = split_s3_path(path)[0] if path else ''
+ bucket_config = s3_buckets.get(bucket)
+ if not bucket_config:
+ for prefix, c in s3_bucket_prefixes.items():
+ if bucket.startswith(prefix):
+ bucket_config = c
+ break
+ if not bucket_config:
+ bucket_config = s3_profiles.get(bucket)
+ if not bucket_config:
+ bucket_config = s3_buckets.get('[default]')
+ assert bucket_config is not None
+ return bucket_config
+
+
+def _get_s3_config(
+ bucket_config,
+ outside: bool,
+ prefer_ip=False,
+ prefer_auto=False,
+):
+ cluster = bucket_config['cluster']
+ assert isinstance(cluster, dict)
+
+ if outside:
+ endpoint_key = 'outside'
+ elif prefer_auto:
+ endpoint_key = 'auto'
+ elif _is_inside_cluster(cluster):
+ endpoint_key = 'inside'
+ else:
+ endpoint_key = 'outside'
+
+ if endpoint_key not in cluster:
+ endpoint_key = 'outside'
+
+ if prefer_ip and f"{endpoint_key}_ips" in cluster:
+ endpoint_key = f"{endpoint_key}_ips"
+
+ endpoints = cluster[endpoint_key]
+
+ if isinstance(endpoints, str):
+ endpoint = endpoints
+ elif isinstance(endpoints, list):
+ endpoint = random.choice(endpoints)
+ else:
+ raise Exception(f"invalid endpoint for [{cluster}]")
+
+ return {
+ 'endpoint': endpoint,
+ 'ak': bucket_config['ak'],
+ 'sk': bucket_config['sk'],
+ }
+
+
+def get_s3_config(path: Union[str, List[str]], outside=False):
+ paths = [path] if type(path) == str else path
+ bucket_config = None
+ for p in paths:
+ bc = _get_s3_bucket_config(p)
+ if bucket_config in [bc, None]:
+ bucket_config = bc
+ continue
+ raise Exception(f"{paths} have different s3 config, cannot read together.")
+ if not bucket_config:
+ raise Exception('path is empty.')
+ return _get_s3_config(bucket_config, outside, prefer_ip=True)
diff --git a/llm_web_kit/html_layout_classify/s3/config.py b/llm_web_kit/html_layout_classify/s3/config.py
new file mode 100644
index 00000000..0499b96e
--- /dev/null
+++ b/llm_web_kit/html_layout_classify/s3/config.py
@@ -0,0 +1,62 @@
+from .reader import read_config
+
+config = read_config()
+
+spark_clusters: dict = config.get('spark', {}).get('clusters', {})
+kafka_clusters: dict = config.get('kafka', {}).get('clusters', {})
+es_clusters: dict = config.get('es', {}).get('clusters', {})
+kudu_clusters: dict = config.get('kudu', {}).get('clusters', {})
+hive_clusters: dict = config.get('hive', {}).get('clusters', {})
+
+_s3_buckets: dict = config.get('s3', {}).get('buckets', {})
+_s3_profiles: dict = config.get('s3', {}).get('profiles', {})
+s3_clusters: dict = config.get('s3', {}).get('endpoints', {})
+
+_kc = list(s3_clusters.keys())
+_kc.sort(key=lambda s: -len(s))
+
+s3_profiles = {}
+for name, profile in _s3_profiles.items():
+ assert isinstance(profile, dict)
+ ak = profile.get('aws_access_key_id')
+ sk = profile.get('aws_secret_access_key')
+ if not ak and sk:
+ continue
+ c = next((c for c in _kc if name.startswith(c)), None)
+ cluster = s3_clusters[c] if c else None
+ if not cluster and profile.get('endpoint_url'):
+ cluster = {'outside': profile['endpoint_url']}
+ if not cluster:
+ continue
+ s3_profiles[name] = {
+ 'profile': name,
+ 'ak': ak,
+ 'sk': sk,
+ 'cluster': cluster,
+ }
+
+s3_buckets = {}
+s3_bucket_prefixes = {}
+for bucket, profile_name in _s3_buckets.items():
+ profile = s3_profiles.get(profile_name)
+ if not profile:
+ continue
+ if '*' not in bucket:
+ s3_buckets[bucket] = profile
+ continue
+ bucket_prefix = bucket.rstrip('*')
+ if '*' not in bucket_prefix:
+ s3_bucket_prefixes[bucket_prefix] = profile
+
+__all__ = [
+ 'config',
+ 's3_buckets',
+ 's3_bucket_prefixes',
+ 's3_profiles',
+ 's3_clusters',
+ 'spark_clusters',
+ 'kafka_clusters',
+ 'es_clusters',
+ 'kudu_clusters',
+ 'hive_clusters',
+]
diff --git a/llm_web_kit/html_layout_classify/s3/const.py b/llm_web_kit/html_layout_classify/s3/const.py
new file mode 100644
index 00000000..f6cfb124
--- /dev/null
+++ b/llm_web_kit/html_layout_classify/s3/const.py
@@ -0,0 +1,18 @@
+SUCCESS_MARK_FILE = '_SUCCESS'
+SUCCESS_MARK_FILE2 = '.SUCCESS'
+
+FAILURE_MARK_FILE = '_FAILURE'
+RESERVE_MARK_FILE = '_RESERVE'
+SUMMARY_MARK_FILE = '_SUMMARY'
+DELETED_MARK_FILE = '_DELETED'
+
+FIELD_ID = 'id'
+FIELD_SUB_PATH = 'sub_path'
+
+
+def is_flag_field(f: str):
+ return f.startswith('is_') or f.startswith('has_')
+
+
+def is_acc_field(f: str):
+ return f.startswith('acc_')
diff --git a/llm_web_kit/html_layout_classify/s3/path.py b/llm_web_kit/html_layout_classify/s3/path.py
new file mode 100644
index 00000000..8a869517
--- /dev/null
+++ b/llm_web_kit/html_layout_classify/s3/path.py
@@ -0,0 +1,27 @@
+import re
+
+__re_s3_path = re.compile('^s3a?://([^/]+)(?:/(.*))?$')
+
+
+def is_s3_path(path: str) -> bool:
+ return path.startswith('s3://') or path.startswith('s3a://')
+
+
+def ensure_s3a_path(path: str) -> str:
+ if not path.startswith('s3://'):
+ return path
+ return 's3a://' + path[len('s3://'):]
+
+
+def ensure_s3_path(path: str) -> str:
+ if not path.startswith('s3a://'):
+ return path
+ return 's3://' + path[len('s3a://'):]
+
+
+def split_s3_path(path: str):
+ """split bucket and key from path."""
+ m = __re_s3_path.match(path)
+ if m is None:
+ return '', ''
+ return m.group(1), (m.group(2) or '')
diff --git a/llm_web_kit/html_layout_classify/s3/read.py b/llm_web_kit/html_layout_classify/s3/read.py
new file mode 100644
index 00000000..ec637698
--- /dev/null
+++ b/llm_web_kit/html_layout_classify/s3/read.py
@@ -0,0 +1,164 @@
+import io
+import re
+from typing import Tuple, Union
+
+from botocore.exceptions import ClientError
+from botocore.response import StreamingBody
+
+from .client import get_s3_object
+from .read_resume import ResumableS3Stream
+
+__re_bytes = re.compile('^([0-9]+)([,-])([0-9]+)$')
+__re_bytes_1 = re.compile('^([0-9]+),([0-9]+)$')
+
+SIZE_1M = 1 << 20
+
+
+def read_s3_object_detailed(
+ path: str,
+ bytes: Union[str, None] = None,
+ client=None,
+) -> Tuple[StreamingBody, dict]:
+ """
+ ### Usage
+ ```
+ obj = read_object("s3://bkt/path/to/file.txt")
+ for line in obj.iter_lines():
+ handle(line)
+ ```
+ """
+ kwargs = {}
+ if bytes:
+ m = __re_bytes.match(bytes)
+ if m is not None:
+ frm = int(m.group(1))
+ to = int(m.group(3))
+ sep = m.group(2)
+ if sep == ',':
+ to = frm + to - 1
+ if to >= frm:
+ kwargs['Range'] = f"bytes={frm}-{to}"
+ elif frm > 0:
+ kwargs['Range'] = f"bytes={frm}-"
+
+ obj = get_s3_object(path, client=client, **kwargs)
+ return obj.pop('Body'), obj
+
+
+def read_s3_object_bytes_detailed(path: str, size_limit=0, client=None):
+ """This method cache all content in memory, avoid large file."""
+ import time
+
+ retries = 0
+ last_e = None
+ while True:
+ if retries > 5:
+ msg = f"Retry exhausted for reading [{path}]"
+ raise Exception(msg) from last_e
+ try:
+ stream, obj = read_s3_object_detailed(path, client=client)
+ with stream:
+ amt = size_limit if size_limit > 0 else None
+ buf = stream.read(amt)
+ break
+ except ClientError:
+ raise
+ except Exception as e:
+ last_e = e
+ retries += 1
+ time.sleep(3)
+
+ assert isinstance(buf, bytes)
+ return buf, obj
+
+
+def read_s3_object_io_detailed(path: str, size_limit=0, client=None):
+ """This method cache all content in memory, avoid large file."""
+ import io
+
+ buf, obj = read_s3_object_bytes_detailed(path, size_limit, client=client)
+ return io.BytesIO(buf), obj
+
+
+def read_s3_object_io(path: str, size_limit=0, client=None):
+ """This method cache all content in memory, avoid large file."""
+ return read_s3_object_io_detailed(path, size_limit, client=client)[0]
+
+
+def buffered_stream(stream, buffer_size: int, **kwargs):
+ from warcio.bufferedreaders import BufferedReader
+ return BufferedReader(stream, buffer_size, **kwargs)
+
+
+def read_records(path: str, stream: io.IOBase, buffer_size: int):
+ """do not handle stream.close()"""
+ offset = 0
+
+ # if path.endswith(".warc") or path.endswith(".warc.gz"):
+ # from .read_warc import read_warc_records
+ #
+ # yield from read_warc_records(path, stream)
+
+ if path.endswith('.gz'):
+ r = buffered_stream(stream, buffer_size, decomp_type='gzip')
+ while True:
+ line = r.readline()
+ if not line:
+ if r.read_next_member():
+ continue
+ break
+ tell = stream.tell() - r.rem_length()
+ yield (line.decode('utf-8'), offset, tell - offset)
+ offset = tell
+
+ elif path.endswith('.bz2'):
+ raise Exception('bz2 is not supported yet.')
+
+ # elif path.endswith(".7z"):
+ # from .read_7z import SevenZipReadStream
+ #
+ # stream1 = SevenZipReadStream(stream)
+ # stream2 = buffered_stream(stream1, buffer_size)
+ #
+ # while True:
+ # line = stream2.readline()
+ # if not line:
+ # break
+ # yield (line.decode("utf-8"), int(-1), int(0))
+
+ else: # plaintext
+ stream1 = stream
+ if isinstance(stream, ResumableS3Stream):
+ stream1 = buffered_stream(stream, buffer_size)
+
+ while True:
+ line = stream1.readline()
+ if not line:
+ break
+ yield (line.decode('utf-8'), offset, len(line))
+ offset += len(line)
+
+
+def read_s3_rows(path: str, use_stream=False, limit=0, size_limit=0, client=None):
+ from .row_fallback import Row
+
+ if use_stream:
+ stream = ResumableS3Stream(path, size_limit, client=client)
+ else:
+ stream = read_s3_object_io(path, size_limit, client=client)
+
+ with stream:
+ cnt = 0
+ for record in read_records(path, stream, SIZE_1M):
+ value, offset, length = record
+
+ if offset >= 0:
+ loc = f"{path}?bytes={offset},{length}"
+ else:
+ loc = path
+
+ yield Row(value=value, loc=loc)
+
+ cnt += 1
+ if limit > 0 and cnt >= limit:
+ break
diff --git a/llm_web_kit/html_layout_classify/s3/read_resume.py b/llm_web_kit/html_layout_classify/s3/read_resume.py
new file mode 100644
index 00000000..94fc172e
--- /dev/null
+++ b/llm_web_kit/html_layout_classify/s3/read_resume.py
@@ -0,0 +1,105 @@
+import io
+import time
+
+from botocore.exceptions import (IncompleteReadError, ReadTimeoutError,
+ ResponseStreamingError)
+from botocore.response import StreamingBody
+
+from .retry import get_s3_object_with_retry, head_s3_object_with_retry
+
+
+class _EmptyStream:
+ def read(self, n):
+ return b''
+
+ def close(self):
+ pass
+
+
+class ResumableS3Stream(io.IOBase):
+ def __init__(self, path: str, size_limit=0, client=None):
+ self.path = path
+ self.size_limit = size_limit
+ self.client = client
+ self.pos = 0
+ self.size = -1
+ self.stream = self.new_stream()
+
+ def new_stream(self) -> StreamingBody:
+ if self.size < 0 and self.size_limit > 0:
+ head = head_s3_object_with_retry(self.path, True, self.client)
+ assert head is not None
+ if int(head['ContentLength']) <= self.size_limit:
+ self.size_limit = 0
+ self.size = int(head['ContentLength'])
+ else:
+ self.size = self.size_limit
+
+ if self.size_limit > 0:
+ kwargs = {'Range': f"bytes={self.pos}-{self.size_limit - 1}"}
+ else:
+ kwargs = {'Range': f"bytes={self.pos}-"} if self.pos > 0 else {}
+
+ obj = get_s3_object_with_retry(self.path, client=self.client, **kwargs)
+
+ if self.size < 0:
+ self.size = int(obj['ContentLength'])
+
+ return obj['Body']
+
+ def readable(self):
+ return True
+
+ def read(self, n=None):
+ if self.pos >= self.size:
+ return b''
+
+ retries = 0
+ last_e = None
+ while True:
+ if retries > 5:
+ msg = f"Retry exhausted for reading [{self.path}]"
+ raise Exception(msg) from last_e
+ try:
+ data = self.stream.read(n)
+ self.pos += len(data)
+ return data
+ except (ReadTimeoutError, ResponseStreamingError, IncompleteReadError) as e:
+ try:
+ self.stream.close()
+ except Exception:
+ pass
+ last_e = e
+ retries += 1
+ time.sleep(3)
+ self.stream = self.new_stream()
+
+ def seekable(self):
+ return True
+
+ def seek(self, offset, whence=io.SEEK_SET):
+ if whence == io.SEEK_SET:
+ pos = offset
+ elif whence == io.SEEK_CUR:
+ pos = self.pos + offset
+ elif whence == io.SEEK_END:
+ pos = self.size + offset
+ else:
+ raise ValueError('Invalid argument: whence')
+ if pos != self.pos:
+ self.pos = pos
+ try:
+ self.stream.close()
+ except Exception:
+ pass
+ if self.pos < self.size:
+ self.stream = self.new_stream()
+ else:
+ self.stream = _EmptyStream()
+ return pos
+
+ def tell(self):
+ return self.pos
+
+ def close(self):
+ self.stream.close()
diff --git a/llm_web_kit/html_layout_classify/s3/reader.py b/llm_web_kit/html_layout_classify/s3/reader.py
new file mode 100644
index 00000000..2edf8d6d
--- /dev/null
+++ b/llm_web_kit/html_layout_classify/s3/reader.py
@@ -0,0 +1,127 @@
+"""read config from:
+
+- user config [specific s3 configs]
+- user ~/.aws/ [for ak/sk of s3 clusters]
+- default config [same for all users]
+"""
+
+import configparser
+import json
+import os
+import re
+from pathlib import Path
+
+import yaml
+
+_USER_CONFIG_FILES = [
+ '.xinghe.yaml',
+ '.xinghe.yml',
+ '.code-clean.yaml',
+ '.code-clean.yml',
+]
+
+
+def _get_home_dir():
+ spark_user = os.environ.get('SPARK_USER')
+ if spark_user:
+ return os.path.join('/share', spark_user) # hard code
+ return Path.home()
+
+
+def _read_ini_s3_section(s: str):
+ config = {}
+ for line in s.split('\n'):
+ ml = re.match(r'^\s*([^=\s]+)\s*=\s*(.+?)\s*$', line)
+ if ml:
+ config[ml.group(1)] = ml.group(2)
+ return config
+
+
+def _read_ini_file(file: str):
+ parser = configparser.ConfigParser()
+ parser.read(file)
+ config = {}
+ for name, section in parser.items():
+ name = re.sub(r'^\s*profile\s+', '', name)
+ name = name.strip().strip('"')
+ for key, val in section.items():
+ if key == 's3':
+ val = _read_ini_s3_section(val)
+ config.setdefault(name, {}).update(val)
+ else:
+ config.setdefault(name, {})[key] = val
+ return config
+
+
+def _merge_config(old: dict, new: dict):
+ ret = {}
+ for key, old_val in old.items():
+ if key not in new:
+ ret[key] = old_val
+ continue
+ new_val = new.pop(key)
+ if isinstance(old_val, dict) and isinstance(new_val, dict):
+ ret[key] = _merge_config(old_val, new_val)
+ else:
+ ret[key] = new_val
+ for key, new_val in new.items():
+ ret[key] = new_val
+ return ret
+
+
+def _read_s3_config():
+ home = _get_home_dir()
+ conf_file = os.path.join(home, '.aws', 'config')
+ cred_file = os.path.join(home, '.aws', 'credentials')
+ config = {}
+ if os.path.isfile(conf_file):
+ config = _read_ini_file(conf_file)
+ if os.path.isfile(cred_file):
+ config = _merge_config(config, _read_ini_file(cred_file))
+ return {'s3': {'profiles': config}}
+
+
+def _read_user_config():
+ home = _get_home_dir()
+ for filename in _USER_CONFIG_FILES:
+ conf_file = os.path.join(home, filename)
+ if os.path.isfile(conf_file):
+ break
+ else:
+ return {}
+ with open(conf_file, 'r') as f:
+ config = yaml.safe_load(f)
+ assert isinstance(config, dict)
+ return config
+
+
+def get_conf_dir():
+ conf_dir = os.getenv('XINGHE_CONF_DIR')
+ if not conf_dir and os.getenv('BASE_DIR'):
+ conf_dir = os.path.join(os.environ['BASE_DIR'], 'conf')
+ if not conf_dir:
+ raise Exception('XINGHE_CONF_DIR not set.')
+ return conf_dir
+
+
+def _read_default_config():
+ conf_file = os.path.join(get_conf_dir(), 'config.yaml')
+ if not os.path.isfile(conf_file):
+ raise Exception(f"config file [{conf_file}] not found.")
+ with open(conf_file, 'r') as f:
+ config = yaml.safe_load(f)
+ assert isinstance(config, dict)
+ return config
+
+
+def read_config():
+ # config = _read_default_config()
+ config = {}
+ config = _merge_config(config, _read_s3_config())
+ config = _merge_config(config, _read_user_config())
+ return config
+
+
+if __name__ == '__main__':
+ c = read_config()
+ print(json.dumps(c, indent=2))
diff --git a/llm_web_kit/html_layout_classify/s3/retry.py b/llm_web_kit/html_layout_classify/s3/retry.py
new file mode 100644
index 00000000..3458f319
--- /dev/null
+++ b/llm_web_kit/html_layout_classify/s3/retry.py
@@ -0,0 +1,39 @@
+from botocore.exceptions import ClientError
+
+from .client import get_s3_object, head_s3_object, upload_s3_object
+from .retry_utils import with_retry
+
+
+@with_retry
+def _get_s3_object_or_ex(path: str, client, **kwargs):
+ try:
+ return get_s3_object(path, client=client, **kwargs)
+ except ClientError as e:
+ return e
+
+
+def get_s3_object_with_retry(path: str, client=None, **kwargs):
+ ret = _get_s3_object_or_ex(path, client, **kwargs)
+ if isinstance(ret, ClientError):
+ raise ret
+ return ret
+
+
+@with_retry
+def _head_s3_object_or_ex(path: str, raise_404: bool, client):
+ try:
+ return head_s3_object(path, raise_404, client=client)
+ except ClientError as e:
+ return e
+
+
+def head_s3_object_with_retry(path: str, raise_404=False, client=None):
+ ret = _head_s3_object_or_ex(path, raise_404, client)
+ if isinstance(ret, ClientError):
+ raise ret
+ return ret
+
+
+@with_retry(sleep_time=180)
+def upload_s3_object_with_retry(path: str, local_file_path: str, client=None):
+ upload_s3_object(path, local_file_path, client=client)
diff --git a/llm_web_kit/html_layout_classify/s3/retry_utils.py b/llm_web_kit/html_layout_classify/s3/retry_utils.py
new file mode 100644
index 00000000..bf52ed33
--- /dev/null
+++ b/llm_web_kit/html_layout_classify/s3/retry_utils.py
@@ -0,0 +1,45 @@
+import functools
+import time
+
+
+def get_func_path(func) -> str:
+ if not callable(func):
+ return func
+ return f"{func.__module__}.{func.__name__}"
+
+
+def with_retry(func=None, max_retries=5, sleep_time=3):
+ def try_sleep():
+ try:
+ time.sleep(sleep_time)
+ except Exception:
+ pass
+
+ def get_msg(func, args, kwargs):
+ msg = f"Retry exhausted for [{get_func_path(func)}]"
+ msg += f", args={args}" if args else ''
+ msg += f", kwargs={kwargs}" if kwargs else ''
+ return msg
+
+ def handle(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ retries = 0
+ last_e = None
+ while True:
+ if retries > max_retries:
+ msg = get_msg(func, args, kwargs)
+ raise Exception(msg) from last_e
+ try:
+ return func(*args, **kwargs)
+ except Exception as e:
+ retries += 1
+ last_e = e
+ try_sleep()
+
+ return wrapper
+
+ if func is not None:
+ return handle(func)
+
+ return handle
diff --git a/llm_web_kit/html_layout_classify/s3/row_fallback.py b/llm_web_kit/html_layout_classify/s3/row_fallback.py
new file mode 100644
index 00000000..25728e83
--- /dev/null
+++ b/llm_web_kit/html_layout_classify/s3/row_fallback.py
@@ -0,0 +1,191 @@
+# copied from pyspark/sql/types.py
+
+from typing import Any, Dict, List, Optional, Tuple, Union, overload
+
+
+def _create_row(fields: Union['Row', List[str]], values: Union[Tuple[Any, ...], List[Any]]) -> 'Row':
+ row = Row(*values)
+ row.__fields__ = fields
+ return row
+
+
+class Row(tuple):
+ """A row in :class:`DataFrame`. The fields in it can be accessed:
+
+ * like attributes (``row.key``)
+ * like dictionary values (``row[key]``)
+
+ ``key in row`` will search through row keys.
+
+ Row can be used to create a row object by using named arguments.
+ It is not allowed to omit a named argument to represent that the value is
+ None or missing. This should be explicitly set to None in this case.
+
+ .. versionchanged:: 3.0.0
+ Rows created from named arguments no longer have
+ field names sorted alphabetically and will be ordered in the position as
+ entered.
+
+ Examples
+ --------
+ >>> from pyspark.sql import Row
+ >>> row = Row(name="Alice", age=11)
+ >>> row
+ Row(name='Alice', age=11)
+ >>> row['name'], row['age']
+ ('Alice', 11)
+ >>> row.name, row.age
+ ('Alice', 11)
+ >>> 'name' in row
+ True
+ >>> 'wrong_key' in row
+ False
+
+ Row also can be used to create another Row like class, then it
+ could be used to create Row objects, such as
+
+ >>> Person = Row("name", "age")
+ >>> Person
+
+ >>> 'name' in Person
+ True
+ >>> 'wrong_key' in Person
+ False
+ >>> Person("Alice", 11)
+ Row(name='Alice', age=11)
+
+ This form can also be used to create rows as tuple values, i.e. with unnamed
+ fields.
+
+ >>> row1 = Row("Alice", 11)
+ >>> row2 = Row(name="Alice", age=11)
+ >>> row1 == row2
+ True
+ """
+
+ @overload
+ def __new__(cls, *args: str) -> 'Row':
+ ...
+
+ @overload
+ def __new__(cls, **kwargs: Any) -> 'Row':
+ ...
+
+ def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> 'Row':
+ if args and kwargs:
+ raise ValueError('Can not use both args ' 'and kwargs to create Row')
+ if kwargs:
+ # create row objects
+ row = tuple.__new__(cls, list(kwargs.values()))
+ row.__fields__ = list(kwargs.keys())
+ return row
+ else:
+ # create row class or objects
+ return tuple.__new__(cls, args)
+
+ def asDict(self, recursive: bool = False) -> Dict[str, Any]:
+ """Return as a dict.
+
+ Parameters
+ ----------
+ recursive : bool, optional
+ turns the nested Rows to dict (default: False).
+
+ Notes
+ -----
+ If a row contains duplicate field names, e.g., the rows of a join
+ between two :class:`DataFrame` that both have the fields of same names,
+ one of the duplicate fields will be selected by ``asDict``. ``__getitem__``
+ will also return one of the duplicate fields, however returned value might
+ be different to ``asDict``.
+
+ Examples
+ --------
+ >>> from pyspark.sql import Row
+ >>> Row(name="Alice", age=11).asDict() == {'name': 'Alice', 'age': 11}
+ True
+ >>> row = Row(key=1, value=Row(name='a', age=2))
+ >>> row.asDict() == {'key': 1, 'value': Row(name='a', age=2)}
+ True
+ >>> row.asDict(True) == {'key': 1, 'value': {'name': 'a', 'age': 2}}
+ True
+ """
+ if not hasattr(self, '__fields__'):
+ raise TypeError('Cannot convert a Row class into dict')
+
+ if recursive:
+
+ def conv(obj: Any) -> Any:
+ if isinstance(obj, Row):
+ return obj.asDict(True)
+ elif isinstance(obj, list):
+ return [conv(o) for o in obj]
+ elif isinstance(obj, dict):
+ return dict((k, conv(v)) for k, v in obj.items())
+ else:
+ return obj
+
+ return dict(zip(self.__fields__, (conv(o) for o in self)))
+ else:
+ return dict(zip(self.__fields__, self))
+
+ def __contains__(self, item: Any) -> bool:
+ if hasattr(self, '__fields__'):
+ return item in self.__fields__
+ else:
+ return super(Row, self).__contains__(item)
+
+ # let object acts like class
+ def __call__(self, *args: Any) -> 'Row':
+ """create new Row object."""
+ if len(args) > len(self):
+ raise ValueError(
+ 'Can not create Row with fields %s, expected %d values ' 'but got %s' % (self, len(self), args))
+ return _create_row(self, args)
+
+ def __getitem__(self, item: Any) -> Any:
+ if isinstance(item, (int, slice)):
+ return super(Row, self).__getitem__(item)
+ try:
+ # it will be slow when it has many fields,
+ # but this will not be used in normal cases
+ idx = self.__fields__.index(item)
+ return super(Row, self).__getitem__(idx)
+ except IndexError:
+ raise KeyError(item)
+ except ValueError:
+ raise ValueError(item)
+
+ def __getattr__(self, item: str) -> Any:
+ if item.startswith('__'):
+ raise AttributeError(item)
+ try:
+ # it will be slow when it has many fields,
+ # but this will not be used in normal cases
+ idx = self.__fields__.index(item)
+ return self[idx]
+ except IndexError:
+ raise AttributeError(item)
+ except ValueError:
+ raise AttributeError(item)
+
+ def __setattr__(self, key: Any, value: Any) -> None:
+ if key != '__fields__':
+ raise RuntimeError('Row is read-only')
+ self.__dict__[key] = value
+
+ def __reduce__(
+ self,
+ ) -> Union[str, Tuple[Any, ...]]:
+ """Returns a tuple so Python knows how to pickle Row."""
+ if hasattr(self, '__fields__'):
+ return (_create_row, (self.__fields__, tuple(self)))
+ else:
+ return tuple.__reduce__(self)
+
+ def __repr__(self) -> str:
+ """Printable representation of Row used in Python REPL."""
+ if hasattr(self, '__fields__'):
+ return 'Row(%s)' % ', '.join('%s=%r' % (k, v) for k, v in zip(self.__fields__, tuple(self)))
+ else:
+ return '' % ', '.join('%r' % field for field in self)
diff --git a/llm_web_kit/html_layout_classify/s3/utils.py b/llm_web_kit/html_layout_classify/s3/utils.py
new file mode 100644
index 00000000..d71af4f0
--- /dev/null
+++ b/llm_web_kit/html_layout_classify/s3/utils.py
@@ -0,0 +1,34 @@
+from .client import (get_s3_client, get_s3_object, head_s3_object,
+ is_s3_404_error, list_s3_objects,
+ list_s3_objects_detailed, upload_s3_object)
+from .conf import get_s3_config
+from .path import ensure_s3_path, ensure_s3a_path, is_s3_path, split_s3_path
+from .read import (read_s3_object_bytes_detailed, read_s3_object_detailed,
+ read_s3_object_io, read_s3_object_io_detailed, read_s3_rows)
+from .retry import (get_s3_object_with_retry, head_s3_object_with_retry,
+ upload_s3_object_with_retry)
+from .write import S3DocWriter
+
+__all__ = [
+ 'is_s3_path',
+ 'ensure_s3a_path',
+ 'ensure_s3_path',
+ 'split_s3_path',
+ 'get_s3_config',
+ 'get_s3_client',
+ 'head_s3_object',
+ 'get_s3_object',
+ 'upload_s3_object',
+ 'list_s3_objects',
+ 'list_s3_objects_detailed',
+ 'is_s3_404_error',
+ 'read_s3_object_detailed',
+ 'read_s3_object_bytes_detailed',
+ 'read_s3_object_io',
+ 'read_s3_object_io_detailed',
+ 'read_s3_rows',
+ 'get_s3_object_with_retry',
+ 'head_s3_object_with_retry',
+ 'upload_s3_object_with_retry',
+ 'S3DocWriter',
+]
diff --git a/llm_web_kit/html_layout_classify/s3/write.py b/llm_web_kit/html_layout_classify/s3/write.py
new file mode 100644
index 00000000..52094116
--- /dev/null
+++ b/llm_web_kit/html_layout_classify/s3/write.py
@@ -0,0 +1,136 @@
+import bz2
+import gzip
+import io
+import os
+import uuid
+
+from .cmd import json_encode
+from .const import FIELD_ID
+from .retry import upload_s3_object_with_retry
+
+_compressions = {
+ 'gz': 'gz',
+ 'gzip': 'gz',
+ 'bz2': 'bz2',
+ 'bzip': 'bz2',
+ 'bzip2': 'bz2',
+ 'raw': 'raw',
+ 'none': 'raw',
+}
+
+
+def s3_upload_tmp_dir():
+ tmp_dir = '/tmp/s3_upload'
+ try:
+ os.makedirs(tmp_dir, exist_ok=True)
+ test_file = os.path.join(tmp_dir, '__test_file')
+ try:
+ open(test_file, 'a').close()
+ finally:
+ try:
+ os.remove(test_file)
+ except Exception:
+ pass
+ except Exception:
+ tmp_dir = os.path.join('/tmp', 's3_upload')
+ os.makedirs(tmp_dir, exist_ok=True)
+ return tmp_dir
+
+
+class S3DocWriter:
+ def __init__(
+ self,
+ path: str,
+ client=None,
+ tmp_dir=None,
+ skip_loc=False,
+ compression='',
+ ) -> None:
+ if not path.startswith('s3://'):
+ raise Exception(f"invalid s3 path [{path}].")
+
+ compression = _compressions.get(compression)
+ if compression and not path.endswith(f".{compression}"):
+ raise Exception(f"path must endswith [.{compression}]")
+ if not compression and path.endswith('.gz'):
+ compression = 'gz'
+ if not compression and path.endswith('.bz2'):
+ compression = 'bz2'
+
+ self.path = path
+ self.client = client
+ self.skip_loc = skip_loc
+ self.compression = compression
+
+ if not tmp_dir:
+ tmp_dir = s3_upload_tmp_dir()
+ os.makedirs(tmp_dir, exist_ok=True)
+
+ ext = self.__get_ext(path)
+ self.tmp_file = os.path.join(tmp_dir, f"{str(uuid.uuid4())}.{ext}")
+ self.tmp_fh = open(self.tmp_file, 'ab')
+ self.offset = 0
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, tb):
+ self.flush()
+
+ @staticmethod
+ def __get_ext(path: str):
+ filename = os.path.basename(path)
+ parts = filename.split('.')
+ if len(parts) > 1 and parts[0]:
+ return parts[-1]
+ return 'txt'
+
+ def write(self, d: dict):
+ d = d.copy()
+
+ if not self.skip_loc and 'doc_loc' in d:
+ track_loc = d.get('track_loc') or []
+ track_loc.append(d['doc_loc'])
+ d['track_loc'] = track_loc
+
+ if self.compression == 'gz':
+ if not self.skip_loc and FIELD_ID in d:
+ d['doc_loc'] = f"{self.path}?bytes={self.offset},0"
+ buf = io.BytesIO()
+ with gzip.GzipFile(fileobj=buf, mode='wb') as f:
+ f.write(json_encode(d))
+ doc_bytes = buf.getvalue()
+
+ elif self.compression == 'bz2':
+ if not self.skip_loc and FIELD_ID in d:
+ d['doc_loc'] = f"{self.path}?bytes={self.offset},0"
+ buf = io.BytesIO()
+ with bz2.BZ2File(buf, mode='wb') as f:
+ f.write(json_encode(d))
+ doc_bytes = buf.getvalue()
+
+ else:
+ doc_bytes = json_encode(d)
+
+ # add doc_loc if doc has id
+ if not self.skip_loc and FIELD_ID in d:
+ doc_len, last_len = len(doc_bytes), 0
+ while doc_len != last_len:
+ d['doc_loc'] = f"{self.path}?bytes={self.offset},{doc_len}"
+ doc_bytes = json_encode(d)
+ doc_len, last_len = len(doc_bytes), doc_len
+
+ self.tmp_fh.write(doc_bytes)
+ self.offset += len(doc_bytes)
+
+ return len(doc_bytes)
+
+ def flush(self):
+ try:
+ self.tmp_fh.close()
+ upload_s3_object_with_retry(self.path, self.tmp_file, self.client)
+ finally:
+ try:
+ os.remove(self.tmp_file)
+ except Exception:
+ pass
diff --git a/requirements/runtime.txt b/requirements/runtime.txt
index cec5055c..6ff7ee03 100644
--- a/requirements/runtime.txt
+++ b/requirements/runtime.txt
@@ -1,3 +1,4 @@
+beautifulsoup4>=4.12.2
boto3==1.28.43
cairosvg==2.7.1
click==8.1.8