Mercurial > repos > bgruening > json2yolosegment
comparison yolov8.py @ 3:97bc82ee2a61 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools commit 743c8acf1ea4e4b1e718743d3772b7e592646611
author | bgruening |
---|---|
date | Mon, 14 Jul 2025 18:28:46 +0000 |
parents | 158e6ce48345 |
children | f6990d85161c |
comparison
equal
deleted
inserted
replaced
2:158e6ce48345 | 3:97bc82ee2a61 |
---|---|
1 import argparse | 1 import argparse |
2 import os | 2 import os |
3 import pathlib | 3 import pathlib |
4 import shutil | |
5 import time | 4 import time |
6 from argparse import RawTextHelpFormatter | 5 from argparse import RawTextHelpFormatter |
7 from collections import defaultdict | 6 from collections import defaultdict |
8 | 7 |
9 import cv2 | 8 import cv2 |
173 | 172 |
174 | 173 |
175 # | 174 # |
176 # Functions | 175 # Functions |
177 # | 176 # |
178 | |
179 def safe_rmtree(path): | |
180 try: | |
181 shutil.rmtree(path) | |
182 except OSError: | |
183 time.sleep(1) | |
184 shutil.rmtree(path, ignore_errors=True) | |
185 | |
186 | |
187 # Train a new model on the dataset mentioned in yaml file | 177 # Train a new model on the dataset mentioned in yaml file |
188 def trainModel(model_path, model_name, yaml_filepath, **kwargs): | 178 def trainModel(model_path, model_name, yaml_filepath, **kwargs): |
189 if "imgsz" in kwargs: | 179 if "imgsz" in kwargs: |
190 image_size = kwargs['imgsz'] | 180 image_size = kwargs['imgsz'] |
191 else: | 181 else: |
269 if "init_lr" in kwargs: | 259 if "init_lr" in kwargs: |
270 init_lr = kwargs['init_lr'] | 260 init_lr = kwargs['init_lr'] |
271 else: | 261 else: |
272 init_lr = 1.0 | 262 init_lr = 1.0 |
273 | 263 |
274 train_save_path = os.path.expanduser('~/runs/' + args.mode + '/train/') | |
275 if os.path.isdir(train_save_path): | |
276 safe_rmtree(train_save_path) | |
277 # Load a pretrained YOLO model (recommended for training) | 264 # Load a pretrained YOLO model (recommended for training) |
278 if args.model_format == 'pt': | 265 if args.model_format == 'pt': |
279 model = YOLO(os.path.join(model_path, model_name + "." + args.model_format)) | 266 model = YOLO(os.path.join(model_path, model_name + "." + args.model_format)) |
280 else: | 267 else: |
281 model = YOLO(model_name + "." + args.model_format) | 268 model = YOLO(model_name + "." + args.model_format) |
289 return model | 276 return model |
290 | 277 |
291 | 278 |
292 # Validate the trained model | 279 # Validate the trained model |
293 def validateModel(model): | 280 def validateModel(model): |
294 # Remove prediction save path if already exists | |
295 val_save_path = os.path.expanduser('~/runs/' + args.mode + '/val/') | |
296 if os.path.isdir(val_save_path): | |
297 safe_rmtree(val_save_path) | |
298 # Validate the model | 281 # Validate the model |
299 metrics = model.val() # no args needed, dataset & settings remembered | 282 metrics = model.val() # no args needed, dataset & settings remembered |
300 metrics.box.map # map50-95 | 283 metrics.box.map # map50-95 |
301 metrics.box.map50 # map50 | 284 metrics.box.map50 # map50 |
302 metrics.box.map75 # map75 | 285 metrics.box.map75 # map75 |
328 if "max_det" in kwargs: | 311 if "max_det" in kwargs: |
329 maximum_detections = args.max_det | 312 maximum_detections = args.max_det |
330 else: | 313 else: |
331 maximum_detections = 300 | 314 maximum_detections = 300 |
332 | 315 |
333 if "run_dir" in kwargs: | 316 run_save_dir = kwargs['run_dir'] # For Galaxy, run_save_dir is always provided via xml wrapper |
334 run_save_dir = kwargs['run_dir'] | |
335 else: | |
336 # Remove prediction save path if already exists | |
337 pred_save_path = os.path.expanduser('~/runs/' + args.mode + '/predict/') | |
338 if os.path.isdir(pred_save_path): | |
339 safe_rmtree(pred_save_path) | |
340 if "foldername" in kwargs: | 317 if "foldername" in kwargs: |
341 save_folder_name = kwargs['foldername'] | 318 save_folder_name = kwargs['foldername'] |
342 # infer on a local image or directory containing images/videos | 319 # infer on a local image or directory containing images/videos |
343 prediction = model.predict(source=source_datapath, save=True, stream=True, | 320 prediction = model.predict(source=source_datapath, save=True, stream=True, |
344 conf=confidence, imgsz=image_size, | 321 conf=confidence, imgsz=image_size, |