Mercurial > repos > bgruening > json2yolosegment
comparison yolov8.py @ 0:252fd085940d draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools commit 67e0e1d123bcfffb10bab8cc04ae67259caec557
author | bgruening |
---|---|
date | Fri, 13 Jun 2025 11:23:35 +0000 |
parents | |
children | dfda27273ead |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:252fd085940d |
---|---|
1 import argparse | |
2 import os | |
3 import pathlib | |
4 import shutil | |
5 import time | |
6 from argparse import RawTextHelpFormatter | |
7 from collections import defaultdict | |
8 | |
9 import cv2 | |
10 import numpy as np | |
11 from termcolor import colored | |
12 from tifffile import imwrite | |
13 from ultralytics import YOLO | |
14 | |
15 | |
16 # | |
17 # Input arguments | |
18 # | |
19 parser = argparse.ArgumentParser( | |
20 description='train/predict dataset with YOLOv8', | |
21 epilog="""USAGE EXAMPLE:\n\n~~~~Prediction~~~~\n\ | |
22 python yolov8.py --test_path=/g/group/user/data --model_path=/g/cba/models --model_name=yolov8n --save_dir=/g/group/user/results --iou=0.7 --confidence=0.5 --image_size=320 --run_dir=/g/group/user/runs --foldername=batch --headless --num_classes=1 max_det=1 --class_names_file=/g/group/user/class_names.txt\n\ | |
23 \n~~~~Training~~~~ \n\ | |
24 python yolov8.py --train --yaml_path=/g/group/user/example.yaml --model_path=/g/cba/models --model_name=yolov8n --run_dir=/g/group/user/runs/ --image_size=320 --epochs=150 --scale=0.3 --hsv_v=0.5 --model_format=pt --degrees=180 --class_names_file=/g/group/user/class_names.txt""", formatter_class=RawTextHelpFormatter) | |
25 parser.add_argument("--dir_path", | |
26 help=( | |
27 "Path to the training data directory." | |
28 ), | |
29 type=str) | |
30 parser.add_argument("--yaml_path", | |
31 help=( | |
32 "YAML file with all the data paths" | |
33 " i.e. for train, test, valid data." | |
34 ), | |
35 type=str) | |
36 parser.add_argument("--test_path", | |
37 help=( | |
38 "Path to the prediction folder." | |
39 ), | |
40 type=str) | |
41 parser.add_argument("--save_dir", | |
42 help=( | |
43 "Path to the directory where bounding boxes text files" | |
44 " would be saved." | |
45 ), | |
46 type=str) | |
47 parser.add_argument("--run_dir", | |
48 help=( | |
49 "Path where overlaid images would be saved." | |
50 "For example: `RUN_DIR=projectName/results`." | |
51 "This should exist." | |
52 ), | |
53 type=str) | |
54 parser.add_argument("--foldername", | |
55 help=("Folder to save overlaid images.\n" | |
56 "For example: FOLDERNAME=batch.\n" | |
57 "This should not exist as a new folder named `batch`\n" | |
58 " will be created in RUN_DIR.\n" | |
59 " If it exists already then, a new folder named `batch1`\n" | |
60 " will be created automatically as it does not overwrite\n" | |
61 ), | |
62 type=str) | |
63 | |
64 # For selecting and loading model | |
65 parser.add_argument("--model_name", | |
66 help=("Models for task `detect` can be seen here:\n" | |
67 "https://docs.ultralytics.com/tasks/detect/#models \n\n" | |
68 "Models for task `segment` can be seen here:\n" | |
69 "https://docs.ultralytics.com/tasks/segment/#models \n\n" | |
70 " . Use `yolov8n` for `detect` tasks. " | |
71 "For custom model, use `best`" | |
72 ), | |
73 default='yolov8n', type=str) | |
74 parser.add_argument("--model_path", | |
75 help="Full absolute path to the model directory", | |
76 type=str) | |
77 parser.add_argument("--model_format", | |
78 help="Format of the YOLO model i.e pt, yaml etc.", | |
79 default='pt', type=str) | |
80 parser.add_argument("--class_names_file", | |
81 help="Path to the text file containing class names.", | |
82 type=str) | |
83 | |
84 # For training the model and prediction | |
85 parser.add_argument("--mode", | |
86 help=( | |
87 "detection, segmentation, classification, and pose \n. " | |
88 " Only detection mode available currently i.e. `detect`" | |
89 ), default='detect', type=str) | |
90 parser.add_argument('--train', | |
91 help="Do training", | |
92 action='store_true') | |
93 parser.add_argument("--confidence", | |
94 help="Confidence value (0-1) for each detected bounding box", | |
95 default=0.5, type=float) | |
96 parser.add_argument("--epochs", | |
97 help="Number of epochs for training. Default: 100", | |
98 default=100, type=int) | |
99 parser.add_argument("--init_lr", | |
100 help="Number of epochs for training. Default: 100", | |
101 default=0.01, type=float) | |
102 parser.add_argument("--weight_decay", | |
103 help="Number of epochs for training. Default: 100", | |
104 default=0.0005, type=float) | |
105 | |
106 parser.add_argument("--num_classes", | |
107 help="Number of classes to be predicted. Default: 2", | |
108 default=2, type=int) | |
109 parser.add_argument("--iou", | |
110 help="Intersection over union (IoU) threshold for NMS", | |
111 default=0.7, type=float) | |
112 parser.add_argument("--image_size", | |
113 help=("Size of input image to be used only as integer of w,h. \n" | |
114 "For training choose <= 1000. \n\n" | |
115 "Prediction will be done on original image size" | |
116 ), | |
117 default=320, type=int) | |
118 parser.add_argument("--max_det", | |
119 help=("Maximum number of detections allowed per image. \n" | |
120 "Limits the total number of objects the model can detect in a single inference, \n" | |
121 "preventing excessive outputs in dense scenes.\n\n" | |
122 ), | |
123 default=300, type=int) | |
124 | |
125 # For tracking | |
126 parser.add_argument("--tracker_file", | |
127 help=("Path to the configuration file of the tracker used. \n"), | |
128 default='bytetrack.yaml', type=str) | |
129 | |
130 # For headless operation | |
131 parser.add_argument('--headless', action='store_true') | |
132 parser.add_argument('--nextflow', action='store_true') | |
133 | |
134 # For data augmentation | |
135 parser.add_argument("--hsv_h", | |
136 help="(float) image HSV-Hue augmentation (fraction)", | |
137 default=0.015, type=float) | |
138 parser.add_argument("--hsv_s", | |
139 help="(float) image HSV-Saturation augmentation (fraction)", | |
140 default=0.7, type=float) | |
141 parser.add_argument("--hsv_v", | |
142 help="(float) image HSV-Value augmentation (fraction)", | |
143 default=0.4, type=float) | |
144 parser.add_argument("--degrees", | |
145 help="(float) image rotation (+/- deg)", | |
146 default=0.0, type=float) | |
147 parser.add_argument("--translate", | |
148 help="(float) image translation (+/- fraction)", | |
149 default=0.1, type=float) | |
150 parser.add_argument("--scale", | |
151 help="(float) image scale (+/- gain)", | |
152 default=0.5, type=float) | |
153 parser.add_argument("--shear", | |
154 help="(float) image shear (+/- deg)", | |
155 default=0.0, type=float) | |
156 parser.add_argument("--perspective", | |
157 help="(float) image perspective (+/- fraction), range 0-0.001", | |
158 default=0.0, type=float) | |
159 parser.add_argument("--flipud", | |
160 help="(float) image flip up-down (probability)", | |
161 default=0.0, type=float) | |
162 parser.add_argument("--fliplr", | |
163 help="(float) image flip left-right (probability)", | |
164 default=0.5, type=float) | |
165 parser.add_argument("--mosaic", | |
166 help="(float) image mosaic (probability)", | |
167 default=1.0, type=float) | |
168 parser.add_argument("--crop_fraction", | |
169 help="(float) crops image to a fraction of its size to " | |
170 "emphasize central features and adapt to object scales, " | |
171 "reducing background distractions", | |
172 default=1.0, type=float) | |
173 | |
174 | |
175 # | |
176 # Functions | |
177 # | |
178 # Train a new model on the dataset mentioned in yaml file | |
179 def trainModel(model_path, model_name, yaml_filepath, **kwargs): | |
180 if "imgsz" in kwargs: | |
181 image_size = kwargs['imgsz'] | |
182 else: | |
183 image_size = 320 | |
184 | |
185 if "epochs" in kwargs: | |
186 n_epochs = kwargs['epochs'] | |
187 else: | |
188 n_epochs = 100 | |
189 | |
190 if "hsv_h" in kwargs: | |
191 aug_hsv_h = kwargs['hsv_h'] | |
192 else: | |
193 aug_hsv_h = 0.015 | |
194 | |
195 if "hsv_s" in kwargs: | |
196 aug_hsv_s = kwargs['hsv_s'] | |
197 else: | |
198 aug_hsv_s = 0.7 | |
199 | |
200 if "hsv_v" in kwargs: | |
201 aug_hsv_v = kwargs['hsv_v'] | |
202 else: | |
203 aug_hsv_v = 0.4 | |
204 | |
205 if "degrees" in kwargs: | |
206 aug_degrees = kwargs['degrees'] | |
207 else: | |
208 aug_degrees = 10.0 | |
209 | |
210 if "translate" in kwargs: | |
211 aug_translate = kwargs['translate'] | |
212 else: | |
213 aug_translate = 0.1 | |
214 | |
215 if "scale" in kwargs: | |
216 aug_scale = kwargs['scale'] | |
217 else: | |
218 aug_scale = 0.2 | |
219 | |
220 if "shear" in kwargs: | |
221 aug_shear = kwargs['shear'] | |
222 else: | |
223 aug_shear = 0.0 | |
224 | |
225 if "shear" in kwargs: | |
226 aug_shear = kwargs['shear'] | |
227 else: | |
228 aug_shear = 0.0 | |
229 | |
230 if "perspective" in kwargs: | |
231 aug_perspective = kwargs['perspective'] | |
232 else: | |
233 aug_perspective = 0.0 | |
234 | |
235 if "fliplr" in kwargs: | |
236 aug_fliplr = kwargs['fliplr'] | |
237 else: | |
238 aug_fliplr = 0.5 | |
239 | |
240 if "flipud" in kwargs: | |
241 aug_flipud = kwargs['flipud'] | |
242 else: | |
243 aug_flipud = 0.0 | |
244 | |
245 if "mosaic" in kwargs: | |
246 aug_mosaic = kwargs['mosaic'] | |
247 else: | |
248 aug_mosaic = 1.0 | |
249 | |
250 if "crop_fraction" in kwargs: | |
251 aug_crop_fraction = kwargs['crop_fraction'] | |
252 else: | |
253 aug_crop_fraction = 1.0 | |
254 | |
255 if "weight_decay" in kwargs: | |
256 weight_decay = kwargs['weight_decay'] | |
257 else: | |
258 weight_decay = 1.0 | |
259 | |
260 if "init_lr" in kwargs: | |
261 init_lr = kwargs['init_lr'] | |
262 else: | |
263 init_lr = 1.0 | |
264 | |
265 train_save_path = os.path.expanduser('~/runs/' + args.mode + '/train/') | |
266 if os.path.isdir(train_save_path): | |
267 shutil.rmtree(train_save_path) | |
268 # Load a pretrained YOLO model (recommended for training) | |
269 if args.model_format == 'pt': | |
270 model = YOLO(os.path.join(model_path, model_name + "." + args.model_format)) | |
271 else: | |
272 model = YOLO(model_name + "." + args.model_format) | |
273 model.train(data=yaml_filepath, epochs=n_epochs, project=args.run_dir, | |
274 imgsz=image_size, verbose=True, hsv_h=aug_hsv_h, | |
275 hsv_s=aug_hsv_s, hsv_v=aug_hsv_v, degrees=aug_degrees, | |
276 translate=aug_translate, shear=aug_shear, scale=aug_scale, | |
277 perspective=aug_perspective, fliplr=aug_fliplr, | |
278 flipud=aug_flipud, mosaic=aug_mosaic, crop_fraction=aug_crop_fraction, | |
279 weight_decay=weight_decay, lr0=init_lr, seed=42) | |
280 return model | |
281 | |
282 | |
283 # Validate the trained model | |
284 def validateModel(model): | |
285 # Remove prediction save path if already exists | |
286 val_save_path = os.path.expanduser('~/runs/' + args.mode + '/val/') | |
287 if os.path.isdir(val_save_path): | |
288 shutil.rmtree(val_save_path) | |
289 # Validate the model | |
290 metrics = model.val() # no args needed, dataset & settings remembered | |
291 metrics.box.map # map50-95 | |
292 metrics.box.map50 # map50 | |
293 metrics.box.map75 # map75 | |
294 metrics.box.maps # a list contains map50-95 of each category | |
295 | |
296 | |
297 # Do predictions on images/videos using trained/loaded model | |
298 def predict(model, source_datapath, **kwargs): | |
299 if "imgsz" in kwargs: | |
300 image_size = kwargs['imgsz'] | |
301 else: | |
302 image_size = 320 | |
303 | |
304 if "conf" in kwargs: | |
305 confidence = kwargs['conf'] | |
306 else: | |
307 confidence = 0.5 | |
308 | |
309 if "iou" in kwargs: | |
310 iou_value = kwargs['iou'] | |
311 else: | |
312 iou_value = 0.5 | |
313 | |
314 if "num_classes" in kwargs: | |
315 class_array = list(range(kwargs['num_classes'])) | |
316 else: | |
317 class_array = [0, 1] | |
318 | |
319 if "max_det" in kwargs: | |
320 maximum_detections = args.max_det | |
321 else: | |
322 maximum_detections = 300 | |
323 | |
324 if "run_dir" in kwargs: | |
325 run_save_dir = kwargs['run_dir'] | |
326 else: | |
327 # Remove prediction save path if already exists | |
328 pred_save_path = os.path.expanduser('~/runs/' + args.mode + '/predict/') | |
329 if os.path.isdir(pred_save_path): | |
330 shutil.rmtree(pred_save_path) | |
331 if "foldername" in kwargs: | |
332 save_folder_name = kwargs['foldername'] | |
333 # infer on a local image or directory containing images/videos | |
334 prediction = model.predict(source=source_datapath, save=True, stream=True, | |
335 conf=confidence, imgsz=image_size, | |
336 save_conf=True, iou=iou_value, max_det=maximum_detections, | |
337 classes=class_array, save_txt=False, | |
338 project=run_save_dir, name=save_folder_name, verbose=True) | |
339 return prediction | |
340 | |
341 | |
342 # Save bounding boxes | |
343 def save_yolo_bounding_boxes_to_txt(predictions, save_dir): | |
344 """ | |
345 Function to save YOLO bounding boxes to text files. | |
346 Parameters: | |
347 - predictions: List of results from YOLO model inference. | |
348 - save_dir: Directory where the text files will be saved. | |
349 """ | |
350 for result in predictions: | |
351 result = result.to("cpu").numpy() | |
352 # Using bounding_boxes, confidence_scores, and class_num which are defined in the list | |
353 bounding_boxes = result.boxes.xyxy # Bounding boxes in xyxy format | |
354 confidence_scores = result.boxes.conf # Confidence scores | |
355 class_nums = result.boxes.cls # Class numbers | |
356 # Create save directory if it doesn't exist | |
357 save_path = pathlib.Path(save_dir).absolute() | |
358 save_path.mkdir(parents=True, exist_ok=True) | |
359 # Construct filename for the text file | |
360 image_filename = pathlib.Path(result.path).stem | |
361 text_filename = save_path / f"{image_filename}.txt" | |
362 # Write bounding boxes info into the text file | |
363 with open(text_filename, 'w') as f: | |
364 for i in range(bounding_boxes.shape[0]): | |
365 x1, y1, x2, y2 = bounding_boxes[i] | |
366 confidence = confidence_scores[i] | |
367 class_num = int(class_nums[i]) | |
368 f.write(f'{class_num:01} {x1:06.2f} {y1:06.2f} {x2:06.2f} {y2:06.2f} {confidence:0.02} \n') | |
369 print(colored(f"Bounding boxes saved in: {text_filename}", 'green')) | |
370 | |
371 | |
372 if __name__ == '__main__': | |
373 args = parser.parse_args() | |
374 os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" | |
375 # Train/load model | |
376 if (args.train): | |
377 model = trainModel(args.model_path, args.model_name, args.yaml_path, | |
378 imgsz=args.image_size, epochs=args.epochs, | |
379 hsv_h=args.hsv_h, hsv_s=args.hsv_s, hsv_v=args.hsv_v, | |
380 degrees=args.degrees, translate=args.translate, | |
381 shear=args.shear, scale=args.scale, | |
382 perspective=args.perspective, fliplr=args.fliplr, | |
383 flipud=args.flipud, mosaic=args.mosaic) | |
384 validateModel(model) | |
385 else: | |
386 t = time.time() | |
387 train_save_path = os.path.expanduser('~/runs/' + args.mode + '/') | |
388 if os.path.isfile(os.path.join(train_save_path, | |
389 "train", "weights", "best.pt")) and (args.model_name == 'sam'): | |
390 model = YOLO(os.path.join(train_save_path, | |
391 "train", "weights", "best.pt")) | |
392 else: | |
393 model = YOLO(os.path.join(args.model_path, | |
394 args.model_name + ".pt")) | |
395 model.info(verbose=True) | |
396 elapsed = time.time() - t | |
397 print(colored(f"\nYOLO model loaded in : '{elapsed}' sec \n", 'white', 'on_yellow')) | |
398 | |
399 if (args.save_dir): | |
400 # Do predictions (optionally show image results with bounding boxes) | |
401 t = time.time() | |
402 datapath_for_prediction = args.test_path | |
403 # Extracting class names from the model | |
404 class_names = model.names | |
405 predictions = predict(model, datapath_for_prediction, | |
406 imgsz=args.image_size, conf=args.confidence, | |
407 iou=args.iou, run_dir=args.run_dir, | |
408 foldername=args.foldername, num_classes=args.num_classes, max_det=args.max_det) | |
409 elapsed = time.time() - t | |
410 print(colored(f"\nYOLO prediction done in : '{elapsed}' sec \n", 'white', 'on_cyan')) | |
411 | |
412 if (args.mode == "detect"): | |
413 # Save bounding boxes | |
414 save_yolo_bounding_boxes_to_txt(predictions, args.save_dir) | |
415 elif (args.mode == "track"): | |
416 results = model.track(source=datapath_for_prediction, | |
417 tracker=args.tracker_file, | |
418 conf=args.confidence, | |
419 iou=args.iou, | |
420 persist=False, | |
421 show=True, | |
422 save=True, | |
423 project=args.run_dir, | |
424 name=args.foldername) | |
425 # Store the track history | |
426 track_history = defaultdict(lambda: []) | |
427 | |
428 for result in results: | |
429 # Get the boxes and track IDs | |
430 if result.boxes and result.boxes.is_track: | |
431 boxes = result.boxes.xywh.cpu() | |
432 track_ids = result.boxes.id.int().cpu().tolist() | |
433 # Visualize the result on the frame | |
434 frame = result.plot() | |
435 # Plot the tracks | |
436 for box, track_id in zip(boxes, track_ids): | |
437 x, y, w, h = box | |
438 track = track_history[track_id] | |
439 track.append((float(x), float(y))) # x, y center point | |
440 if len(track) > 30: # retain 30 tracks for 30 frames | |
441 track.pop(0) | |
442 | |
443 # Draw the tracking lines | |
444 points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2)) | |
445 cv2.polylines(frame, [points], isClosed=False, color=(230, 230, 230), thickness=2) | |
446 | |
447 # Display the annotated frame | |
448 cv2.imshow("YOLO11 Tracking", frame) | |
449 print(colored(f"Tracking results saved in : '{args.save_dir}' \n", 'green')) | |
450 elif (args.mode == "segment"): | |
451 # Read class names from the file | |
452 with open(args.class_names_file, 'r') as f: | |
453 class_names = [line.strip() for line in f.readlines()] | |
454 # Create a mapping from class names to indices | |
455 class_to_index = {class_name: i for i, class_name in enumerate(class_names)} | |
456 | |
457 # Save polygon coordinates | |
458 for result in predictions: | |
459 # Create binary mask | |
460 img = np.copy(result.orig_img) | |
461 filename = pathlib.Path(result.path).stem | |
462 b_mask = np.zeros(img.shape[:2], np.uint8) | |
463 mask_save_as = str(pathlib.Path(os.path.join(args.save_dir, filename + "_mask.tiff")).absolute()) | |
464 # Define output file path for text file | |
465 output_filename = os.path.splitext(filename)[0] + ".txt" | |
466 txt_save_as = str(pathlib.Path(os.path.join(args.save_dir, filename + ".txt")).absolute()) | |
467 | |
468 for c, ci in enumerate(result): | |
469 # Extract contour result | |
470 contour = ci.masks.xy.pop() | |
471 # Changing the type | |
472 contour = contour.astype(np.int32) | |
473 # Reshaping | |
474 contour = contour.reshape(-1, 1, 2) | |
475 # Draw contour onto mask | |
476 _ = cv2.drawContours(b_mask, [contour], -1, (255, 255, 255), cv2.FILLED) | |
477 | |
478 # Normalized polygon points | |
479 points = ci.masks.xyn.pop() | |
480 obj_class = int(ci.boxes.cls.to("cpu").numpy().item()) | |
481 confidence = result.boxes.conf.to("cpu").numpy()[c] | |
482 | |
483 with open(txt_save_as, 'a') as f: | |
484 segmentation_points = ['{} {}'.format(points[i][0], points[i][1]) for i in range(len(points))] | |
485 segmentation_points_string = ' '.join(segmentation_points) | |
486 line = '{} {} {}\n'.format(obj_class, segmentation_points_string, confidence) | |
487 f.write(line) | |
488 | |
489 imwrite(mask_save_as, b_mask, imagej=True) # save image | |
490 print(colored(f"Saved cropped image as : \n '{mask_save_as}' \n", 'magenta')) | |
491 print(colored(f"Polygon coordinates saved as : \n '{txt_save_as}' \n", 'cyan')) | |
492 | |
493 else: | |
494 raise Exception(("Currently only 'detect' and 'segment' modes are available")) |