comparison preprocessing.py @ 0:252fd085940d draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools commit 67e0e1d123bcfffb10bab8cc04ae67259caec557
author bgruening
date Fri, 13 Jun 2025 11:23:35 +0000
parents
children 97bc82ee2a61
comparison
equal deleted inserted replaced
-1:000000000000 0:252fd085940d
1 import argparse
2 import os
3 import shutil
4
5 from sklearn.model_selection import train_test_split
6
7
8 def get_basename(f):
9 return os.path.splitext(os.path.basename(f))[0]
10
11
12 def pair_files(images_dir, labels_dir):
13
14 img_files = [f for f in os.listdir(images_dir)]
15 lbl_files = [f for f in os.listdir(labels_dir)]
16
17 image_dict = {get_basename(f): f for f in img_files}
18 label_dict = {get_basename(f): f for f in lbl_files}
19
20 keys = sorted(set(image_dict) & set(label_dict))
21
22 return [(image_dict[k], label_dict[k]) for k in keys]
23
24
25 def copy_pairs(pairs, image_src, label_src, image_dst, label_dst):
26 os.makedirs(image_dst, exist_ok=True)
27 os.makedirs(label_dst, exist_ok=True)
28 for img, lbl in pairs:
29 shutil.copy(os.path.join(image_src, img), os.path.join(image_dst, img))
30 shutil.copy(os.path.join(label_src, lbl), os.path.join(label_dst, lbl))
31
32
33 def write_yolo_yaml(output_dir):
34
35 yolo_yaml_path = os.path.join(output_dir, "yolo.yml")
36 with open(yolo_yaml_path, 'w') as f:
37 f.write(f"path: {output_dir}\n")
38 f.write("train: train\n")
39 f.write("val: valid\n")
40 f.write("test: test\n")
41 f.write("\n")
42 f.write("names: ['dataset']\n")
43
44
45 def main():
46 parser = argparse.ArgumentParser()
47 parser.add_argument("-i", "--images", required=True)
48 parser.add_argument("-y", "--labels", required=True)
49 parser.add_argument("-o", "--output", required=True)
50 parser.add_argument("-p", "--train_percent", type=int, default=70)
51 args = parser.parse_args()
52
53 all_pairs = pair_files(args.images, args.labels)
54 train_size = args.train_percent / 100.0
55 val_test_size = 1.0 - train_size
56
57 train_pairs, val_test_pairs = train_test_split(all_pairs, test_size=val_test_size, random_state=42)
58 val_pairs, test_pairs = train_test_split(val_test_pairs, test_size=0.5, random_state=42)
59
60 copy_pairs(train_pairs, args.images, args.labels, os.path.join(args.output, "train/images"), os.path.join(args.output, "train/labels"))
61 copy_pairs(val_pairs, args.images, args.labels, os.path.join(args.output, "valid/images"), os.path.join(args.output, "valid/labels"))
62 copy_pairs(test_pairs, args.images, args.labels, os.path.join(args.output, "test/images"), os.path.join(args.output, "test/labels"))
63
64 write_yolo_yaml(args.output)
65
66
67 if __name__ == "__main__":
68 main()