Mercurial > repos > bgruening > json2yolosegment
comparison yolov8.py @ 4:f6990d85161c draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools commit c6c9d43a4ecdc88ebdeaf3451453a550f159c506
author | bgruening |
---|---|
date | Mon, 21 Jul 2025 15:51:13 +0000 |
parents | 97bc82ee2a61 |
children |
comparison
equal
deleted
inserted
replaced
3:97bc82ee2a61 | 4:f6990d85161c |
---|---|
1 import argparse | 1 import argparse |
2 import csv | |
2 import os | 3 import os |
3 import pathlib | 4 import pathlib |
4 import time | 5 import time |
5 from argparse import RawTextHelpFormatter | 6 from argparse import RawTextHelpFormatter |
6 from collections import defaultdict | 7 from collections import defaultdict |
8 import cv2 | 9 import cv2 |
9 import numpy as np | 10 import numpy as np |
10 from termcolor import colored | 11 from termcolor import colored |
11 from tifffile import imwrite | 12 from tifffile import imwrite |
12 from ultralytics import YOLO | 13 from ultralytics import YOLO |
13 | |
14 | 14 |
15 # | 15 # |
16 # Input arguments | 16 # Input arguments |
17 # | 17 # |
18 parser = argparse.ArgumentParser( | 18 parser = argparse.ArgumentParser( |
77 help="Format of the YOLO model i.e pt, yaml etc.", | 77 help="Format of the YOLO model i.e pt, yaml etc.", |
78 default='pt', type=str) | 78 default='pt', type=str) |
79 parser.add_argument("--class_names_file", | 79 parser.add_argument("--class_names_file", |
80 help="Path to the text file containing class names.", | 80 help="Path to the text file containing class names.", |
81 type=str) | 81 type=str) |
82 | |
83 # For training the model and prediction | 82 # For training the model and prediction |
84 parser.add_argument("--mode", | 83 parser.add_argument("--mode", |
85 help=( | 84 help=( |
86 "detection, segmentation, classification, and pose \n. " | 85 "detection, segmentation, classification, and pose \n. " |
87 " Only detection mode available currently i.e. `detect`" | 86 " Only detection mode available currently i.e. `detect`" |
127 default='bytetrack.yaml', type=str) | 126 default='bytetrack.yaml', type=str) |
128 | 127 |
129 # For headless operation | 128 # For headless operation |
130 parser.add_argument('--headless', action='store_true') | 129 parser.add_argument('--headless', action='store_true') |
131 parser.add_argument('--nextflow', action='store_true') | 130 parser.add_argument('--nextflow', action='store_true') |
131 | |
132 | 132 |
133 # For data augmentation | 133 # For data augmentation |
134 parser.add_argument("--hsv_h", | 134 parser.add_argument("--hsv_h", |
135 help="(float) image HSV-Hue augmentation (fraction)", | 135 help="(float) image HSV-Hue augmentation (fraction)", |
136 default=0.015, type=float) | 136 default=0.015, type=float) |
169 "emphasize central features and adapt to object scales, " | 169 "emphasize central features and adapt to object scales, " |
170 "reducing background distractions", | 170 "reducing background distractions", |
171 default=1.0, type=float) | 171 default=1.0, type=float) |
172 | 172 |
173 | 173 |
174 # | |
175 # Functions | |
176 # | |
177 # Train a new model on the dataset mentioned in yaml file | 174 # Train a new model on the dataset mentioned in yaml file |
178 def trainModel(model_path, model_name, yaml_filepath, **kwargs): | 175 def trainModel(model_path, model_name, yaml_filepath, **kwargs): |
179 if "imgsz" in kwargs: | 176 if "imgsz" in kwargs: |
180 image_size = kwargs['imgsz'] | 177 image_size = kwargs['imgsz'] |
181 else: | 178 else: |
270 imgsz=image_size, verbose=True, hsv_h=aug_hsv_h, | 267 imgsz=image_size, verbose=True, hsv_h=aug_hsv_h, |
271 hsv_s=aug_hsv_s, hsv_v=aug_hsv_v, degrees=aug_degrees, | 268 hsv_s=aug_hsv_s, hsv_v=aug_hsv_v, degrees=aug_degrees, |
272 translate=aug_translate, shear=aug_shear, scale=aug_scale, | 269 translate=aug_translate, shear=aug_shear, scale=aug_scale, |
273 perspective=aug_perspective, fliplr=aug_fliplr, | 270 perspective=aug_perspective, fliplr=aug_fliplr, |
274 flipud=aug_flipud, mosaic=aug_mosaic, crop_fraction=aug_crop_fraction, | 271 flipud=aug_flipud, mosaic=aug_mosaic, crop_fraction=aug_crop_fraction, |
275 weight_decay=weight_decay, lr0=init_lr, seed=42) | 272 weight_decay=weight_decay, lr0=init_lr) |
276 return model | 273 return model |
277 | 274 |
278 | 275 |
279 # Validate the trained model | 276 # Validate the trained model |
280 def validateModel(model): | 277 def validateModel(model): |
281 # Validate the model | |
282 metrics = model.val() # no args needed, dataset & settings remembered | 278 metrics = model.val() # no args needed, dataset & settings remembered |
283 metrics.box.map # map50-95 | 279 metrics.box.map # map50-95 |
284 metrics.box.map50 # map50 | 280 metrics.box.map50 # map50 |
285 metrics.box.map75 # map75 | 281 metrics.box.map75 # map75 |
286 metrics.box.maps # a list contains map50-95 of each category | 282 metrics.box.maps # a list contains map50-95 of each category |
314 maximum_detections = 300 | 310 maximum_detections = 300 |
315 | 311 |
316 run_save_dir = kwargs['run_dir'] # For Galaxy, run_save_dir is always provided via xml wrapper | 312 run_save_dir = kwargs['run_dir'] # For Galaxy, run_save_dir is always provided via xml wrapper |
317 if "foldername" in kwargs: | 313 if "foldername" in kwargs: |
318 save_folder_name = kwargs['foldername'] | 314 save_folder_name = kwargs['foldername'] |
315 | |
319 # infer on a local image or directory containing images/videos | 316 # infer on a local image or directory containing images/videos |
320 prediction = model.predict(source=source_datapath, save=True, stream=True, | 317 prediction = model.predict(source=source_datapath, save=True, stream=True, |
321 conf=confidence, imgsz=image_size, | 318 conf=confidence, imgsz=image_size, |
322 save_conf=True, iou=iou_value, max_det=maximum_detections, | 319 save_conf=True, iou=iou_value, max_det=maximum_detections, |
323 classes=class_array, save_txt=False, | 320 classes=class_array, save_txt=False, |
327 | 324 |
328 # Save bounding boxes | 325 # Save bounding boxes |
329 def save_yolo_bounding_boxes_to_txt(predictions, save_dir): | 326 def save_yolo_bounding_boxes_to_txt(predictions, save_dir): |
330 """ | 327 """ |
331 Function to save YOLO bounding boxes to text files. | 328 Function to save YOLO bounding boxes to text files. |
329 | |
332 Parameters: | 330 Parameters: |
333 - predictions: List of results from YOLO model inference. | 331 - predictions: List of results from YOLO model inference. |
334 - save_dir: Directory where the text files will be saved. | 332 - save_dir: Directory where the text files will be saved. |
335 """ | 333 """ |
336 for result in predictions: | 334 for result in predictions: |
337 result = result.to("cpu").numpy() | 335 result = result.to("cpu").numpy() |
338 # Using bounding_boxes, confidence_scores, and class_num which are defined in the list | 336 # Using bounding_boxes, confidence_scores, and class_num which are defined in the list |
339 bounding_boxes = result.boxes.xyxy # Bounding boxes in xyxy format | 337 bounding_boxes = result.boxes.xyxy # Bounding boxes in xyxy format |
340 confidence_scores = result.boxes.conf # Confidence scores | 338 confidence_scores = result.boxes.conf # Confidence scores |
341 class_nums = result.boxes.cls # Class numbers | 339 class_nums = result.boxes.cls # Class numbers |
340 | |
342 # Create save directory if it doesn't exist | 341 # Create save directory if it doesn't exist |
343 save_path = pathlib.Path(save_dir).absolute() | 342 save_path = pathlib.Path(save_dir).absolute() |
344 save_path.mkdir(parents=True, exist_ok=True) | 343 save_path.mkdir(parents=True, exist_ok=True) |
344 | |
345 # Construct filename for the text file | 345 # Construct filename for the text file |
346 image_filename = pathlib.Path(result.path).stem | 346 image_filename = pathlib.Path(result.path).stem |
347 text_filename = save_path / f"{image_filename}.txt" | 347 text_filename = save_path / f"{image_filename}.txt" |
348 | |
348 # Write bounding boxes info into the text file | 349 # Write bounding boxes info into the text file |
349 with open(text_filename, 'w') as f: | 350 with open(text_filename, 'w') as f: |
350 for i in range(bounding_boxes.shape[0]): | 351 for i in range(bounding_boxes.shape[0]): |
351 x1, y1, x2, y2 = bounding_boxes[i] | 352 x1, y1, x2, y2 = bounding_boxes[i] |
352 confidence = confidence_scores[i] | 353 confidence = confidence_scores[i] |
353 class_num = int(class_nums[i]) | 354 class_num = int(class_nums[i]) |
354 f.write(f'{class_num:01} {x1:06.2f} {y1:06.2f} {x2:06.2f} {y2:06.2f} {confidence:0.02} \n') | 355 f.write(f'{class_num:01} {x1:06.2f} {y1:06.2f} {x2:06.2f} {y2:06.2f} {confidence:0.02} \n') |
355 print(colored(f"Bounding boxes saved in: {text_filename}", 'green')) | 356 print(colored(f"Bounding boxes saved in: {text_filename}", 'green')) |
356 | 357 |
357 | 358 |
359 # Main code | |
358 if __name__ == '__main__': | 360 if __name__ == '__main__': |
359 args = parser.parse_args() | 361 args = parser.parse_args() |
360 os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" | 362 os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" |
363 | |
361 # Train/load model | 364 # Train/load model |
362 if (args.train): | 365 if (args.train): |
363 model = trainModel(args.model_path, args.model_name, args.yaml_path, | 366 model = trainModel(args.model_path, args.model_name, args.yaml_path, |
364 imgsz=args.image_size, epochs=args.epochs, | 367 imgsz=args.image_size, epochs=args.epochs, |
365 hsv_h=args.hsv_h, hsv_s=args.hsv_s, hsv_v=args.hsv_v, | 368 hsv_h=args.hsv_h, hsv_s=args.hsv_s, hsv_v=args.hsv_v, |
375 "train", "weights", "best.pt")) and (args.model_name == 'sam'): | 378 "train", "weights", "best.pt")) and (args.model_name == 'sam'): |
376 model = YOLO(os.path.join(train_save_path, | 379 model = YOLO(os.path.join(train_save_path, |
377 "train", "weights", "best.pt")) | 380 "train", "weights", "best.pt")) |
378 else: | 381 else: |
379 model = YOLO(os.path.join(args.model_path, | 382 model = YOLO(os.path.join(args.model_path, |
380 args.model_name + ".pt")) | 383 args.model_name + ".pt")) |
381 model.info(verbose=True) | 384 model.info(verbose=True) |
382 elapsed = time.time() - t | 385 elapsed = time.time() - t |
383 print(colored(f"\nYOLO model loaded in : '{elapsed}' sec \n", 'white', 'on_yellow')) | 386 print(colored(f"\nYOLO model loaded in : '{elapsed}' sec \n", 'white', 'on_yellow')) |
384 | 387 |
385 if (args.save_dir): | 388 if (args.save_dir): |
420 elif (args.mode == "track"): | 423 elif (args.mode == "track"): |
421 results = model.track(source=datapath_for_prediction, | 424 results = model.track(source=datapath_for_prediction, |
422 tracker=args.tracker_file, | 425 tracker=args.tracker_file, |
423 conf=args.confidence, | 426 conf=args.confidence, |
424 iou=args.iou, | 427 iou=args.iou, |
425 persist=False, | 428 persist=True, |
426 show=True, | 429 show=False, |
427 save=True, | 430 save=True, |
428 project=args.run_dir, | 431 project=args.run_dir, |
429 name=args.foldername) | 432 name=args.foldername) |
430 # Store the track history | 433 # Store the track history |
431 track_history = defaultdict(lambda: []) | 434 track_history = defaultdict(lambda: []) |
432 | 435 |
433 for result in results: | 436 tsv_path = os.path.join(args.save_dir, "tracks.tsv") |
434 # Get the boxes and track IDs | 437 with open(tsv_path, "w", newline="") as tsvfile: |
435 if result.boxes and result.boxes.is_track: | 438 writer = csv.writer(tsvfile, delimiter='\t') |
436 boxes = result.boxes.xywh.cpu() | 439 writer.writerow(['track_id', 'frame', 'class', 'centroid_x', 'centroid_y']) |
437 track_ids = result.boxes.id.int().cpu().tolist() | 440 frame_idx = 0 |
438 # Visualize the result on the frame | 441 for result in results: |
439 frame = result.plot() | 442 # Get the boxes and track IDs |
440 # Plot the tracks | 443 if result.boxes and result.boxes.is_track: |
441 for box, track_id in zip(boxes, track_ids): | 444 track_ids = result.boxes.id.int().cpu().tolist() |
442 x, y, w, h = box | 445 labels = result.boxes.cls.int().cpu().tolist() if hasattr(result.boxes, "cls") else [0] * len(track_ids) |
443 track = track_history[track_id] | 446 # Prepare mask image |
444 track.append((float(x), float(y))) # x, y center point | 447 img_shape = result.orig_shape if hasattr(result, "orig_shape") else result.orig_img.shape |
445 if len(track) > 30: # retain 30 tracks for 30 frames | 448 mask = np.zeros(img_shape[:2], dtype=np.uint16) |
446 track.pop(0) | 449 # Check if polygons (masks) are available |
447 | 450 if hasattr(result, "masks") and result.masks is not None and hasattr(result.masks, "xy"): |
448 # Draw the tracking lines | 451 polygons = result.masks.xy |
449 points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2)) | 452 for i, (track_id, label) in enumerate(zip(track_ids, labels)): |
450 cv2.polylines(frame, [points], isClosed=False, color=(230, 230, 230), thickness=2) | 453 if i < len(polygons): |
451 | 454 contour = polygons[i].astype(np.int32) |
452 # Display the annotated frame | 455 contour = contour.reshape(-1, 1, 2) |
453 cv2.imshow("YOLO11 Tracking", frame) | 456 cv2.drawContours(mask, [contour], -1, int(track_id), cv2.FILLED) |
454 print(colored(f"Tracking results saved in : '{args.save_dir}' \n", 'green')) | 457 # Calculate centroid of the polygon |
458 M = cv2.moments(contour) | |
459 if M["m00"] != 0: | |
460 cx = float(M["m10"] / M["m00"]) | |
461 cy = float(M["m01"] / M["m00"]) | |
462 else: | |
463 cx, cy = 0.0, 0.0 | |
464 writer.writerow([track_id, frame_idx, label, cx, cy]) | |
465 else: | |
466 # Fallback to bounding boxes if polygons are not available | |
467 boxes = result.boxes.xywh.cpu() | |
468 xyxy_boxes = result.boxes.xyxy.cpu().numpy() | |
469 for i, (box, xyxy, track_id, label) in enumerate(zip(boxes, xyxy_boxes, track_ids, labels)): | |
470 x, y, w, h = box | |
471 writer.writerow([track_id, frame_idx, label, float(x), float(y)]) | |
472 x1, y1, x2, y2 = map(int, xyxy) | |
473 cv2.rectangle(mask, (x1, y1), (x2, y2), int(track_id), thickness=-1) | |
474 # Collect masks for TYX stack | |
475 if frame_idx == 0: | |
476 mask_stack = [] | |
477 mask_stack.append(mask) | |
478 frame_idx += 1 | |
479 # Save TYX stack (T=frames, Y, X) | |
480 if 'mask_stack' in locals() and len(mask_stack) > 0: | |
481 tyx_array = np.stack(mask_stack, axis=0) | |
482 # Remove string from last underscore in filename | |
483 stem = pathlib.Path(result.path).stem | |
484 stem = stem.rsplit('_', 1)[0] if '_' in stem else stem | |
485 mask_save_as = str(pathlib.Path(os.path.join(args.save_dir, stem + "_mask.tiff")).absolute()) | |
486 imwrite(mask_save_as, tyx_array) | |
487 print(colored(f"TYX mask stack saved as : '{mask_save_as}'", 'magenta')) | |
488 print(colored(f"Tracking results saved in : '{args.save_dir}' \n", 'green')) | |
455 elif (args.mode == "segment"): | 489 elif (args.mode == "segment"): |
456 # Read class names from the file | 490 # Read class names from the file |
457 with open(args.class_names_file, 'r') as f: | 491 with open(args.class_names_file, 'r') as f: |
458 class_names = [line.strip() for line in f.readlines()] | 492 class_names = [line.strip() for line in f.readlines()] |
493 # Create a mapping from class names to indices | |
459 class_to_index = {class_name: i for i, class_name in enumerate(class_names)} | 494 class_to_index = {class_name: i for i, class_name in enumerate(class_names)} |
460 | 495 |
461 # Save polygon coordinates | 496 # Save polygon coordinates |
462 for result in predictions: | 497 for result in predictions: |
498 # Create binary mask | |
463 img = np.copy(result.orig_img) | 499 img = np.copy(result.orig_img) |
464 filename = pathlib.Path(result.path).stem | 500 filename = pathlib.Path(result.path).stem |
465 b_mask = np.zeros(img.shape[:2], np.uint8) | 501 b_mask = np.zeros(img.shape[:2], np.uint8) |
466 mask_save_as = str(pathlib.Path(os.path.join(args.save_dir, filename + "_mask.tiff")).absolute()) | 502 mask_save_as = str(pathlib.Path(os.path.join(args.save_dir, filename + "_mask.tiff")).absolute()) |
503 # Define output file path for text file | |
504 output_filename = os.path.splitext(filename)[0] + ".txt" | |
467 txt_save_as = str(pathlib.Path(os.path.join(args.save_dir, filename + ".txt")).absolute()) | 505 txt_save_as = str(pathlib.Path(os.path.join(args.save_dir, filename + ".txt")).absolute()) |
468 | 506 instance_id = 1 # Start instance IDs from 1 |
469 for c, ci in enumerate(result): | 507 for c, ci in enumerate(result): |
470 if ci.masks is not None and ci.masks.xy: | 508 # Extract contour result |
471 # Extract contour | 509 contour = ci.masks.xy.pop() |
472 contour = ci.masks.xy.pop() | 510 contour = contour.astype(np.int32) |
473 contour = contour.astype(np.int32).reshape(-1, 1, 2) | 511 contour = contour.reshape(-1, 1, 2) |
474 _ = cv2.drawContours(b_mask, [contour], -1, (255, 255, 255), cv2.FILLED) | 512 # Draw contour onto mask with unique instance id |
475 | 513 _ = cv2.drawContours(b_mask, [contour], -1, instance_id, cv2.FILLED) |
476 # Normalized polygon points | 514 |
477 points = ci.masks.xyn.pop() | 515 # Normalized polygon points |
478 obj_class = int(ci.boxes.cls.to("cpu").numpy().item()) | 516 points = ci.masks.xyn.pop() |
479 confidence = result.boxes.conf.to("cpu").numpy()[c] | 517 confidence = result.boxes.conf.to("cpu").numpy()[c] |
480 | 518 |
481 with open(txt_save_as, 'a') as f: | 519 with open(txt_save_as, 'a') as f: |
482 segmentation_points = ['{} {}'.format(points[i][0], points[i][1]) for i in range(len(points))] | 520 segmentation_points = ['{} {}'.format(points[i][0], points[i][1]) for i in range(len(points))] |
483 segmentation_points_string = ' '.join(segmentation_points) | 521 segmentation_points_string = ' '.join(segmentation_points) |
484 line = '{} {} {}\n'.format(obj_class, segmentation_points_string, confidence) | 522 line = '{} {} {}\n'.format(instance_id, segmentation_points_string, confidence) |
485 f.write(line) | 523 f.write(line) |
486 else: | 524 |
487 print(colored(f"⚠️ No mask found for object {c} in '{filename}'. Skipping.", "yellow")) | 525 instance_id += 1 # Increment for next object |
488 | 526 |
489 # Overlay mask onto original image | 527 imwrite(mask_save_as, b_mask, imagej=True) # save label mask image |
490 colored_mask = cv2.merge([b_mask, np.zeros_like(b_mask), np.zeros_like(b_mask)]) | 528 print(colored(f"Saved label mask as : \n '{mask_save_as}' \n", 'magenta')) |
491 blended = cv2.addWeighted(img, 1.0, colored_mask, 0.5, 0) | |
492 overlay_path = os.path.join(args.save_dir, filename + "_overlay.jpg") | |
493 cv2.imwrite(overlay_path, blended) | |
494 | |
495 imwrite(mask_save_as, b_mask, imagej=True) | |
496 print(colored(f"Saved binary mask as : \n '{mask_save_as}' \n", 'magenta')) | |
497 print(colored(f"Polygon coordinates saved as : \n '{txt_save_as}' \n", 'cyan')) | 529 print(colored(f"Polygon coordinates saved as : \n '{txt_save_as}' \n", 'cyan')) |
530 else: | |
531 raise Exception(("Currently only 'detect', 'segment' and 'track' modes are available")) |