ultralytics 8.0.175 StreamLoader wait for missing frames (#4814)

This commit is contained in:
Glenn Jocher 2023-09-10 23:59:43 +02:00 committed by GitHub
parent dd0782bd8d
commit 3c88bebc95
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 197 additions and 49 deletions

View file

@ -135,7 +135,7 @@ def check_source(source):
return source, webcam, screenshot, from_img, in_memory, tensor
def load_inference_source(source=None, imgsz=640, vid_stride=1, stream_buffer=False):
def load_inference_source(source=None, imgsz=640, vid_stride=1, buffer=False):
"""
Loads an inference source for object detection and applies necessary transformations.
@ -143,7 +143,7 @@ def load_inference_source(source=None, imgsz=640, vid_stride=1, stream_buffer=Fa
source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference.
imgsz (int, optional): The size of the image for inference. Default is 640.
vid_stride (int, optional): The frame interval for video sources. Default is 1.
stream_buffer (bool, optional): Determined whether stream frames will be buffered. Default is False.
buffer (bool, optional): Determined whether stream frames will be buffered. Default is False.
Returns:
dataset (Dataset): A dataset object for the specified input source.
@ -157,7 +157,7 @@ def load_inference_source(source=None, imgsz=640, vid_stride=1, stream_buffer=Fa
elif in_memory:
dataset = source
elif webcam:
dataset = LoadStreams(source, imgsz=imgsz, vid_stride=vid_stride, stream_buffer=stream_buffer)
dataset = LoadStreams(source, imgsz=imgsz, vid_stride=vid_stride, buffer=buffer)
elif screenshot:
dataset = LoadScreenshots(source, imgsz=imgsz)
elif from_img:

View file

@ -31,10 +31,10 @@ class SourceTypes:
class LoadStreams:
"""YOLOv8 streamloader, i.e. `yolo predict source='rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`."""
def __init__(self, sources='file.streams', imgsz=640, vid_stride=1, stream_buffer=False):
def __init__(self, sources='file.streams', imgsz=640, vid_stride=1, buffer=False):
"""Initialize instance variables and check for consistent input stream shapes."""
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
self.stream_buffer = stream_buffer # buffer input streams
self.buffer = buffer # buffer input streams
self.running = True # running flag for Thread
self.mode = 'stream'
self.imgsz = imgsz
@ -42,7 +42,7 @@ class LoadStreams:
sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
n = len(sources)
self.sources = [ops.clean_str(x) for x in sources] # clean source names for later
self.imgs, self.fps, self.frames, self.threads, self.shape = [[]] * n, [0] * n, [0] * n, [None] * n, [None] * n
self.imgs, self.fps, self.frames, self.threads, self.shape = [[]] * n, [0] * n, [0] * n, [None] * n, [[]] * n
self.caps = [None] * n # video capture objects
for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream
@ -81,8 +81,7 @@ class LoadStreams:
"""Read stream `i` frames in daemon thread."""
n, f = 0, self.frames[i] # frame number, frame array
while self.running and cap.isOpened() and n < (f - 1):
# Only read a new frame if the buffer is empty
if not self.imgs[i] or not self.stream_buffer:
if len(self.imgs[i]) < 30: # keep a <=30-image buffer
n += 1
cap.grab() # .read() = .grab() followed by .retrieve()
if n % self.vid_stride == 0:
@ -91,7 +90,10 @@ class LoadStreams:
im = np.zeros(self.shape[i], dtype=np.uint8)
LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.')
cap.open(stream) # re-open stream if signal was lost
self.imgs[i].append(im) # add image to buffer
if self.buffer:
self.imgs[i].append(im)
else:
self.imgs[i] = [im]
else:
time.sleep(0.01) # wait until the buffer is empty
@ -117,21 +119,24 @@ class LoadStreams:
"""Returns source paths, transformed and original images for processing."""
self.count += 1
# Wait until a frame is available in each buffer
while not all(self.imgs):
if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
self.close()
raise StopIteration
time.sleep(1 / min(self.fps))
images = []
for i, x in enumerate(self.imgs):
# Get and remove the next frame from imgs buffer
if self.stream_buffer:
images = [x.pop(0) for x in self.imgs]
else:
# Get the latest frame, and clear the rest from the imgs buffer
images = []
for x in self.imgs:
images.append(x.pop(-1) if x else None)
# Wait until a frame is available in each buffer
while not x:
if not self.threads[i].is_alive() or cv2.waitKey(1) == ord('q'): # q to quit
self.close()
raise StopIteration
LOGGER.warning(f'WARNING ⚠️ Waiting for stream {i}')
time.sleep(1 / min(self.fps))
# Get and remove the first frame from imgs buffer
if self.buffer:
images.append(x.pop(0))
# Get the last frame, and clear the rest from the imgs buffer
else:
images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8))
x.clear()
return self.sources, images, None, ''