Mercurial > repos > bgruening > json2yolosegment
comparison preprocessing.py @ 0:252fd085940d draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools commit 67e0e1d123bcfffb10bab8cc04ae67259caec557
| author | bgruening |
|---|---|
| date | Fri, 13 Jun 2025 11:23:35 +0000 |
| parents | |
| children | 97bc82ee2a61 |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:252fd085940d |
|---|---|
| 1 import argparse | |
| 2 import os | |
| 3 import shutil | |
| 4 | |
| 5 from sklearn.model_selection import train_test_split | |
| 6 | |
| 7 | |
| 8 def get_basename(f): | |
| 9 return os.path.splitext(os.path.basename(f))[0] | |
| 10 | |
| 11 | |
| 12 def pair_files(images_dir, labels_dir): | |
| 13 | |
| 14 img_files = [f for f in os.listdir(images_dir)] | |
| 15 lbl_files = [f for f in os.listdir(labels_dir)] | |
| 16 | |
| 17 image_dict = {get_basename(f): f for f in img_files} | |
| 18 label_dict = {get_basename(f): f for f in lbl_files} | |
| 19 | |
| 20 keys = sorted(set(image_dict) & set(label_dict)) | |
| 21 | |
| 22 return [(image_dict[k], label_dict[k]) for k in keys] | |
| 23 | |
| 24 | |
| 25 def copy_pairs(pairs, image_src, label_src, image_dst, label_dst): | |
| 26 os.makedirs(image_dst, exist_ok=True) | |
| 27 os.makedirs(label_dst, exist_ok=True) | |
| 28 for img, lbl in pairs: | |
| 29 shutil.copy(os.path.join(image_src, img), os.path.join(image_dst, img)) | |
| 30 shutil.copy(os.path.join(label_src, lbl), os.path.join(label_dst, lbl)) | |
| 31 | |
| 32 | |
| 33 def write_yolo_yaml(output_dir): | |
| 34 | |
| 35 yolo_yaml_path = os.path.join(output_dir, "yolo.yml") | |
| 36 with open(yolo_yaml_path, 'w') as f: | |
| 37 f.write(f"path: {output_dir}\n") | |
| 38 f.write("train: train\n") | |
| 39 f.write("val: valid\n") | |
| 40 f.write("test: test\n") | |
| 41 f.write("\n") | |
| 42 f.write("names: ['dataset']\n") | |
| 43 | |
| 44 | |
| 45 def main(): | |
| 46 parser = argparse.ArgumentParser() | |
| 47 parser.add_argument("-i", "--images", required=True) | |
| 48 parser.add_argument("-y", "--labels", required=True) | |
| 49 parser.add_argument("-o", "--output", required=True) | |
| 50 parser.add_argument("-p", "--train_percent", type=int, default=70) | |
| 51 args = parser.parse_args() | |
| 52 | |
| 53 all_pairs = pair_files(args.images, args.labels) | |
| 54 train_size = args.train_percent / 100.0 | |
| 55 val_test_size = 1.0 - train_size | |
| 56 | |
| 57 train_pairs, val_test_pairs = train_test_split(all_pairs, test_size=val_test_size, random_state=42) | |
| 58 val_pairs, test_pairs = train_test_split(val_test_pairs, test_size=0.5, random_state=42) | |
| 59 | |
| 60 copy_pairs(train_pairs, args.images, args.labels, os.path.join(args.output, "train/images"), os.path.join(args.output, "train/labels")) | |
| 61 copy_pairs(val_pairs, args.images, args.labels, os.path.join(args.output, "valid/images"), os.path.join(args.output, "valid/labels")) | |
| 62 copy_pairs(test_pairs, args.images, args.labels, os.path.join(args.output, "test/images"), os.path.join(args.output, "test/labels")) | |
| 63 | |
| 64 write_yolo_yaml(args.output) | |
| 65 | |
| 66 | |
| 67 if __name__ == "__main__": | |
| 68 main() |
