Optimized SAHI video inference (#15183)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
d96ea5b493
commit
3e4a581c35
1 changed files with 78 additions and 87 deletions
|
|
@ -9,103 +9,94 @@ from sahi.predict import get_sliced_prediction
|
|||
from sahi.utils.yolov8 import download_yolov8s_model
|
||||
|
||||
from ultralytics.utils.files import increment_path
|
||||
from ultralytics.utils.plotting import Annotator, colors
|
||||
|
||||
|
||||
def run(weights="yolov8n.pt", source="test.mp4", view_img=False, save_img=False, exist_ok=False):
|
||||
"""
|
||||
Run object detection on a video using YOLOv8 and SAHI.
|
||||
class SahiInference:
|
||||
def __init__(self):
|
||||
self.detection_model = None
|
||||
|
||||
Args:
|
||||
weights (str): Model weights path.
|
||||
source (str): Video file path.
|
||||
view_img (bool): Show results.
|
||||
save_img (bool): Save results.
|
||||
exist_ok (bool): Overwrite existing files.
|
||||
"""
|
||||
|
||||
# Check source path
|
||||
if not Path(source).exists():
|
||||
raise FileNotFoundError(f"Source path '{source}' does not exist.")
|
||||
|
||||
yolov8_model_path = f"models/{weights}"
|
||||
download_yolov8s_model(yolov8_model_path)
|
||||
detection_model = AutoDetectionModel.from_pretrained(
|
||||
model_type="yolov8", model_path=yolov8_model_path, confidence_threshold=0.3, device="cpu"
|
||||
)
|
||||
|
||||
# Video setup
|
||||
videocapture = cv2.VideoCapture(source)
|
||||
frame_width, frame_height = int(videocapture.get(3)), int(videocapture.get(4))
|
||||
fps, fourcc = int(videocapture.get(5)), cv2.VideoWriter_fourcc(*"mp4v")
|
||||
|
||||
# Output setup
|
||||
save_dir = increment_path(Path("ultralytics_results_with_sahi") / "exp", exist_ok)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_writer = cv2.VideoWriter(str(save_dir / f"{Path(source).stem}.mp4"), fourcc, fps, (frame_width, frame_height))
|
||||
|
||||
while videocapture.isOpened():
|
||||
success, frame = videocapture.read()
|
||||
if not success:
|
||||
break
|
||||
|
||||
results = get_sliced_prediction(
|
||||
frame, detection_model, slice_height=512, slice_width=512, overlap_height_ratio=0.2, overlap_width_ratio=0.2
|
||||
def load_model(self, weights):
|
||||
yolov8_model_path = f"models/{weights}"
|
||||
download_yolov8s_model(yolov8_model_path)
|
||||
self.detection_model = AutoDetectionModel.from_pretrained(
|
||||
model_type="yolov8", model_path=yolov8_model_path, confidence_threshold=0.3, device="cpu"
|
||||
)
|
||||
object_prediction_list = results.object_prediction_list
|
||||
|
||||
boxes_list = []
|
||||
clss_list = []
|
||||
for ind, _ in enumerate(object_prediction_list):
|
||||
boxes = (
|
||||
object_prediction_list[ind].bbox.minx,
|
||||
object_prediction_list[ind].bbox.miny,
|
||||
object_prediction_list[ind].bbox.maxx,
|
||||
object_prediction_list[ind].bbox.maxy,
|
||||
def inference(
|
||||
self, weights="yolov8n.pt", source="test.mp4", view_img=False, save_img=False, exist_ok=False, track=False
|
||||
):
|
||||
"""
|
||||
Run object detection on a video using YOLOv8 and SAHI.
|
||||
|
||||
Args:
|
||||
weights (str): Model weights path.
|
||||
source (str): Video file path.
|
||||
view_img (bool): Show results.
|
||||
save_img (bool): Save results.
|
||||
exist_ok (bool): Overwrite existing files.
|
||||
track (bool): Enable object tracking with SAHI
|
||||
"""
|
||||
# Video setup
|
||||
cap = cv2.VideoCapture(source)
|
||||
assert cap.isOpened(), "Error reading video file"
|
||||
frame_width, frame_height = int(cap.get(3)), int(cap.get(4))
|
||||
|
||||
# Output setup
|
||||
save_dir = increment_path(Path("ultralytics_results_with_sahi") / "exp", exist_ok)
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_writer = cv2.VideoWriter(
|
||||
str(save_dir / f"{Path(source).stem}.mp4"),
|
||||
cv2.VideoWriter_fourcc(*"mp4v"),
|
||||
int(cap.get(5)),
|
||||
(frame_width, frame_height),
|
||||
)
|
||||
|
||||
# Load model
|
||||
self.load_model(weights)
|
||||
while cap.isOpened():
|
||||
success, frame = cap.read()
|
||||
if not success:
|
||||
break
|
||||
annotator = Annotator(frame) # Initialize annotator for plotting detection and tracking results
|
||||
results = get_sliced_prediction(
|
||||
frame,
|
||||
self.detection_model,
|
||||
slice_height=512,
|
||||
slice_width=512,
|
||||
overlap_height_ratio=0.2,
|
||||
overlap_width_ratio=0.2,
|
||||
)
|
||||
clss = object_prediction_list[ind].category.name
|
||||
boxes_list.append(boxes)
|
||||
clss_list.append(clss)
|
||||
detection_data = [
|
||||
(det.category.name, det.category.id, (det.bbox.minx, det.bbox.miny, det.bbox.maxx, det.bbox.maxy))
|
||||
for det in results.object_prediction_list
|
||||
]
|
||||
|
||||
for box, cls in zip(boxes_list, clss_list):
|
||||
x1, y1, x2, y2 = box
|
||||
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (56, 56, 255), 2)
|
||||
label = str(cls)
|
||||
t_size = cv2.getTextSize(label, 0, fontScale=0.6, thickness=1)[0]
|
||||
cv2.rectangle(
|
||||
frame, (int(x1), int(y1) - t_size[1] - 3), (int(x1) + t_size[0], int(y1) + 3), (56, 56, 255), -1
|
||||
)
|
||||
cv2.putText(
|
||||
frame, label, (int(x1), int(y1) - 2), 0, 0.6, [255, 255, 255], thickness=1, lineType=cv2.LINE_AA
|
||||
)
|
||||
for det in detection_data:
|
||||
annotator.box_label(det[2], label=str(det[0]), color=colors(int(det[1]), True))
|
||||
|
||||
if view_img:
|
||||
cv2.imshow(Path(source).stem, frame)
|
||||
if save_img:
|
||||
video_writer.write(frame)
|
||||
if view_img:
|
||||
cv2.imshow(Path(source).stem, frame)
|
||||
if save_img:
|
||||
video_writer.write(frame)
|
||||
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
video_writer.release()
|
||||
videocapture.release()
|
||||
cv2.destroyAllWindows()
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
break
|
||||
video_writer.release()
|
||||
cap.release()
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
def parse_opt():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--weights", type=str, default="yolov8n.pt", help="initial weights path")
|
||||
parser.add_argument("--source", type=str, required=True, help="video file path")
|
||||
parser.add_argument("--view-img", action="store_true", help="show results")
|
||||
parser.add_argument("--save-img", action="store_true", help="save results")
|
||||
parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main(opt):
|
||||
"""Main function."""
|
||||
run(**vars(opt))
|
||||
def parse_opt(self):
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--weights", type=str, default="yolov8n.pt", help="initial weights path")
|
||||
parser.add_argument("--source", type=str, required=True, help="video file path")
|
||||
parser.add_argument("--view-img", action="store_true", help="show results")
|
||||
parser.add_argument("--save-img", action="store_true", help="save results")
|
||||
parser.add_argument("--exist-ok", action="store_true", help="existing project/name ok, do not increment")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
opt = parse_opt()
|
||||
main(opt)
|
||||
inference = SahiInference()
|
||||
inference.inference(**vars(inference.parse_opt()))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue