ultralytics 8.2.63 refactor FastSAMPredictor (#14582)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Laughing-q <1185102784@qq.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Nguyễn Anh Bình 2024-07-23 00:59:37 +07:00 committed by GitHub
parent db82d1c6ae
commit 3637516412
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 22 additions and 118 deletions

View file

@ -1,84 +1,31 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
from ultralytics.engine.results import Results
from ultralytics.models.fastsam.utils import bbox_iou
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, ops
from ultralytics.models.yolo.segment import SegmentationPredictor
from ultralytics.utils.metrics import box_iou
from .utils import adjust_bboxes_to_image_border
class FastSAMPredictor(DetectionPredictor):
class FastSAMPredictor(SegmentationPredictor):
"""
FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks in Ultralytics
YOLO framework.
This class extends the DetectionPredictor, customizing the prediction pipeline specifically for fast SAM.
It adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing
for single-class segmentation.
Attributes:
cfg (dict): Configuration parameters for prediction.
overrides (dict, optional): Optional parameter overrides for custom behavior.
_callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
adjusts post-processing steps to incorporate mask prediction and non-max suppression while optimizing for single-
class segmentation.
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initializes the FastSAMPredictor class, inheriting from DetectionPredictor and setting the task to 'segment'.
Args:
cfg (dict): Configuration parameters for prediction.
overrides (dict, optional): Optional parameter overrides for custom behavior.
_callbacks (dict, optional): Optional list of callback functions to be invoked during prediction.
"""
super().__init__(cfg, overrides, _callbacks)
self.args.task = "segment"
def postprocess(self, preds, img, orig_imgs):
"""
Perform post-processing steps on predictions, including non-max suppression and scaling boxes to original image
size, and returns the final results.
Args:
preds (list): The raw output predictions from the model.
img (torch.Tensor): The processed image tensor.
orig_imgs (list | torch.Tensor): The original image or list of images.
Returns:
(list): A list of Results objects, each containing processed boxes, masks, and other metadata.
"""
p = ops.non_max_suppression(
preds[0],
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
nc=1, # set to 1 class since SAM has no class predictions
classes=self.args.classes,
)
full_box = torch.zeros(p[0].shape[1], device=p[0].device)
full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0
full_box = full_box.view(1, -1)
critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:])
if critical_iou_index.numel() != 0:
full_box[0][4] = p[0][critical_iou_index][:, 4]
full_box[0][6:] = p[0][critical_iou_index][:, 6:]
p[0][critical_iou_index] = full_box
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
for i, (pred, orig_img, img_path) in enumerate(zip(p, orig_imgs, self.batch[0])):
if not len(pred): # save empty boxes
masks = None
elif self.args.retina_masks:
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
else:
masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks))
"""Applies box postprocess for FastSAM predictions."""
results = super().postprocess(preds, img, orig_imgs)
for result in results:
full_box = torch.tensor(
[0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32
)
boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)
idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
if idx.numel() != 0:
result.boxes.xyxy[idx] = full_box
return results