Skip to content

Commit f8c2231

Browse files
committed
update code format
Signed-off-by: guangli.bao <guangli.bao@daocloud.io>
1 parent 5d68449 commit f8c2231

File tree

2 files changed

+45
-49
lines changed

2 files changed

+45
-49
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
#!/bin/bash
22

33
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
4-
python3 shareGPT_data_preprocessing.py --parse $1
4+
python3 sharegpt_data_preprocessing.py --parse $1
Lines changed: 44 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,54 @@
1-
# SPDX-License-Identifier: Apache-2.0
2-
# Standard
31
import argparse
42
import json
53
import os
4+
import re
5+
from pathlib import Path
66

7-
# Third Party
8-
from transformers import AutoTokenizer
97
import numpy as np
10-
118
from datasets import load_dataset
12-
import re
9+
from transformers import AutoTokenizer
1310

14-
def extract_and_save_with_filtering():
11+
MIN_CHAR = 10
12+
MAX_CHAR = 1000
13+
14+
15+
def extract_and_save_with_filtering(file):
1516
"""substract human prompts and apply filtering conditions"""
16-
17-
dataset = load_dataset('json', data_files='./ShareGPT.json', split='train')
18-
17+
dataset = load_dataset("json", data_files=file, split="train")
1918
filtered_prompts = []
20-
19+
2120
for example in dataset:
22-
conversations = example.get('conversations', [])
23-
21+
conversations = example.get("conversations", [])
2422
if isinstance(conversations, list):
2523
for turn in conversations:
26-
if turn.get('from') in ['human', 'user']:
27-
prompt_text = turn['value'].strip()
28-
29-
# 应用过滤条件
30-
if (len(prompt_text) >= 10 and # 至少10个字符
31-
len(prompt_text) <= 1000 and # 最多1000个字符
32-
not prompt_text.startswith(('http://', 'https://')) and # 排除URL
33-
not re.search(r'[<>{}[\]\\]', prompt_text) and # 排除特殊字符
34-
not prompt_text.isdigit()): # 排除纯数字
35-
36-
filtered_prompts.append({
37-
'from': turn.get('from'),
38-
'text': prompt_text,
39-
'char_count': len(prompt_text),
40-
'word_count': len(prompt_text.split())
41-
})
42-
24+
if turn.get("from") in ["human", "user"]:
25+
prompt_text = turn["value"].strip()
26+
# apply filter conditions: more than 10 characters
27+
if (
28+
len(prompt_text) >= MIN_CHAR
29+
and
30+
# less thant 1000 characters
31+
len(prompt_text) <= MAX_CHAR
32+
and
33+
# except URLs
34+
not prompt_text.startswith(("http://", "https://"))
35+
and
36+
# except special characters
37+
not re.search(r"[<>{}[\]\\]", prompt_text)
38+
and not prompt_text.isdigit()
39+
): # except pure numbers
40+
filtered_prompts.append(
41+
{
42+
"from": turn.get("from"),
43+
"text": prompt_text,
44+
"char_count": len(prompt_text),
45+
"word_count": len(prompt_text.split()),
46+
}
47+
)
48+
4349
return filtered_prompts
44-
50+
51+
4552
if __name__ == "__main__":
4653
parser = argparse.ArgumentParser(description="Process data percentage.")
4754
parser.add_argument(
@@ -50,13 +57,12 @@ def extract_and_save_with_filtering():
5057
default=1,
5158
help="The percentage of data to process (0 to 1). Default is 1 (100%).",
5259
)
53-
5460
args = parser.parse_args()
5561

56-
with open("ShareGPT_V3_unfiltered_cleaned_split.json", "r", encoding="utf-8") as file:
62+
sharegpt_file = "ShareGPT_V3_unfiltered_cleaned_split.json"
63+
with Path(sharegpt_file).open("r", encoding="utf-8") as file:
5764
data = json.load(file)
5865

59-
6066
def estimate_num_tokens(text: str) -> int:
6167
if not hasattr(estimate_num_tokens, "tokenizer"):
6268
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -65,15 +71,10 @@ def estimate_num_tokens(text: str) -> int:
6571
)
6672
return len(estimate_num_tokens.tokenizer.tokenize(text))
6773

68-
6974
num_of_ids = len(data)
70-
print(f"Number of IDs: {num_of_ids}")
7175
data = data[: int(num_of_ids * args.parse)]
72-
73-
count = 0
74-
7576
for d in data:
76-
d["num_round"] = len(d["conversations"]) # human is one round, gpt is another round
77+
d["num_round"] = len(d["conversations"])
7778
human_tokens = []
7879
gpt_tokens = []
7980
for conv in d["conversations"]:
@@ -96,15 +97,10 @@ def estimate_num_tokens(text: str) -> int:
9697
d["average_gpt_token"] = float(np.mean(gpt_tokens))
9798
d["max_gpt_token"] = float(np.max(gpt_tokens))
9899

99-
count += 1
100-
print(f"Finished {count}")
101-
102100
# save unfiletered datasets to ShareGPT.json
103-
with open("ShareGPT.json", "w", encoding="utf-8") as file:
101+
with Path("ShareGPT.json").open("w", encoding="utf-8") as file:
104102
json.dump(data, file, ensure_ascii=False, indent=2)
105103
# filter from: human prompts and save again
106-
filtered_result = extract_and_save_with_filtering()
107-
with open("ShareGPT.json", "w", encoding="utf-8") as file:
104+
filtered_result = extract_and_save_with_filtering("ShareGPT.json")
105+
with Path("ShareGPT.json").open("w", encoding="utf-8") as file:
108106
json.dump(filtered_result, file, ensure_ascii=False, indent=2)
109-
110-

0 commit comments

Comments
 (0)