-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsplit.py
More file actions
31 lines (25 loc) · 903 Bytes
/
split.py
File metadata and controls
31 lines (25 loc) · 903 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import fire
import os
import pandas as pd
def split(input_path, output_path, cuda_list, num_samples: int = -1):
if type(cuda_list) == int:
cuda_list = [cuda_list]
df = pd.read_csv(input_path)
# Deterministically take the first num_samples rows when requested.
if num_samples is not None and int(num_samples) > -1:
num_samples = int(num_samples)
if num_samples == 0:
df = df.iloc[0:0]
elif num_samples > 0:
df = df.head(num_samples)
if not os.path.exists(output_path):
os.makedirs(output_path)
df_len = len(df)
cuda_list = list(cuda_list)
cuda_num = len(cuda_list)
for i in range(cuda_num):
start = i * df_len // cuda_num
end = (i + 1) * df_len // cuda_num
df[start:end].to_csv(f"{output_path}/{cuda_list[i]}.csv", index=True)
if __name__ == "__main__":
fire.Fire(split)