-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsplit.py
More file actions
21 lines (19 loc) · 926 Bytes
/
split.py
File metadata and controls
21 lines (19 loc) · 926 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from sklearn.model_selection import train_test_split
import os
def split_data(data_dir, test_size=0.2):
# random split and just put the names in a txt file
all_files = os.listdir(data_dir)
train_files, test_files = train_test_split(all_files, test_size=test_size, random_state=42)
with open(os.path.join(data_dir, 'train.txt'), 'w') as f:
for item in train_files:
f.write("%s\n" % item)
with open(os.path.join(data_dir, 'test.txt'), 'w') as f:
for item in test_files:
f.write("%s\n" % item)
print(f"Train files: {len(train_files)}, Test files: {len(test_files)}")
return train_files, test_files
if __name__ == "__main__":
data_dir = "/home/waleed/Documents/3DLearning/margin-line/final_02/context_margin_colors_faces_classes"
train_files, test_files = split_data(data_dir)
print("Train files:", train_files)
print("Test files:", test_files)