comparison yolov8.py @ 0:252fd085940d draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools commit 67e0e1d123bcfffb10bab8cc04ae67259caec557
author bgruening
date Fri, 13 Jun 2025 11:23:35 +0000
parents
children dfda27273ead
comparison
equal deleted inserted replaced
-1:000000000000 0:252fd085940d
1 import argparse
2 import os
3 import pathlib
4 import shutil
5 import time
6 from argparse import RawTextHelpFormatter
7 from collections import defaultdict
8
9 import cv2
10 import numpy as np
11 from termcolor import colored
12 from tifffile import imwrite
13 from ultralytics import YOLO
14
15
16 #
17 # Input arguments
18 #
19 parser = argparse.ArgumentParser(
20 description='train/predict dataset with YOLOv8',
21 epilog="""USAGE EXAMPLE:\n\n~~~~Prediction~~~~\n\
22 python yolov8.py --test_path=/g/group/user/data --model_path=/g/cba/models --model_name=yolov8n --save_dir=/g/group/user/results --iou=0.7 --confidence=0.5 --image_size=320 --run_dir=/g/group/user/runs --foldername=batch --headless --num_classes=1 max_det=1 --class_names_file=/g/group/user/class_names.txt\n\
23 \n~~~~Training~~~~ \n\
24 python yolov8.py --train --yaml_path=/g/group/user/example.yaml --model_path=/g/cba/models --model_name=yolov8n --run_dir=/g/group/user/runs/ --image_size=320 --epochs=150 --scale=0.3 --hsv_v=0.5 --model_format=pt --degrees=180 --class_names_file=/g/group/user/class_names.txt""", formatter_class=RawTextHelpFormatter)
25 parser.add_argument("--dir_path",
26 help=(
27 "Path to the training data directory."
28 ),
29 type=str)
30 parser.add_argument("--yaml_path",
31 help=(
32 "YAML file with all the data paths"
33 " i.e. for train, test, valid data."
34 ),
35 type=str)
36 parser.add_argument("--test_path",
37 help=(
38 "Path to the prediction folder."
39 ),
40 type=str)
41 parser.add_argument("--save_dir",
42 help=(
43 "Path to the directory where bounding boxes text files"
44 " would be saved."
45 ),
46 type=str)
47 parser.add_argument("--run_dir",
48 help=(
49 "Path where overlaid images would be saved."
50 "For example: `RUN_DIR=projectName/results`."
51 "This should exist."
52 ),
53 type=str)
54 parser.add_argument("--foldername",
55 help=("Folder to save overlaid images.\n"
56 "For example: FOLDERNAME=batch.\n"
57 "This should not exist as a new folder named `batch`\n"
58 " will be created in RUN_DIR.\n"
59 " If it exists already then, a new folder named `batch1`\n"
60 " will be created automatically as it does not overwrite\n"
61 ),
62 type=str)
63
64 # For selecting and loading model
65 parser.add_argument("--model_name",
66 help=("Models for task `detect` can be seen here:\n"
67 "https://docs.ultralytics.com/tasks/detect/#models \n\n"
68 "Models for task `segment` can be seen here:\n"
69 "https://docs.ultralytics.com/tasks/segment/#models \n\n"
70 " . Use `yolov8n` for `detect` tasks. "
71 "For custom model, use `best`"
72 ),
73 default='yolov8n', type=str)
74 parser.add_argument("--model_path",
75 help="Full absolute path to the model directory",
76 type=str)
77 parser.add_argument("--model_format",
78 help="Format of the YOLO model i.e pt, yaml etc.",
79 default='pt', type=str)
80 parser.add_argument("--class_names_file",
81 help="Path to the text file containing class names.",
82 type=str)
83
84 # For training the model and prediction
85 parser.add_argument("--mode",
86 help=(
87 "detection, segmentation, classification, and pose \n. "
88 " Only detection mode available currently i.e. `detect`"
89 ), default='detect', type=str)
90 parser.add_argument('--train',
91 help="Do training",
92 action='store_true')
93 parser.add_argument("--confidence",
94 help="Confidence value (0-1) for each detected bounding box",
95 default=0.5, type=float)
96 parser.add_argument("--epochs",
97 help="Number of epochs for training. Default: 100",
98 default=100, type=int)
99 parser.add_argument("--init_lr",
100 help="Number of epochs for training. Default: 100",
101 default=0.01, type=float)
102 parser.add_argument("--weight_decay",
103 help="Number of epochs for training. Default: 100",
104 default=0.0005, type=float)
105
106 parser.add_argument("--num_classes",
107 help="Number of classes to be predicted. Default: 2",
108 default=2, type=int)
109 parser.add_argument("--iou",
110 help="Intersection over union (IoU) threshold for NMS",
111 default=0.7, type=float)
112 parser.add_argument("--image_size",
113 help=("Size of input image to be used only as integer of w,h. \n"
114 "For training choose <= 1000. \n\n"
115 "Prediction will be done on original image size"
116 ),
117 default=320, type=int)
118 parser.add_argument("--max_det",
119 help=("Maximum number of detections allowed per image. \n"
120 "Limits the total number of objects the model can detect in a single inference, \n"
121 "preventing excessive outputs in dense scenes.\n\n"
122 ),
123 default=300, type=int)
124
125 # For tracking
126 parser.add_argument("--tracker_file",
127 help=("Path to the configuration file of the tracker used. \n"),
128 default='bytetrack.yaml', type=str)
129
130 # For headless operation
131 parser.add_argument('--headless', action='store_true')
132 parser.add_argument('--nextflow', action='store_true')
133
134 # For data augmentation
135 parser.add_argument("--hsv_h",
136 help="(float) image HSV-Hue augmentation (fraction)",
137 default=0.015, type=float)
138 parser.add_argument("--hsv_s",
139 help="(float) image HSV-Saturation augmentation (fraction)",
140 default=0.7, type=float)
141 parser.add_argument("--hsv_v",
142 help="(float) image HSV-Value augmentation (fraction)",
143 default=0.4, type=float)
144 parser.add_argument("--degrees",
145 help="(float) image rotation (+/- deg)",
146 default=0.0, type=float)
147 parser.add_argument("--translate",
148 help="(float) image translation (+/- fraction)",
149 default=0.1, type=float)
150 parser.add_argument("--scale",
151 help="(float) image scale (+/- gain)",
152 default=0.5, type=float)
153 parser.add_argument("--shear",
154 help="(float) image shear (+/- deg)",
155 default=0.0, type=float)
156 parser.add_argument("--perspective",
157 help="(float) image perspective (+/- fraction), range 0-0.001",
158 default=0.0, type=float)
159 parser.add_argument("--flipud",
160 help="(float) image flip up-down (probability)",
161 default=0.0, type=float)
162 parser.add_argument("--fliplr",
163 help="(float) image flip left-right (probability)",
164 default=0.5, type=float)
165 parser.add_argument("--mosaic",
166 help="(float) image mosaic (probability)",
167 default=1.0, type=float)
168 parser.add_argument("--crop_fraction",
169 help="(float) crops image to a fraction of its size to "
170 "emphasize central features and adapt to object scales, "
171 "reducing background distractions",
172 default=1.0, type=float)
173
174
175 #
176 # Functions
177 #
178 # Train a new model on the dataset mentioned in yaml file
179 def trainModel(model_path, model_name, yaml_filepath, **kwargs):
180 if "imgsz" in kwargs:
181 image_size = kwargs['imgsz']
182 else:
183 image_size = 320
184
185 if "epochs" in kwargs:
186 n_epochs = kwargs['epochs']
187 else:
188 n_epochs = 100
189
190 if "hsv_h" in kwargs:
191 aug_hsv_h = kwargs['hsv_h']
192 else:
193 aug_hsv_h = 0.015
194
195 if "hsv_s" in kwargs:
196 aug_hsv_s = kwargs['hsv_s']
197 else:
198 aug_hsv_s = 0.7
199
200 if "hsv_v" in kwargs:
201 aug_hsv_v = kwargs['hsv_v']
202 else:
203 aug_hsv_v = 0.4
204
205 if "degrees" in kwargs:
206 aug_degrees = kwargs['degrees']
207 else:
208 aug_degrees = 10.0
209
210 if "translate" in kwargs:
211 aug_translate = kwargs['translate']
212 else:
213 aug_translate = 0.1
214
215 if "scale" in kwargs:
216 aug_scale = kwargs['scale']
217 else:
218 aug_scale = 0.2
219
220 if "shear" in kwargs:
221 aug_shear = kwargs['shear']
222 else:
223 aug_shear = 0.0
224
225 if "shear" in kwargs:
226 aug_shear = kwargs['shear']
227 else:
228 aug_shear = 0.0
229
230 if "perspective" in kwargs:
231 aug_perspective = kwargs['perspective']
232 else:
233 aug_perspective = 0.0
234
235 if "fliplr" in kwargs:
236 aug_fliplr = kwargs['fliplr']
237 else:
238 aug_fliplr = 0.5
239
240 if "flipud" in kwargs:
241 aug_flipud = kwargs['flipud']
242 else:
243 aug_flipud = 0.0
244
245 if "mosaic" in kwargs:
246 aug_mosaic = kwargs['mosaic']
247 else:
248 aug_mosaic = 1.0
249
250 if "crop_fraction" in kwargs:
251 aug_crop_fraction = kwargs['crop_fraction']
252 else:
253 aug_crop_fraction = 1.0
254
255 if "weight_decay" in kwargs:
256 weight_decay = kwargs['weight_decay']
257 else:
258 weight_decay = 1.0
259
260 if "init_lr" in kwargs:
261 init_lr = kwargs['init_lr']
262 else:
263 init_lr = 1.0
264
265 train_save_path = os.path.expanduser('~/runs/' + args.mode + '/train/')
266 if os.path.isdir(train_save_path):
267 shutil.rmtree(train_save_path)
268 # Load a pretrained YOLO model (recommended for training)
269 if args.model_format == 'pt':
270 model = YOLO(os.path.join(model_path, model_name + "." + args.model_format))
271 else:
272 model = YOLO(model_name + "." + args.model_format)
273 model.train(data=yaml_filepath, epochs=n_epochs, project=args.run_dir,
274 imgsz=image_size, verbose=True, hsv_h=aug_hsv_h,
275 hsv_s=aug_hsv_s, hsv_v=aug_hsv_v, degrees=aug_degrees,
276 translate=aug_translate, shear=aug_shear, scale=aug_scale,
277 perspective=aug_perspective, fliplr=aug_fliplr,
278 flipud=aug_flipud, mosaic=aug_mosaic, crop_fraction=aug_crop_fraction,
279 weight_decay=weight_decay, lr0=init_lr, seed=42)
280 return model
281
282
283 # Validate the trained model
284 def validateModel(model):
285 # Remove prediction save path if already exists
286 val_save_path = os.path.expanduser('~/runs/' + args.mode + '/val/')
287 if os.path.isdir(val_save_path):
288 shutil.rmtree(val_save_path)
289 # Validate the model
290 metrics = model.val() # no args needed, dataset & settings remembered
291 metrics.box.map # map50-95
292 metrics.box.map50 # map50
293 metrics.box.map75 # map75
294 metrics.box.maps # a list contains map50-95 of each category
295
296
297 # Do predictions on images/videos using trained/loaded model
298 def predict(model, source_datapath, **kwargs):
299 if "imgsz" in kwargs:
300 image_size = kwargs['imgsz']
301 else:
302 image_size = 320
303
304 if "conf" in kwargs:
305 confidence = kwargs['conf']
306 else:
307 confidence = 0.5
308
309 if "iou" in kwargs:
310 iou_value = kwargs['iou']
311 else:
312 iou_value = 0.5
313
314 if "num_classes" in kwargs:
315 class_array = list(range(kwargs['num_classes']))
316 else:
317 class_array = [0, 1]
318
319 if "max_det" in kwargs:
320 maximum_detections = args.max_det
321 else:
322 maximum_detections = 300
323
324 if "run_dir" in kwargs:
325 run_save_dir = kwargs['run_dir']
326 else:
327 # Remove prediction save path if already exists
328 pred_save_path = os.path.expanduser('~/runs/' + args.mode + '/predict/')
329 if os.path.isdir(pred_save_path):
330 shutil.rmtree(pred_save_path)
331 if "foldername" in kwargs:
332 save_folder_name = kwargs['foldername']
333 # infer on a local image or directory containing images/videos
334 prediction = model.predict(source=source_datapath, save=True, stream=True,
335 conf=confidence, imgsz=image_size,
336 save_conf=True, iou=iou_value, max_det=maximum_detections,
337 classes=class_array, save_txt=False,
338 project=run_save_dir, name=save_folder_name, verbose=True)
339 return prediction
340
341
342 # Save bounding boxes
343 def save_yolo_bounding_boxes_to_txt(predictions, save_dir):
344 """
345 Function to save YOLO bounding boxes to text files.
346 Parameters:
347 - predictions: List of results from YOLO model inference.
348 - save_dir: Directory where the text files will be saved.
349 """
350 for result in predictions:
351 result = result.to("cpu").numpy()
352 # Using bounding_boxes, confidence_scores, and class_num which are defined in the list
353 bounding_boxes = result.boxes.xyxy # Bounding boxes in xyxy format
354 confidence_scores = result.boxes.conf # Confidence scores
355 class_nums = result.boxes.cls # Class numbers
356 # Create save directory if it doesn't exist
357 save_path = pathlib.Path(save_dir).absolute()
358 save_path.mkdir(parents=True, exist_ok=True)
359 # Construct filename for the text file
360 image_filename = pathlib.Path(result.path).stem
361 text_filename = save_path / f"{image_filename}.txt"
362 # Write bounding boxes info into the text file
363 with open(text_filename, 'w') as f:
364 for i in range(bounding_boxes.shape[0]):
365 x1, y1, x2, y2 = bounding_boxes[i]
366 confidence = confidence_scores[i]
367 class_num = int(class_nums[i])
368 f.write(f'{class_num:01} {x1:06.2f} {y1:06.2f} {x2:06.2f} {y2:06.2f} {confidence:0.02} \n')
369 print(colored(f"Bounding boxes saved in: {text_filename}", 'green'))
370
371
372 if __name__ == '__main__':
373 args = parser.parse_args()
374 os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
375 # Train/load model
376 if (args.train):
377 model = trainModel(args.model_path, args.model_name, args.yaml_path,
378 imgsz=args.image_size, epochs=args.epochs,
379 hsv_h=args.hsv_h, hsv_s=args.hsv_s, hsv_v=args.hsv_v,
380 degrees=args.degrees, translate=args.translate,
381 shear=args.shear, scale=args.scale,
382 perspective=args.perspective, fliplr=args.fliplr,
383 flipud=args.flipud, mosaic=args.mosaic)
384 validateModel(model)
385 else:
386 t = time.time()
387 train_save_path = os.path.expanduser('~/runs/' + args.mode + '/')
388 if os.path.isfile(os.path.join(train_save_path,
389 "train", "weights", "best.pt")) and (args.model_name == 'sam'):
390 model = YOLO(os.path.join(train_save_path,
391 "train", "weights", "best.pt"))
392 else:
393 model = YOLO(os.path.join(args.model_path,
394 args.model_name + ".pt"))
395 model.info(verbose=True)
396 elapsed = time.time() - t
397 print(colored(f"\nYOLO model loaded in : '{elapsed}' sec \n", 'white', 'on_yellow'))
398
399 if (args.save_dir):
400 # Do predictions (optionally show image results with bounding boxes)
401 t = time.time()
402 datapath_for_prediction = args.test_path
403 # Extracting class names from the model
404 class_names = model.names
405 predictions = predict(model, datapath_for_prediction,
406 imgsz=args.image_size, conf=args.confidence,
407 iou=args.iou, run_dir=args.run_dir,
408 foldername=args.foldername, num_classes=args.num_classes, max_det=args.max_det)
409 elapsed = time.time() - t
410 print(colored(f"\nYOLO prediction done in : '{elapsed}' sec \n", 'white', 'on_cyan'))
411
412 if (args.mode == "detect"):
413 # Save bounding boxes
414 save_yolo_bounding_boxes_to_txt(predictions, args.save_dir)
415 elif (args.mode == "track"):
416 results = model.track(source=datapath_for_prediction,
417 tracker=args.tracker_file,
418 conf=args.confidence,
419 iou=args.iou,
420 persist=False,
421 show=True,
422 save=True,
423 project=args.run_dir,
424 name=args.foldername)
425 # Store the track history
426 track_history = defaultdict(lambda: [])
427
428 for result in results:
429 # Get the boxes and track IDs
430 if result.boxes and result.boxes.is_track:
431 boxes = result.boxes.xywh.cpu()
432 track_ids = result.boxes.id.int().cpu().tolist()
433 # Visualize the result on the frame
434 frame = result.plot()
435 # Plot the tracks
436 for box, track_id in zip(boxes, track_ids):
437 x, y, w, h = box
438 track = track_history[track_id]
439 track.append((float(x), float(y))) # x, y center point
440 if len(track) > 30: # retain 30 tracks for 30 frames
441 track.pop(0)
442
443 # Draw the tracking lines
444 points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
445 cv2.polylines(frame, [points], isClosed=False, color=(230, 230, 230), thickness=2)
446
447 # Display the annotated frame
448 cv2.imshow("YOLO11 Tracking", frame)
449 print(colored(f"Tracking results saved in : '{args.save_dir}' \n", 'green'))
450 elif (args.mode == "segment"):
451 # Read class names from the file
452 with open(args.class_names_file, 'r') as f:
453 class_names = [line.strip() for line in f.readlines()]
454 # Create a mapping from class names to indices
455 class_to_index = {class_name: i for i, class_name in enumerate(class_names)}
456
457 # Save polygon coordinates
458 for result in predictions:
459 # Create binary mask
460 img = np.copy(result.orig_img)
461 filename = pathlib.Path(result.path).stem
462 b_mask = np.zeros(img.shape[:2], np.uint8)
463 mask_save_as = str(pathlib.Path(os.path.join(args.save_dir, filename + "_mask.tiff")).absolute())
464 # Define output file path for text file
465 output_filename = os.path.splitext(filename)[0] + ".txt"
466 txt_save_as = str(pathlib.Path(os.path.join(args.save_dir, filename + ".txt")).absolute())
467
468 for c, ci in enumerate(result):
469 # Extract contour result
470 contour = ci.masks.xy.pop()
471 # Changing the type
472 contour = contour.astype(np.int32)
473 # Reshaping
474 contour = contour.reshape(-1, 1, 2)
475 # Draw contour onto mask
476 _ = cv2.drawContours(b_mask, [contour], -1, (255, 255, 255), cv2.FILLED)
477
478 # Normalized polygon points
479 points = ci.masks.xyn.pop()
480 obj_class = int(ci.boxes.cls.to("cpu").numpy().item())
481 confidence = result.boxes.conf.to("cpu").numpy()[c]
482
483 with open(txt_save_as, 'a') as f:
484 segmentation_points = ['{} {}'.format(points[i][0], points[i][1]) for i in range(len(points))]
485 segmentation_points_string = ' '.join(segmentation_points)
486 line = '{} {} {}\n'.format(obj_class, segmentation_points_string, confidence)
487 f.write(line)
488
489 imwrite(mask_save_as, b_mask, imagej=True) # save image
490 print(colored(f"Saved cropped image as : \n '{mask_save_as}' \n", 'magenta'))
491 print(colored(f"Polygon coordinates saved as : \n '{txt_save_as}' \n", 'cyan'))
492
493 else:
494 raise Exception(("Currently only 'detect' and 'segment' modes are available"))