ultralytics 8.0.128 FastSAM autodownload and super() init (#3552)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-07-06 01:11:37 +02:00 committed by GitHub
parent 400f3f72a1
commit ad99246ff1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 90 additions and 72 deletions

View file

@ -1,16 +1,20 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import torch
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
'''Adjust bounding boxes to stick to image border if they are within a certain threshold.
"""
Adjust bounding boxes to stick to image border if they are within a certain threshold.
Args:
boxes: (n, 4)
image_shape: (height, width)
threshold: pixel threshold
boxes: (n, 4)
image_shape: (height, width)
threshold: pixel threshold
Returns:
adjusted_boxes: adjusted bounding boxes
'''
adjusted_boxes: adjusted bounding boxes
"""
# Image dimensions
h, w = image_shape
@ -25,14 +29,16 @@ def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False):
'''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
"""
Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes.
Args:
box1: (4, )
boxes: (n, 4)
box1: (4, )
boxes: (n, 4)
Returns:
high_iou_indices: Indices of boxes with IoU > thres
'''
high_iou_indices: Indices of boxes with IoU > thres
"""
boxes = adjust_bboxes_to_image_border(boxes, image_shape)
# obtain coordinates for intersections
x1 = torch.max(box1[0], boxes[:, 0])
@ -53,11 +59,7 @@ def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=Fals
# compute the IoU
iou = intersection / union # Should be shape (n, )
if raw_output:
if iou.numel() == 0:
return 0
return iou
return 0 if iou.numel() == 0 else iou
# get indices of boxes with IoU > thres
high_iou_indices = torch.nonzero(iou > iou_thres).flatten()
return high_iou_indices
# return indices of boxes with IoU > thres
return torch.nonzero(iou > iou_thres).flatten()