comparison yolov8.py @ 4:f6990d85161c draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools commit c6c9d43a4ecdc88ebdeaf3451453a550f159c506
author bgruening
date Mon, 21 Jul 2025 15:51:13 +0000
parents 97bc82ee2a61
children
comparison
equal deleted inserted replaced
3:97bc82ee2a61 4:f6990d85161c
1 import argparse 1 import argparse
2 import csv
2 import os 3 import os
3 import pathlib 4 import pathlib
4 import time 5 import time
5 from argparse import RawTextHelpFormatter 6 from argparse import RawTextHelpFormatter
6 from collections import defaultdict 7 from collections import defaultdict
8 import cv2 9 import cv2
9 import numpy as np 10 import numpy as np
10 from termcolor import colored 11 from termcolor import colored
11 from tifffile import imwrite 12 from tifffile import imwrite
12 from ultralytics import YOLO 13 from ultralytics import YOLO
13
14 14
15 # 15 #
16 # Input arguments 16 # Input arguments
17 # 17 #
18 parser = argparse.ArgumentParser( 18 parser = argparse.ArgumentParser(
77 help="Format of the YOLO model i.e pt, yaml etc.", 77 help="Format of the YOLO model i.e pt, yaml etc.",
78 default='pt', type=str) 78 default='pt', type=str)
79 parser.add_argument("--class_names_file", 79 parser.add_argument("--class_names_file",
80 help="Path to the text file containing class names.", 80 help="Path to the text file containing class names.",
81 type=str) 81 type=str)
82
83 # For training the model and prediction 82 # For training the model and prediction
84 parser.add_argument("--mode", 83 parser.add_argument("--mode",
85 help=( 84 help=(
86 "detection, segmentation, classification, and pose \n. " 85 "detection, segmentation, classification, and pose \n. "
87 " Only detection mode available currently i.e. `detect`" 86 " Only detection mode available currently i.e. `detect`"
127 default='bytetrack.yaml', type=str) 126 default='bytetrack.yaml', type=str)
128 127
129 # For headless operation 128 # For headless operation
130 parser.add_argument('--headless', action='store_true') 129 parser.add_argument('--headless', action='store_true')
131 parser.add_argument('--nextflow', action='store_true') 130 parser.add_argument('--nextflow', action='store_true')
131
132 132
133 # For data augmentation 133 # For data augmentation
134 parser.add_argument("--hsv_h", 134 parser.add_argument("--hsv_h",
135 help="(float) image HSV-Hue augmentation (fraction)", 135 help="(float) image HSV-Hue augmentation (fraction)",
136 default=0.015, type=float) 136 default=0.015, type=float)
169 "emphasize central features and adapt to object scales, " 169 "emphasize central features and adapt to object scales, "
170 "reducing background distractions", 170 "reducing background distractions",
171 default=1.0, type=float) 171 default=1.0, type=float)
172 172
173 173
174 #
175 # Functions
176 #
177 # Train a new model on the dataset mentioned in yaml file 174 # Train a new model on the dataset mentioned in yaml file
178 def trainModel(model_path, model_name, yaml_filepath, **kwargs): 175 def trainModel(model_path, model_name, yaml_filepath, **kwargs):
179 if "imgsz" in kwargs: 176 if "imgsz" in kwargs:
180 image_size = kwargs['imgsz'] 177 image_size = kwargs['imgsz']
181 else: 178 else:
270 imgsz=image_size, verbose=True, hsv_h=aug_hsv_h, 267 imgsz=image_size, verbose=True, hsv_h=aug_hsv_h,
271 hsv_s=aug_hsv_s, hsv_v=aug_hsv_v, degrees=aug_degrees, 268 hsv_s=aug_hsv_s, hsv_v=aug_hsv_v, degrees=aug_degrees,
272 translate=aug_translate, shear=aug_shear, scale=aug_scale, 269 translate=aug_translate, shear=aug_shear, scale=aug_scale,
273 perspective=aug_perspective, fliplr=aug_fliplr, 270 perspective=aug_perspective, fliplr=aug_fliplr,
274 flipud=aug_flipud, mosaic=aug_mosaic, crop_fraction=aug_crop_fraction, 271 flipud=aug_flipud, mosaic=aug_mosaic, crop_fraction=aug_crop_fraction,
275 weight_decay=weight_decay, lr0=init_lr, seed=42) 272 weight_decay=weight_decay, lr0=init_lr)
276 return model 273 return model
277 274
278 275
279 # Validate the trained model 276 # Validate the trained model
280 def validateModel(model): 277 def validateModel(model):
281 # Validate the model
282 metrics = model.val() # no args needed, dataset & settings remembered 278 metrics = model.val() # no args needed, dataset & settings remembered
283 metrics.box.map # map50-95 279 metrics.box.map # map50-95
284 metrics.box.map50 # map50 280 metrics.box.map50 # map50
285 metrics.box.map75 # map75 281 metrics.box.map75 # map75
286 metrics.box.maps # a list contains map50-95 of each category 282 metrics.box.maps # a list contains map50-95 of each category
314 maximum_detections = 300 310 maximum_detections = 300
315 311
316 run_save_dir = kwargs['run_dir'] # For Galaxy, run_save_dir is always provided via xml wrapper 312 run_save_dir = kwargs['run_dir'] # For Galaxy, run_save_dir is always provided via xml wrapper
317 if "foldername" in kwargs: 313 if "foldername" in kwargs:
318 save_folder_name = kwargs['foldername'] 314 save_folder_name = kwargs['foldername']
315
319 # infer on a local image or directory containing images/videos 316 # infer on a local image or directory containing images/videos
320 prediction = model.predict(source=source_datapath, save=True, stream=True, 317 prediction = model.predict(source=source_datapath, save=True, stream=True,
321 conf=confidence, imgsz=image_size, 318 conf=confidence, imgsz=image_size,
322 save_conf=True, iou=iou_value, max_det=maximum_detections, 319 save_conf=True, iou=iou_value, max_det=maximum_detections,
323 classes=class_array, save_txt=False, 320 classes=class_array, save_txt=False,
327 324
328 # Save bounding boxes 325 # Save bounding boxes
329 def save_yolo_bounding_boxes_to_txt(predictions, save_dir): 326 def save_yolo_bounding_boxes_to_txt(predictions, save_dir):
330 """ 327 """
331 Function to save YOLO bounding boxes to text files. 328 Function to save YOLO bounding boxes to text files.
329
332 Parameters: 330 Parameters:
333 - predictions: List of results from YOLO model inference. 331 - predictions: List of results from YOLO model inference.
334 - save_dir: Directory where the text files will be saved. 332 - save_dir: Directory where the text files will be saved.
335 """ 333 """
336 for result in predictions: 334 for result in predictions:
337 result = result.to("cpu").numpy() 335 result = result.to("cpu").numpy()
338 # Using bounding_boxes, confidence_scores, and class_num which are defined in the list 336 # Using bounding_boxes, confidence_scores, and class_num which are defined in the list
339 bounding_boxes = result.boxes.xyxy # Bounding boxes in xyxy format 337 bounding_boxes = result.boxes.xyxy # Bounding boxes in xyxy format
340 confidence_scores = result.boxes.conf # Confidence scores 338 confidence_scores = result.boxes.conf # Confidence scores
341 class_nums = result.boxes.cls # Class numbers 339 class_nums = result.boxes.cls # Class numbers
340
342 # Create save directory if it doesn't exist 341 # Create save directory if it doesn't exist
343 save_path = pathlib.Path(save_dir).absolute() 342 save_path = pathlib.Path(save_dir).absolute()
344 save_path.mkdir(parents=True, exist_ok=True) 343 save_path.mkdir(parents=True, exist_ok=True)
344
345 # Construct filename for the text file 345 # Construct filename for the text file
346 image_filename = pathlib.Path(result.path).stem 346 image_filename = pathlib.Path(result.path).stem
347 text_filename = save_path / f"{image_filename}.txt" 347 text_filename = save_path / f"{image_filename}.txt"
348
348 # Write bounding boxes info into the text file 349 # Write bounding boxes info into the text file
349 with open(text_filename, 'w') as f: 350 with open(text_filename, 'w') as f:
350 for i in range(bounding_boxes.shape[0]): 351 for i in range(bounding_boxes.shape[0]):
351 x1, y1, x2, y2 = bounding_boxes[i] 352 x1, y1, x2, y2 = bounding_boxes[i]
352 confidence = confidence_scores[i] 353 confidence = confidence_scores[i]
353 class_num = int(class_nums[i]) 354 class_num = int(class_nums[i])
354 f.write(f'{class_num:01} {x1:06.2f} {y1:06.2f} {x2:06.2f} {y2:06.2f} {confidence:0.02} \n') 355 f.write(f'{class_num:01} {x1:06.2f} {y1:06.2f} {x2:06.2f} {y2:06.2f} {confidence:0.02} \n')
355 print(colored(f"Bounding boxes saved in: {text_filename}", 'green')) 356 print(colored(f"Bounding boxes saved in: {text_filename}", 'green'))
356 357
357 358
359 # Main code
358 if __name__ == '__main__': 360 if __name__ == '__main__':
359 args = parser.parse_args() 361 args = parser.parse_args()
360 os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" 362 os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
363
361 # Train/load model 364 # Train/load model
362 if (args.train): 365 if (args.train):
363 model = trainModel(args.model_path, args.model_name, args.yaml_path, 366 model = trainModel(args.model_path, args.model_name, args.yaml_path,
364 imgsz=args.image_size, epochs=args.epochs, 367 imgsz=args.image_size, epochs=args.epochs,
365 hsv_h=args.hsv_h, hsv_s=args.hsv_s, hsv_v=args.hsv_v, 368 hsv_h=args.hsv_h, hsv_s=args.hsv_s, hsv_v=args.hsv_v,
375 "train", "weights", "best.pt")) and (args.model_name == 'sam'): 378 "train", "weights", "best.pt")) and (args.model_name == 'sam'):
376 model = YOLO(os.path.join(train_save_path, 379 model = YOLO(os.path.join(train_save_path,
377 "train", "weights", "best.pt")) 380 "train", "weights", "best.pt"))
378 else: 381 else:
379 model = YOLO(os.path.join(args.model_path, 382 model = YOLO(os.path.join(args.model_path,
380 args.model_name + ".pt")) 383 args.model_name + ".pt"))
381 model.info(verbose=True) 384 model.info(verbose=True)
382 elapsed = time.time() - t 385 elapsed = time.time() - t
383 print(colored(f"\nYOLO model loaded in : '{elapsed}' sec \n", 'white', 'on_yellow')) 386 print(colored(f"\nYOLO model loaded in : '{elapsed}' sec \n", 'white', 'on_yellow'))
384 387
385 if (args.save_dir): 388 if (args.save_dir):
420 elif (args.mode == "track"): 423 elif (args.mode == "track"):
421 results = model.track(source=datapath_for_prediction, 424 results = model.track(source=datapath_for_prediction,
422 tracker=args.tracker_file, 425 tracker=args.tracker_file,
423 conf=args.confidence, 426 conf=args.confidence,
424 iou=args.iou, 427 iou=args.iou,
425 persist=False, 428 persist=True,
426 show=True, 429 show=False,
427 save=True, 430 save=True,
428 project=args.run_dir, 431 project=args.run_dir,
429 name=args.foldername) 432 name=args.foldername)
430 # Store the track history 433 # Store the track history
431 track_history = defaultdict(lambda: []) 434 track_history = defaultdict(lambda: [])
432 435
433 for result in results: 436 tsv_path = os.path.join(args.save_dir, "tracks.tsv")
434 # Get the boxes and track IDs 437 with open(tsv_path, "w", newline="") as tsvfile:
435 if result.boxes and result.boxes.is_track: 438 writer = csv.writer(tsvfile, delimiter='\t')
436 boxes = result.boxes.xywh.cpu() 439 writer.writerow(['track_id', 'frame', 'class', 'centroid_x', 'centroid_y'])
437 track_ids = result.boxes.id.int().cpu().tolist() 440 frame_idx = 0
438 # Visualize the result on the frame 441 for result in results:
439 frame = result.plot() 442 # Get the boxes and track IDs
440 # Plot the tracks 443 if result.boxes and result.boxes.is_track:
441 for box, track_id in zip(boxes, track_ids): 444 track_ids = result.boxes.id.int().cpu().tolist()
442 x, y, w, h = box 445 labels = result.boxes.cls.int().cpu().tolist() if hasattr(result.boxes, "cls") else [0] * len(track_ids)
443 track = track_history[track_id] 446 # Prepare mask image
444 track.append((float(x), float(y))) # x, y center point 447 img_shape = result.orig_shape if hasattr(result, "orig_shape") else result.orig_img.shape
445 if len(track) > 30: # retain 30 tracks for 30 frames 448 mask = np.zeros(img_shape[:2], dtype=np.uint16)
446 track.pop(0) 449 # Check if polygons (masks) are available
447 450 if hasattr(result, "masks") and result.masks is not None and hasattr(result.masks, "xy"):
448 # Draw the tracking lines 451 polygons = result.masks.xy
449 points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2)) 452 for i, (track_id, label) in enumerate(zip(track_ids, labels)):
450 cv2.polylines(frame, [points], isClosed=False, color=(230, 230, 230), thickness=2) 453 if i < len(polygons):
451 454 contour = polygons[i].astype(np.int32)
452 # Display the annotated frame 455 contour = contour.reshape(-1, 1, 2)
453 cv2.imshow("YOLO11 Tracking", frame) 456 cv2.drawContours(mask, [contour], -1, int(track_id), cv2.FILLED)
454 print(colored(f"Tracking results saved in : '{args.save_dir}' \n", 'green')) 457 # Calculate centroid of the polygon
458 M = cv2.moments(contour)
459 if M["m00"] != 0:
460 cx = float(M["m10"] / M["m00"])
461 cy = float(M["m01"] / M["m00"])
462 else:
463 cx, cy = 0.0, 0.0
464 writer.writerow([track_id, frame_idx, label, cx, cy])
465 else:
466 # Fallback to bounding boxes if polygons are not available
467 boxes = result.boxes.xywh.cpu()
468 xyxy_boxes = result.boxes.xyxy.cpu().numpy()
469 for i, (box, xyxy, track_id, label) in enumerate(zip(boxes, xyxy_boxes, track_ids, labels)):
470 x, y, w, h = box
471 writer.writerow([track_id, frame_idx, label, float(x), float(y)])
472 x1, y1, x2, y2 = map(int, xyxy)
473 cv2.rectangle(mask, (x1, y1), (x2, y2), int(track_id), thickness=-1)
474 # Collect masks for TYX stack
475 if frame_idx == 0:
476 mask_stack = []
477 mask_stack.append(mask)
478 frame_idx += 1
479 # Save TYX stack (T=frames, Y, X)
480 if 'mask_stack' in locals() and len(mask_stack) > 0:
481 tyx_array = np.stack(mask_stack, axis=0)
482 # Remove string from last underscore in filename
483 stem = pathlib.Path(result.path).stem
484 stem = stem.rsplit('_', 1)[0] if '_' in stem else stem
485 mask_save_as = str(pathlib.Path(os.path.join(args.save_dir, stem + "_mask.tiff")).absolute())
486 imwrite(mask_save_as, tyx_array)
487 print(colored(f"TYX mask stack saved as : '{mask_save_as}'", 'magenta'))
488 print(colored(f"Tracking results saved in : '{args.save_dir}' \n", 'green'))
455 elif (args.mode == "segment"): 489 elif (args.mode == "segment"):
456 # Read class names from the file 490 # Read class names from the file
457 with open(args.class_names_file, 'r') as f: 491 with open(args.class_names_file, 'r') as f:
458 class_names = [line.strip() for line in f.readlines()] 492 class_names = [line.strip() for line in f.readlines()]
493 # Create a mapping from class names to indices
459 class_to_index = {class_name: i for i, class_name in enumerate(class_names)} 494 class_to_index = {class_name: i for i, class_name in enumerate(class_names)}
460 495
461 # Save polygon coordinates 496 # Save polygon coordinates
462 for result in predictions: 497 for result in predictions:
498 # Create binary mask
463 img = np.copy(result.orig_img) 499 img = np.copy(result.orig_img)
464 filename = pathlib.Path(result.path).stem 500 filename = pathlib.Path(result.path).stem
465 b_mask = np.zeros(img.shape[:2], np.uint8) 501 b_mask = np.zeros(img.shape[:2], np.uint8)
466 mask_save_as = str(pathlib.Path(os.path.join(args.save_dir, filename + "_mask.tiff")).absolute()) 502 mask_save_as = str(pathlib.Path(os.path.join(args.save_dir, filename + "_mask.tiff")).absolute())
503 # Define output file path for text file
504 output_filename = os.path.splitext(filename)[0] + ".txt"
467 txt_save_as = str(pathlib.Path(os.path.join(args.save_dir, filename + ".txt")).absolute()) 505 txt_save_as = str(pathlib.Path(os.path.join(args.save_dir, filename + ".txt")).absolute())
468 506 instance_id = 1 # Start instance IDs from 1
469 for c, ci in enumerate(result): 507 for c, ci in enumerate(result):
470 if ci.masks is not None and ci.masks.xy: 508 # Extract contour result
471 # Extract contour 509 contour = ci.masks.xy.pop()
472 contour = ci.masks.xy.pop() 510 contour = contour.astype(np.int32)
473 contour = contour.astype(np.int32).reshape(-1, 1, 2) 511 contour = contour.reshape(-1, 1, 2)
474 _ = cv2.drawContours(b_mask, [contour], -1, (255, 255, 255), cv2.FILLED) 512 # Draw contour onto mask with unique instance id
475 513 _ = cv2.drawContours(b_mask, [contour], -1, instance_id, cv2.FILLED)
476 # Normalized polygon points 514
477 points = ci.masks.xyn.pop() 515 # Normalized polygon points
478 obj_class = int(ci.boxes.cls.to("cpu").numpy().item()) 516 points = ci.masks.xyn.pop()
479 confidence = result.boxes.conf.to("cpu").numpy()[c] 517 confidence = result.boxes.conf.to("cpu").numpy()[c]
480 518
481 with open(txt_save_as, 'a') as f: 519 with open(txt_save_as, 'a') as f:
482 segmentation_points = ['{} {}'.format(points[i][0], points[i][1]) for i in range(len(points))] 520 segmentation_points = ['{} {}'.format(points[i][0], points[i][1]) for i in range(len(points))]
483 segmentation_points_string = ' '.join(segmentation_points) 521 segmentation_points_string = ' '.join(segmentation_points)
484 line = '{} {} {}\n'.format(obj_class, segmentation_points_string, confidence) 522 line = '{} {} {}\n'.format(instance_id, segmentation_points_string, confidence)
485 f.write(line) 523 f.write(line)
486 else: 524
487 print(colored(f"⚠️ No mask found for object {c} in '{filename}'. Skipping.", "yellow")) 525 instance_id += 1 # Increment for next object
488 526
489 # Overlay mask onto original image 527 imwrite(mask_save_as, b_mask, imagej=True) # save label mask image
490 colored_mask = cv2.merge([b_mask, np.zeros_like(b_mask), np.zeros_like(b_mask)]) 528 print(colored(f"Saved label mask as : \n '{mask_save_as}' \n", 'magenta'))
491 blended = cv2.addWeighted(img, 1.0, colored_mask, 0.5, 0)
492 overlay_path = os.path.join(args.save_dir, filename + "_overlay.jpg")
493 cv2.imwrite(overlay_path, blended)
494
495 imwrite(mask_save_as, b_mask, imagej=True)
496 print(colored(f"Saved binary mask as : \n '{mask_save_as}' \n", 'magenta'))
497 print(colored(f"Polygon coordinates saved as : \n '{txt_save_as}' \n", 'cyan')) 529 print(colored(f"Polygon coordinates saved as : \n '{txt_save_as}' \n", 'cyan'))
530 else:
531 raise Exception(("Currently only 'detect', 'segment' and 'track' modes are available"))