ultralytics 8.0.164 new StreamLoader stream_buffer argument (#4596)
Co-authored-by: jgoo9410 <jjoohhnnggooddwwiinn@gmail.com> Co-authored-by: John Goodwin <johnf4g@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
bd96c0846b
commit
1121ef2409
8 changed files with 89 additions and 70 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = '8.0.163'
|
||||
__version__ = '8.0.164'
|
||||
|
||||
from ultralytics.models import RTDETR, SAM, YOLO
|
||||
from ultralytics.models.fastsam import FastSAM
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ save_crop: False # (bool) save cropped images with results
|
|||
show_labels: True # (bool) show object labels in plots
|
||||
show_conf: True # (bool) show object confidence scores in plots
|
||||
vid_stride: 1 # (int) video frame-rate stride
|
||||
stream_buffer: False # (bool) buffer all streaming frames (True) or return the most recent frame (False)
|
||||
line_width: # (int, optional) line width of the bounding boxes, auto if missing
|
||||
visualize: False # (bool) visualize model features
|
||||
augment: False # (bool) apply image augmentation to prediction sources
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ def check_source(source):
|
|||
return source, webcam, screenshot, from_img, in_memory, tensor
|
||||
|
||||
|
||||
def load_inference_source(source=None, imgsz=640, vid_stride=1):
|
||||
def load_inference_source(source=None, imgsz=640, vid_stride=1, stream_buffer=False):
|
||||
"""
|
||||
Loads an inference source for object detection and applies necessary transformations.
|
||||
|
||||
|
|
@ -143,6 +143,7 @@ def load_inference_source(source=None, imgsz=640, vid_stride=1):
|
|||
source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference.
|
||||
imgsz (int, optional): The size of the image for inference. Default is 640.
|
||||
vid_stride (int, optional): The frame interval for video sources. Default is 1.
|
||||
stream_buffer (bool, optional): Determined whether stream frames will be buffered. Default is False.
|
||||
|
||||
Returns:
|
||||
dataset (Dataset): A dataset object for the specified input source.
|
||||
|
|
@ -156,7 +157,7 @@ def load_inference_source(source=None, imgsz=640, vid_stride=1):
|
|||
elif in_memory:
|
||||
dataset = source
|
||||
elif webcam:
|
||||
dataset = LoadStreams(source, imgsz=imgsz, vid_stride=vid_stride)
|
||||
dataset = LoadStreams(source, imgsz=imgsz, vid_stride=vid_stride, stream_buffer=stream_buffer)
|
||||
elif screenshot:
|
||||
dataset = LoadScreenshots(source, imgsz=imgsz)
|
||||
elif from_img:
|
||||
|
|
|
|||
|
|
@ -31,9 +31,10 @@ class SourceTypes:
|
|||
class LoadStreams:
|
||||
"""YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`."""
|
||||
|
||||
def __init__(self, sources='file.streams', imgsz=640, vid_stride=1):
|
||||
def __init__(self, sources='file.streams', imgsz=640, vid_stride=1, stream_buffer=False):
|
||||
"""Initialize instance variables and check for consistent input stream shapes."""
|
||||
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
|
||||
self.stream_buffer = stream_buffer # buffer input streams
|
||||
self.running = True # running flag for Thread
|
||||
self.mode = 'stream'
|
||||
self.imgsz = imgsz
|
||||
|
|
@ -81,7 +82,7 @@ class LoadStreams:
|
|||
n, f = 0, self.frames[i] # frame number, frame array
|
||||
while self.running and cap.isOpened() and n < (f - 1):
|
||||
# Only read a new frame if the buffer is empty
|
||||
if not self.imgs[i]:
|
||||
if not self.imgs[i] or not self.stream_buffer:
|
||||
n += 1
|
||||
cap.grab() # .read() = .grab() followed by .retrieve()
|
||||
if n % self.vid_stride == 0:
|
||||
|
|
@ -124,7 +125,16 @@ class LoadStreams:
|
|||
time.sleep(1 / min(self.fps))
|
||||
|
||||
# Get and remove the next frame from imgs buffer
|
||||
return self.sources, [x.pop(0) for x in self.imgs], None, ''
|
||||
if self.stream_buffer:
|
||||
images = [x.pop(0) for x in self.imgs]
|
||||
else:
|
||||
# Get the latest frame, and clear the rest from the imgs buffer
|
||||
images = []
|
||||
for x in self.imgs:
|
||||
images.append(x.pop(-1) if x else None)
|
||||
x.clear()
|
||||
|
||||
return self.sources, images, None, ''
|
||||
|
||||
def __len__(self):
|
||||
"""Return the length of the sources object."""
|
||||
|
|
|
|||
|
|
@ -209,7 +209,10 @@ class BasePredictor:
|
|||
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
||||
self.transforms = getattr(self.model.model, 'transforms', classify_transforms(
|
||||
self.imgsz[0])) if self.args.task == 'classify' else None
|
||||
self.dataset = load_inference_source(source=source, imgsz=self.imgsz, vid_stride=self.args.vid_stride)
|
||||
self.dataset = load_inference_source(source=source,
|
||||
imgsz=self.imgsz,
|
||||
vid_stride=self.args.vid_stride,
|
||||
stream_buffer=self.args.stream_buffer)
|
||||
self.source_type = self.dataset.source_type
|
||||
if not getattr(self, 'stream', True) and (self.dataset.mode == 'stream' or # streams
|
||||
len(self.dataset) > 1000 or # images
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue