ultralytics 8.3.50 Enhanced segment resample (#18171)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-12-16 19:21:49 +08:00 committed by GitHub
parent f87b447b2d
commit a3d807be13
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 22 additions and 5 deletions

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.49"
__version__ = "8.3.50"
import os

View file

@ -218,8 +218,10 @@ class YOLODataset(BaseDataset):
# NOTE: do NOT resample oriented boxes
segment_resamples = 100 if self.use_obb else 1000
if len(segments) > 0:
# list[np.array(1000, 2)] * num_samples
# (N, 1000, 2)
# make sure segments interpolate correctly if original length is greater than segment_resamples
max_len = max([len(s) for s in segments])
segment_resamples = (max_len + 1) if segment_resamples < max_len else segment_resamples
# list[np.array(segment_resamples, 2)] * num_samples
segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
else:
segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)

View file

@ -7,7 +7,7 @@ from typing import List
import numpy as np
from .ops import ltwh2xywh, ltwh2xyxy, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh
from .ops import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh
def _ntuple(n):
@ -406,7 +406,20 @@ class Instances:
normalized = instances_list[0].normalized
cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis)
seg_len = [b.segments.shape[1] for b in instances_list]
if len(set(seg_len)) > 1: # resample segments if there's different length
max_len = max(seg_len)
cat_segments = np.concatenate(
[
resample_segments(list(b.segments), max_len)
if len(b.segments)
else np.zeros((0, max_len, 2), dtype=np.float32) # re-generating empty segments
for b in instances_list
],
axis=axis,
)
else:
cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis)
cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None
return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized)

View file

@ -624,6 +624,8 @@ def resample_segments(segments, n=1000):
segments (list): the resampled segments.
"""
for i, s in enumerate(segments):
if len(s) == n:
continue
s = np.concatenate((s, s[0:1, :]), axis=0)
x = np.linspace(0, len(s) - 1, n - len(s) if len(s) < n else n)
xp = np.arange(len(s))