ultralytics 8.1.26 LoadImagesAndVideos batched inference (#8817)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
1f9667fff2
commit
7451ca1f54
11 changed files with 186 additions and 171 deletions
|
|
@ -11,7 +11,7 @@ from torch.utils.data import dataloader, distributed
|
|||
|
||||
from ultralytics.data.loaders import (
|
||||
LOADERS,
|
||||
LoadImages,
|
||||
LoadImagesAndVideos,
|
||||
LoadPilAndNumpy,
|
||||
LoadScreenshots,
|
||||
LoadStreams,
|
||||
|
|
@ -150,34 +150,35 @@ def check_source(source):
|
|||
return source, webcam, screenshot, from_img, in_memory, tensor
|
||||
|
||||
|
||||
def load_inference_source(source=None, vid_stride=1, buffer=False):
|
||||
def load_inference_source(source=None, batch=1, vid_stride=1, buffer=False):
|
||||
"""
|
||||
Loads an inference source for object detection and applies necessary transformations.
|
||||
|
||||
Args:
|
||||
source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference.
|
||||
batch (int, optional): Batch size for dataloaders. Default is 1.
|
||||
vid_stride (int, optional): The frame interval for video sources. Default is 1.
|
||||
buffer (bool, optional): Determined whether stream frames will be buffered. Default is False.
|
||||
|
||||
Returns:
|
||||
dataset (Dataset): A dataset object for the specified input source.
|
||||
"""
|
||||
source, webcam, screenshot, from_img, in_memory, tensor = check_source(source)
|
||||
source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor)
|
||||
source, stream, screenshot, from_img, in_memory, tensor = check_source(source)
|
||||
source_type = source.source_type if in_memory else SourceTypes(stream, screenshot, from_img, tensor)
|
||||
|
||||
# Dataloader
|
||||
if tensor:
|
||||
dataset = LoadTensor(source)
|
||||
elif in_memory:
|
||||
dataset = source
|
||||
elif webcam:
|
||||
elif stream:
|
||||
dataset = LoadStreams(source, vid_stride=vid_stride, buffer=buffer)
|
||||
elif screenshot:
|
||||
dataset = LoadScreenshots(source)
|
||||
elif from_img:
|
||||
dataset = LoadPilAndNumpy(source)
|
||||
else:
|
||||
dataset = LoadImages(source, vid_stride=vid_stride)
|
||||
dataset = LoadImagesAndVideos(source, batch=batch, vid_stride=vid_stride)
|
||||
|
||||
# Attach source types to the dataset
|
||||
setattr(dataset, "source_type", source_type)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue