ultralytics 8.1.26 LoadImagesAndVideos batched inference (#8817)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-03-10 19:15:44 +01:00 committed by GitHub
parent 1f9667fff2
commit 7451ca1f54
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 186 additions and 171 deletions

View file

@ -11,7 +11,7 @@ from torch.utils.data import dataloader, distributed
from ultralytics.data.loaders import (
LOADERS,
LoadImages,
LoadImagesAndVideos,
LoadPilAndNumpy,
LoadScreenshots,
LoadStreams,
@ -150,34 +150,35 @@ def check_source(source):
return source, webcam, screenshot, from_img, in_memory, tensor
def load_inference_source(source=None, vid_stride=1, buffer=False):
def load_inference_source(source=None, batch=1, vid_stride=1, buffer=False):
"""
Loads an inference source for object detection and applies necessary transformations.
Args:
source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference.
batch (int, optional): Batch size for dataloaders. Default is 1.
vid_stride (int, optional): The frame interval for video sources. Default is 1.
buffer (bool, optional): Determined whether stream frames will be buffered. Default is False.
Returns:
dataset (Dataset): A dataset object for the specified input source.
"""
source, webcam, screenshot, from_img, in_memory, tensor = check_source(source)
source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor)
source, stream, screenshot, from_img, in_memory, tensor = check_source(source)
source_type = source.source_type if in_memory else SourceTypes(stream, screenshot, from_img, tensor)
# Dataloader
if tensor:
dataset = LoadTensor(source)
elif in_memory:
dataset = source
elif webcam:
elif stream:
dataset = LoadStreams(source, vid_stride=vid_stride, buffer=buffer)
elif screenshot:
dataset = LoadScreenshots(source)
elif from_img:
dataset = LoadPilAndNumpy(source)
else:
dataset = LoadImages(source, vid_stride=vid_stride)
dataset = LoadImagesAndVideos(source, batch=batch, vid_stride=vid_stride)
# Attach source types to the dataset
setattr(dataset, "source_type", source_type)