comparison create_yaml.py @ 0:356d58ae85fa draft default tip

planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/main/tools/biapy commit 63860b5c6c21e0b76b1c55a5e71cafcb77d6cc84
author iuc
date Fri, 06 Feb 2026 17:50:32 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:356d58ae85fa
1 import argparse
2 import sys
3
4 import requests
5 import yaml
6
7
8 def download_yaml_template(workflow, dims, biapy_version=""):
9 template_dir_map = {
10 "SEMANTIC_SEG": "semantic_segmentation",
11 "INSTANCE_SEG": "instance_segmentation",
12 "DETECTION": "detection",
13 "DENOISING": "denoising",
14 "SUPER_RESOLUTION": "super-resolution",
15 "CLASSIFICATION": "classification",
16 "SELF_SUPERVISED": "self-supervised",
17 "IMAGE_TO_IMAGE": "image-to-image",
18 }
19
20 # Use .get() to avoid KeyError if workflow is unexpected
21 dir_name = template_dir_map.get(workflow)
22 if not dir_name:
23 raise ValueError(f"Unknown workflow: {workflow}")
24
25 template_name = f"{dir_name}/{dims.lower()}_{dir_name}.yaml"
26 url = f"https://raw.githubusercontent.com/BiaPyX/BiaPy/refs/tags/v{biapy_version}/templates/{template_name}"
27
28 print(f"Downloading YAML template from {url}")
29 try:
30 response = requests.get(url, timeout=10) # Added timeout
31 response.raise_for_status() # Automatically raises HTTPError for 4xx/5xx
32 return yaml.safe_load(response.text) or {}
33 except requests.exceptions.RequestException as e:
34 print(f"Error: Could not download template. {e}")
35 sys.exit(1) # Exit gracefully rather than crashing with a stack trace
36
37
38 def tuple_to_list(obj):
39 """Convert tuples to lists recursively."""
40 if isinstance(obj, tuple):
41 return list(obj)
42 if isinstance(obj, dict):
43 return {k: tuple_to_list(v) for k, v in obj.items()}
44 if isinstance(obj, list):
45 return [tuple_to_list(v) for v in obj]
46 return obj
47
48
49 def main():
50 parser = argparse.ArgumentParser(
51 description="Generate a YAML configuration from given arguments."
52 )
53 parser.add_argument(
54 '--input_config_path', default='', type=str,
55 help="Input configuration file to reuse"
56 )
57 parser.add_argument(
58 '--new_config', action='store_true',
59 help="Whether to create a new config or reuse an existing one."
60 )
61 parser.add_argument(
62 '--out_config_path', required=True, type=str,
63 help="Path to save the generated YAML configuration."
64 )
65 parser.add_argument(
66 '--workflow', default='semantic', type=str,
67 choices=['semantic', 'instance', 'detection', 'denoising',
68 'sr', 'cls', 'sr2', 'i2i'],
69 )
70 parser.add_argument(
71 '--dims', default='2d', type=str,
72 choices=['2d_stack', '2d', '3d'],
73 help="Number of dimensions for the problem"
74 )
75 parser.add_argument(
76 '--obj_slices', default='', type=str,
77 choices=['', '1-5', '5-10', '10-20', '20-60', '60+'],
78 help="Number of slices for the objects in the images"
79 )
80 parser.add_argument(
81 '--obj_size', default='0-25', type=str,
82 choices=['0-25', '25-100', '100-200', '200-500', '500+'],
83 help="Size of the objects in the images"
84 )
85 parser.add_argument(
86 '--img_channel', default=1, type=int,
87 help="Number of channels in the input images"
88 )
89 parser.add_argument(
90 '--model_source', default='biapy',
91 choices=['biapy', 'bmz', 'torchvision'],
92 help="Source of the model."
93 )
94 parser.add_argument(
95 '--model', default='', type=str,
96 help=("Path to the model file if using a pre-trained model "
97 "from BiaPy or name of the model within BioImage "
98 "Model Zoo or TorchVision.")
99 )
100 parser.add_argument(
101 '--raw_train', default='', type=str,
102 help="Path to the training raw data."
103 )
104 parser.add_argument(
105 '--gt_train', default='', type=str,
106 help="Path to the training ground truth data."
107 )
108 parser.add_argument(
109 '--test_raw_path', default='', type=str,
110 help="Path to the testing raw data."
111 )
112 parser.add_argument(
113 '--test_gt_path', default='', type=str,
114 help="Path to the testing ground truth data."
115 )
116 parser.add_argument(
117 '--biapy_version', default='', type=str,
118 help="BiaPy version to use."
119 )
120 parser.add_argument(
121 '--num_cpus', default="1", type=str,
122 help="Number of CPUs to allocate."
123 )
124 args = parser.parse_args()
125
126 if args.new_config:
127 workflow_map = {
128 "semantic": "SEMANTIC_SEG",
129 "instance": "INSTANCE_SEG",
130 "detection": "DETECTION",
131 "denoising": "DENOISING",
132 "sr": "SUPER_RESOLUTION",
133 "cls": "CLASSIFICATION",
134 "sr2": "SELF_SUPERVISED",
135 "i2i": "IMAGE_TO_IMAGE",
136 }
137 workflow_type = workflow_map[args.workflow]
138
139 ndim = "3D" if args.dims == "3d" else "2D"
140 as_stack = args.dims in ["2d_stack", "2d"]
141
142 config = download_yaml_template(workflow_type, ndim, biapy_version=args.biapy_version)
143
144 # Initialization using setdefault to prevent KeyErrors
145 config.setdefault("PROBLEM", {})
146 config["PROBLEM"].update({"TYPE": workflow_type, "NDIM": ndim})
147
148 config.setdefault("TEST", {})["ANALIZE_2D_IMGS_AS_3D_STACK"] = as_stack
149
150 # Handle MODEL and PATHS
151 model_cfg = config.setdefault("MODEL", {})
152 if args.model_source == "biapy":
153 model_cfg["SOURCE"] = "biapy"
154 is_loading = bool(args.model)
155 model_cfg["LOAD_CHECKPOINT"] = is_loading
156 model_cfg["LOAD_MODEL_FROM_CHECKPOINT"] = is_loading
157 if is_loading:
158 config.setdefault("PATHS", {})["CHECKPOINT_FILE"] = args.model
159 elif args.model_source == "bmz":
160 model_cfg["SOURCE"] = "bmz"
161 model_cfg.setdefault("BMZ", {})["SOURCE_MODEL_ID"] = args.model
162 elif args.model_source == "torchvision":
163 model_cfg["SOURCE"] = "torchvision"
164 model_cfg["TORCHVISION_MODEL_NAME"] = args.model
165
166 # PATCH_SIZE Logic
167 obj_size_map = {
168 "0-25": (256, 256), "25-100": (256, 256),
169 "100-200": (512, 512), "200-500": (512, 512), "500+": (1024, 1024),
170 }
171 obj_size = obj_size_map[args.obj_size]
172
173 obj_slices_map = {"": -1, "1-5": 5, "5-10": 10, "10-20": 20, "20-60": 40, "60+": 80}
174 obj_slices = obj_slices_map.get(args.obj_slices, -1)
175
176 if ndim == "2D":
177 patch_size = obj_size + (args.img_channel,)
178 else:
179 if obj_slices == -1:
180 print("Error: For 3D problems, obj_slices must be specified.")
181 sys.exit(1)
182 patch_size = (obj_slices,) + obj_size + (args.img_channel,)
183
184 config.setdefault("DATA", {})["PATCH_SIZE"] = str(patch_size)
185 config["DATA"]["REFLECT_TO_COMPLETE_SHAPE"] = True
186
187 else:
188 if not args.input_config_path:
189 print("Error: Input configuration path must be specified.")
190 sys.exit(1)
191 try:
192 with open(args.input_config_path, 'r', encoding='utf-8') as f:
193 config = yaml.safe_load(f) or {}
194 except FileNotFoundError:
195 print(f"Error: File {args.input_config_path} not found.")
196 sys.exit(1)
197
198 # Always set NUM_CPUS
199 config.setdefault("SYSTEM", {})
200 try:
201 num_cpus = max(int(args.num_cpus), 1)
202 except BaseException:
203 num_cpus = 1
204 config["SYSTEM"].update({"NUM_CPUS": num_cpus})
205
206 # Global overrides (Train/Test)
207 config.setdefault("TRAIN", {})
208 config.setdefault("DATA", {})
209
210 if args.raw_train:
211 config["TRAIN"]["ENABLE"] = True
212 config["DATA"].setdefault("TRAIN", {}).update({
213 "PATH": args.raw_train,
214 "GT_PATH": args.gt_train
215 })
216 else:
217 config["TRAIN"]["ENABLE"] = False
218
219 test_cfg = config.setdefault("TEST", {})
220 if args.test_raw_path:
221 test_cfg["ENABLE"] = True
222 data_test = config["DATA"].setdefault("TEST", {})
223 data_test["PATH"] = args.test_raw_path
224 data_test["LOAD_GT"] = bool(args.test_gt_path)
225 if args.test_gt_path:
226 data_test["GT_PATH"] = args.test_gt_path
227 else:
228 test_cfg["ENABLE"] = False
229
230 config.setdefault("MODEL", {})["OUT_CHECKPOINT_FORMAT"] = "safetensors"
231
232 # Final cleanup and save
233 config = tuple_to_list(config)
234 with open(args.out_config_path, 'w', encoding='utf-8') as f:
235 yaml.dump(config, f, default_flow_style=False)
236
237 print(f"Success: YAML configuration written to {args.out_config_path}")
238
239
240 if __name__ == "__main__":
241 main()