Support prediction of list of sources, in-memory dataset and other improvements (#685)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
a5410ed79e
commit
0609561549
9 changed files with 174 additions and 73 deletions
|
|
@ -2,11 +2,18 @@
|
|||
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import DataLoader, dataloader, distributed
|
||||
|
||||
from ultralytics.yolo.data.dataloaders.stream_loaders import (LOADERS, LoadImages, LoadPilAndNumpy, LoadScreenshots,
|
||||
LoadStreams, SourceTypes, autocast_list)
|
||||
from ultralytics.yolo.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.yolo.utils.checks import check_file
|
||||
|
||||
from ..utils import LOGGER, colorstr
|
||||
from ..utils.torch_utils import torch_distributed_zero_first
|
||||
from .dataset import ClassificationDataset, YOLODataset
|
||||
|
|
@ -123,3 +130,63 @@ def build_classification_dataloader(path,
|
|||
pin_memory=PIN_MEMORY,
|
||||
worker_init_fn=seed_worker,
|
||||
generator=generator) # or DataLoader(persistent_workers=True)
|
||||
|
||||
|
||||
def check_source(source):
|
||||
webcam, screenshot, from_img, in_memory = False, False, False, False
|
||||
if isinstance(source, (str, int, Path)): # int for local usb carame
|
||||
source = str(source)
|
||||
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
|
||||
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
|
||||
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
|
||||
screenshot = source.lower().startswith('screen')
|
||||
if is_url and is_file:
|
||||
source = check_file(source) # download
|
||||
elif isinstance(source, tuple(LOADERS)):
|
||||
in_memory = True
|
||||
elif isinstance(source, (list, tuple)):
|
||||
source = autocast_list(source) # convert all list elements to PIL or np arrays
|
||||
from_img = True
|
||||
elif isinstance(source, ((Image.Image, np.ndarray))):
|
||||
from_img = True
|
||||
else:
|
||||
raise Exception(
|
||||
"Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict")
|
||||
|
||||
return source, webcam, screenshot, from_img, in_memory
|
||||
|
||||
|
||||
def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1, stride=32, auto=True):
|
||||
"""
|
||||
TODO: docs
|
||||
"""
|
||||
# source
|
||||
source, webcam, screenshot, from_img, in_memory = check_source(source)
|
||||
source_type = SourceTypes(webcam, screenshot, from_img) if not in_memory else source.source_type
|
||||
|
||||
# Dataloader
|
||||
if in_memory:
|
||||
dataset = source
|
||||
elif webcam:
|
||||
dataset = LoadStreams(source,
|
||||
imgsz=imgsz,
|
||||
stride=stride,
|
||||
auto=auto,
|
||||
transforms=transforms,
|
||||
vid_stride=vid_stride)
|
||||
|
||||
elif screenshot:
|
||||
dataset = LoadScreenshots(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms)
|
||||
elif from_img:
|
||||
dataset = LoadPilAndNumpy(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms)
|
||||
else:
|
||||
dataset = LoadImages(source,
|
||||
imgsz=imgsz,
|
||||
stride=stride,
|
||||
auto=auto,
|
||||
transforms=transforms,
|
||||
vid_stride=vid_stride)
|
||||
|
||||
setattr(dataset, 'source_type', source_type) # attach source types
|
||||
|
||||
return dataset
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue