ultralytics 8.0.72 faster Windows trainings and corrupt cache fix (#1912)

Co-authored-by: andreaswimmer <53872150+andreaswimmer@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-04-10 00:21:03 +02:00 committed by GitHub
parent 48f1d269fb
commit 95f96dc5bc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 312 additions and 413 deletions

View file

@ -162,7 +162,18 @@ def check_source(source):
def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1, stride=32, auto=True):
"""
TODO: docs
Loads an inference source for object detection and applies necessary transformations.
Args:
source (str, Path, Tensor, PIL.Image, np.ndarray): The input source for inference.
transforms (callable, optional): Custom transformations to be applied to the input source.
imgsz (int, optional): The size of the image for inference. Default is 640.
vid_stride (int, optional): The frame interval for video sources. Default is 1.
stride (int, optional): The model stride. Default is 32.
auto (bool, optional): Automatically apply pre-processing. Default is True.
Returns:
dataset: A dataset object for the specified input source.
"""
source, webcam, screenshot, from_img, in_memory, tensor = check_source(source)
source_type = source.source_type if in_memory else SourceTypes(webcam, screenshot, from_img, tensor)
@ -179,7 +190,6 @@ def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1,
auto=auto,
transforms=transforms,
vid_stride=vid_stride)
elif screenshot:
dataset = LoadScreenshots(source, imgsz=imgsz, stride=stride, auto=auto, transforms=transforms)
elif from_img:
@ -192,6 +202,7 @@ def load_inference_source(source=None, transforms=None, imgsz=640, vid_stride=1,
transforms=transforms,
vid_stride=vid_stride)
setattr(dataset, 'source_type', source_type) # attach source types
# Attach source types to the dataset
setattr(dataset, 'source_type', source_type)
return dataset