ultralytics 8.1.40 search in Python sets {} for speed (#9450)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
30484d5925
commit
ea527507fe
41 changed files with 97 additions and 93 deletions
|
|
@ -351,7 +351,7 @@ def test_labels_and_crops():
|
||||||
crop_dirs = [p for p in (save_path / "crops").iterdir()]
|
crop_dirs = [p for p in (save_path / "crops").iterdir()]
|
||||||
crop_files = [f for p in crop_dirs for f in p.glob("*")]
|
crop_files = [f for p in crop_dirs for f in p.glob("*")]
|
||||||
# Crop directories match detections
|
# Crop directories match detections
|
||||||
assert all([r.names.get(c) in [d.name for d in crop_dirs] for c in cls_idxs])
|
assert all([r.names.get(c) in {d.name for d in crop_dirs} for c in cls_idxs])
|
||||||
# Same number of crops as detections
|
# Same number of crops as detections
|
||||||
assert len([f for f in crop_files if im_name in f.name]) == len(r.boxes.data)
|
assert len([f for f in crop_files if im_name in f.name]) == len(r.boxes.data)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||||
|
|
||||||
__version__ = "8.1.39"
|
__version__ = "8.1.40"
|
||||||
|
|
||||||
from ultralytics.data.explorer.explorer import Explorer
|
from ultralytics.data.explorer.explorer import Explorer
|
||||||
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
|
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
|
||||||
|
|
|
||||||
|
|
@ -272,7 +272,7 @@ def get_save_dir(args, name=None):
|
||||||
|
|
||||||
project = args.project or (ROOT.parent / "tests/tmp/runs" if TESTS_RUNNING else RUNS_DIR) / args.task
|
project = args.project or (ROOT.parent / "tests/tmp/runs" if TESTS_RUNNING else RUNS_DIR) / args.task
|
||||||
name = name or args.name or f"{args.mode}"
|
name = name or args.name or f"{args.mode}"
|
||||||
save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True)
|
save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in {-1, 0} else True)
|
||||||
|
|
||||||
return Path(save_dir)
|
return Path(save_dir)
|
||||||
|
|
||||||
|
|
@ -566,10 +566,10 @@ def entrypoint(debug=""):
|
||||||
task = model.task
|
task = model.task
|
||||||
|
|
||||||
# Mode
|
# Mode
|
||||||
if mode in ("predict", "track") and "source" not in overrides:
|
if mode in {"predict", "track"} and "source" not in overrides:
|
||||||
overrides["source"] = DEFAULT_CFG.source or ASSETS
|
overrides["source"] = DEFAULT_CFG.source or ASSETS
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'source' argument is missing. Using default 'source={overrides['source']}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'source' argument is missing. Using default 'source={overrides['source']}'.")
|
||||||
elif mode in ("train", "val"):
|
elif mode in {"train", "val"}:
|
||||||
if "data" not in overrides and "resume" not in overrides:
|
if "data" not in overrides and "resume" not in overrides:
|
||||||
overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
|
overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
|
||||||
LOGGER.warning(f"WARNING ⚠️ 'data' argument is missing. Using default 'data={overrides['data']}'.")
|
LOGGER.warning(f"WARNING ⚠️ 'data' argument is missing. Using default 'data={overrides['data']}'.")
|
||||||
|
|
|
||||||
|
|
@ -191,7 +191,7 @@ class Mosaic(BaseMixTransform):
|
||||||
def __init__(self, dataset, imgsz=640, p=1.0, n=4):
|
def __init__(self, dataset, imgsz=640, p=1.0, n=4):
|
||||||
"""Initializes the object with a dataset, image size, probability, and border."""
|
"""Initializes the object with a dataset, image size, probability, and border."""
|
||||||
assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}."
|
assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}."
|
||||||
assert n in (4, 9), "grid must be equal to 4 or 9."
|
assert n in {4, 9}, "grid must be equal to 4 or 9."
|
||||||
super().__init__(dataset=dataset, p=p)
|
super().__init__(dataset=dataset, p=p)
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.imgsz = imgsz
|
self.imgsz = imgsz
|
||||||
|
|
@ -685,7 +685,7 @@ class RandomFlip:
|
||||||
Default is 'horizontal'.
|
Default is 'horizontal'.
|
||||||
flip_idx (array-like, optional): Index mapping for flipping keypoints, if any.
|
flip_idx (array-like, optional): Index mapping for flipping keypoints, if any.
|
||||||
"""
|
"""
|
||||||
assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}"
|
assert direction in {"horizontal", "vertical"}, f"Support direction `horizontal` or `vertical`, got {direction}"
|
||||||
assert 0 <= p <= 1.0
|
assert 0 <= p <= 1.0
|
||||||
|
|
||||||
self.p = p
|
self.p = p
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ import psutil
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM
|
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM
|
||||||
from .utils import HELP_URL, IMG_FORMATS
|
from .utils import HELP_URL, FORMATS_HELP_MSG, IMG_FORMATS
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(Dataset):
|
class BaseDataset(Dataset):
|
||||||
|
|
@ -118,7 +118,7 @@ class BaseDataset(Dataset):
|
||||||
raise FileNotFoundError(f"{self.prefix}{p} does not exist")
|
raise FileNotFoundError(f"{self.prefix}{p} does not exist")
|
||||||
im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
|
im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
|
||||||
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
||||||
assert im_files, f"{self.prefix}No images found in {img_path}"
|
assert im_files, f"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
|
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
|
||||||
if self.fraction < 1:
|
if self.fraction < 1:
|
||||||
|
|
|
||||||
|
|
@ -481,7 +481,7 @@ def merge_multi_segment(segments):
|
||||||
segments[i] = np.roll(segments[i], -idx[0], axis=0)
|
segments[i] = np.roll(segments[i], -idx[0], axis=0)
|
||||||
segments[i] = np.concatenate([segments[i], segments[i][:1]])
|
segments[i] = np.concatenate([segments[i], segments[i][:1]])
|
||||||
# Deal with the first segment and the last one
|
# Deal with the first segment and the last one
|
||||||
if i in [0, len(idx_list) - 1]:
|
if i in {0, len(idx_list) - 1}:
|
||||||
s.append(segments[i])
|
s.append(segments[i])
|
||||||
else:
|
else:
|
||||||
idx = [0, idx[1] - idx[0]]
|
idx = [0, idx[1] - idx[0]]
|
||||||
|
|
@ -489,7 +489,7 @@ def merge_multi_segment(segments):
|
||||||
|
|
||||||
else:
|
else:
|
||||||
for i in range(len(idx_list) - 1, -1, -1):
|
for i in range(len(idx_list) - 1, -1, -1):
|
||||||
if i not in [0, len(idx_list) - 1]:
|
if i not in {0, len(idx_list) - 1}:
|
||||||
idx = idx_list[i]
|
idx = idx_list[i]
|
||||||
nidx = abs(idx[1] - idx[0])
|
nidx = abs(idx[1] - idx[0])
|
||||||
s.append(segments[i][nidx:])
|
s.append(segments[i][nidx:])
|
||||||
|
|
|
||||||
|
|
@ -77,7 +77,7 @@ class YOLODataset(BaseDataset):
|
||||||
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
||||||
total = len(self.im_files)
|
total = len(self.im_files)
|
||||||
nkpt, ndim = self.data.get("kpt_shape", (0, 0))
|
nkpt, ndim = self.data.get("kpt_shape", (0, 0))
|
||||||
if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
|
if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
|
"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
|
||||||
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
|
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
|
||||||
|
|
@ -142,7 +142,7 @@ class YOLODataset(BaseDataset):
|
||||||
|
|
||||||
# Display cache
|
# Display cache
|
||||||
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
||||||
if exists and LOCAL_RANK in (-1, 0):
|
if exists and LOCAL_RANK in {-1, 0}:
|
||||||
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||||
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
|
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
|
||||||
if cache["msgs"]:
|
if cache["msgs"]:
|
||||||
|
|
@ -235,7 +235,7 @@ class YOLODataset(BaseDataset):
|
||||||
value = values[i]
|
value = values[i]
|
||||||
if k == "img":
|
if k == "img":
|
||||||
value = torch.stack(value, 0)
|
value = torch.stack(value, 0)
|
||||||
if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]:
|
if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
|
||||||
value = torch.cat(value, 0)
|
value = torch.cat(value, 0)
|
||||||
new_batch[k] = value
|
new_batch[k] = value
|
||||||
new_batch["batch_idx"] = list(new_batch["batch_idx"])
|
new_batch["batch_idx"] = list(new_batch["batch_idx"])
|
||||||
|
|
@ -334,7 +334,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
||||||
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
||||||
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
||||||
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
|
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
|
||||||
if LOCAL_RANK in (-1, 0):
|
if LOCAL_RANK in {-1, 0}:
|
||||||
d = f"{desc} {nf} images, {nc} corrupt"
|
d = f"{desc} {nf} images, {nc} corrupt"
|
||||||
TQDM(None, desc=d, total=n, initial=n)
|
TQDM(None, desc=d, total=n, initial=n)
|
||||||
if cache["msgs"]:
|
if cache["msgs"]:
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ import requests
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
|
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS, FORMATS_HELP_MSG
|
||||||
from ultralytics.utils import LOGGER, is_colab, is_kaggle, ops
|
from ultralytics.utils import LOGGER, is_colab, is_kaggle, ops
|
||||||
from ultralytics.utils.checks import check_requirements
|
from ultralytics.utils.checks import check_requirements
|
||||||
|
|
||||||
|
|
@ -83,7 +83,7 @@ class LoadStreams:
|
||||||
for i, s in enumerate(sources): # index, source
|
for i, s in enumerate(sources): # index, source
|
||||||
# Start thread to read frames from video stream
|
# Start thread to read frames from video stream
|
||||||
st = f"{i + 1}/{n}: {s}... "
|
st = f"{i + 1}/{n}: {s}... "
|
||||||
if urlparse(s).hostname in ("www.youtube.com", "youtube.com", "youtu.be"): # if source is YouTube video
|
if urlparse(s).hostname in {"www.youtube.com", "youtube.com", "youtu.be"}: # if source is YouTube video
|
||||||
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4'
|
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4'
|
||||||
s = get_best_youtube_url(s)
|
s = get_best_youtube_url(s)
|
||||||
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
||||||
|
|
@ -291,8 +291,14 @@ class LoadImagesAndVideos:
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"{p} does not exist")
|
raise FileNotFoundError(f"{p} does not exist")
|
||||||
|
|
||||||
images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS]
|
# Define files as images or videos
|
||||||
videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS]
|
images, videos = [], []
|
||||||
|
for f in files:
|
||||||
|
suffix = f.split(".")[-1].lower() # Get file extension without the dot and lowercase
|
||||||
|
if suffix in IMG_FORMATS:
|
||||||
|
images.append(f)
|
||||||
|
elif suffix in VID_FORMATS:
|
||||||
|
videos.append(f)
|
||||||
ni, nv = len(images), len(videos)
|
ni, nv = len(images), len(videos)
|
||||||
|
|
||||||
self.files = images + videos
|
self.files = images + videos
|
||||||
|
|
@ -307,10 +313,7 @@ class LoadImagesAndVideos:
|
||||||
else:
|
else:
|
||||||
self.cap = None
|
self.cap = None
|
||||||
if self.nf == 0:
|
if self.nf == 0:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(f"No images or videos found in {p}. {FORMATS_HELP_MSG}")
|
||||||
f"No images or videos found in {p}. "
|
|
||||||
f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
"""Returns an iterator object for VideoStream or ImageFolder."""
|
"""Returns an iterator object for VideoStream or ImageFolder."""
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ def load_yolo_dota(data_root, split="train"):
|
||||||
- train
|
- train
|
||||||
- val
|
- val
|
||||||
"""
|
"""
|
||||||
assert split in ["train", "val"]
|
assert split in {"train", "val"}, f"Split must be 'train' or 'val', not {split}."
|
||||||
im_dir = Path(data_root) / "images" / split
|
im_dir = Path(data_root) / "images" / split
|
||||||
assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
|
assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
|
||||||
im_files = glob(str(Path(data_root) / "images" / split / "*"))
|
im_files = glob(str(Path(data_root) / "images" / split / "*"))
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatt
|
||||||
IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"} # image suffixes
|
IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"} # image suffixes
|
||||||
VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes
|
VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes
|
||||||
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
|
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
|
||||||
|
FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
|
||||||
|
|
||||||
|
|
||||||
def img2label_paths(img_paths):
|
def img2label_paths(img_paths):
|
||||||
|
|
@ -63,7 +64,7 @@ def exif_size(img: Image.Image):
|
||||||
exif = img.getexif()
|
exif = img.getexif()
|
||||||
if exif:
|
if exif:
|
||||||
rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274
|
rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274
|
||||||
if rotation in [6, 8]: # rotation 270 or 90
|
if rotation in {6, 8}: # rotation 270 or 90
|
||||||
s = s[1], s[0]
|
s = s[1], s[0]
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
@ -79,8 +80,8 @@ def verify_image(args):
|
||||||
shape = exif_size(im) # image size
|
shape = exif_size(im) # image size
|
||||||
shape = (shape[1], shape[0]) # hw
|
shape = (shape[1], shape[0]) # hw
|
||||||
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
||||||
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
|
assert im.format.lower() in IMG_FORMATS, f"Invalid image format {im.format}. {FORMATS_HELP_MSG}"
|
||||||
if im.format.lower() in ("jpg", "jpeg"):
|
if im.format.lower() in {"jpg", "jpeg"}:
|
||||||
with open(im_file, "rb") as f:
|
with open(im_file, "rb") as f:
|
||||||
f.seek(-2, 2)
|
f.seek(-2, 2)
|
||||||
if f.read() != b"\xff\xd9": # corrupt JPEG
|
if f.read() != b"\xff\xd9": # corrupt JPEG
|
||||||
|
|
@ -105,8 +106,8 @@ def verify_image_label(args):
|
||||||
shape = exif_size(im) # image size
|
shape = exif_size(im) # image size
|
||||||
shape = (shape[1], shape[0]) # hw
|
shape = (shape[1], shape[0]) # hw
|
||||||
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
||||||
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
|
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}. {FORMATS_HELP_MSG}"
|
||||||
if im.format.lower() in ("jpg", "jpeg"):
|
if im.format.lower() in {"jpg", "jpeg"}:
|
||||||
with open(im_file, "rb") as f:
|
with open(im_file, "rb") as f:
|
||||||
f.seek(-2, 2)
|
f.seek(-2, 2)
|
||||||
if f.read() != b"\xff\xd9": # corrupt JPEG
|
if f.read() != b"\xff\xd9": # corrupt JPEG
|
||||||
|
|
@ -336,7 +337,7 @@ def check_det_dataset(dataset, autodownload=True):
|
||||||
else: # python script
|
else: # python script
|
||||||
exec(s, {"yaml": data})
|
exec(s, {"yaml": data})
|
||||||
dt = f"({round(time.time() - t, 1)}s)"
|
dt = f"({round(time.time() - t, 1)}s)"
|
||||||
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
|
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in {0, None} else f"failure {dt} ❌"
|
||||||
LOGGER.info(f"Dataset download {s}\n")
|
LOGGER.info(f"Dataset download {s}\n")
|
||||||
check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts
|
check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts
|
||||||
|
|
||||||
|
|
@ -366,7 +367,7 @@ def check_cls_dataset(dataset, split=""):
|
||||||
# Download (optional if dataset=https://file.zip is passed directly)
|
# Download (optional if dataset=https://file.zip is passed directly)
|
||||||
if str(dataset).startswith(("http:/", "https:/")):
|
if str(dataset).startswith(("http:/", "https:/")):
|
||||||
dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
|
dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
|
||||||
elif Path(dataset).suffix in (".zip", ".tar", ".gz"):
|
elif Path(dataset).suffix in {".zip", ".tar", ".gz"}:
|
||||||
file = check_file(dataset)
|
file = check_file(dataset)
|
||||||
dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
|
dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -159,7 +159,7 @@ class Exporter:
|
||||||
_callbacks (dict, optional): Dictionary of callback functions. Defaults to None.
|
_callbacks (dict, optional): Dictionary of callback functions. Defaults to None.
|
||||||
"""
|
"""
|
||||||
self.args = get_cfg(cfg, overrides)
|
self.args = get_cfg(cfg, overrides)
|
||||||
if self.args.format.lower() in ("coreml", "mlmodel"): # fix attempt for protobuf<3.20.x errors
|
if self.args.format.lower() in {"coreml", "mlmodel"}: # fix attempt for protobuf<3.20.x errors
|
||||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # must run before TensorBoard callback
|
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # must run before TensorBoard callback
|
||||||
|
|
||||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||||
|
|
@ -171,9 +171,9 @@ class Exporter:
|
||||||
self.run_callbacks("on_export_start")
|
self.run_callbacks("on_export_start")
|
||||||
t = time.time()
|
t = time.time()
|
||||||
fmt = self.args.format.lower() # to lowercase
|
fmt = self.args.format.lower() # to lowercase
|
||||||
if fmt in ("tensorrt", "trt"): # 'engine' aliases
|
if fmt in {"tensorrt", "trt"}: # 'engine' aliases
|
||||||
fmt = "engine"
|
fmt = "engine"
|
||||||
if fmt in ("mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"): # 'coreml' aliases
|
if fmt in {"mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"}: # 'coreml' aliases
|
||||||
fmt = "coreml"
|
fmt = "coreml"
|
||||||
fmts = tuple(export_formats()["Argument"][1:]) # available export formats
|
fmts = tuple(export_formats()["Argument"][1:]) # available export formats
|
||||||
flags = [x == fmt for x in fmts]
|
flags = [x == fmt for x in fmts]
|
||||||
|
|
|
||||||
|
|
@ -145,7 +145,7 @@ class Model(nn.Module):
|
||||||
return
|
return
|
||||||
|
|
||||||
# Load or create new YOLO model
|
# Load or create new YOLO model
|
||||||
if Path(model).suffix in (".yaml", ".yml"):
|
if Path(model).suffix in {".yaml", ".yml"}:
|
||||||
self._new(model, task=task, verbose=verbose)
|
self._new(model, task=task, verbose=verbose)
|
||||||
else:
|
else:
|
||||||
self._load(model, task=task)
|
self._load(model, task=task)
|
||||||
|
|
@ -666,7 +666,7 @@ class Model(nn.Module):
|
||||||
self.trainer.hub_session = self.session # attach optional HUB session
|
self.trainer.hub_session = self.session # attach optional HUB session
|
||||||
self.trainer.train()
|
self.trainer.train()
|
||||||
# Update model and cfg after training
|
# Update model and cfg after training
|
||||||
if RANK in (-1, 0):
|
if RANK in {-1, 0}:
|
||||||
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
|
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
|
||||||
self.model, _ = attempt_load_one_weight(ckpt)
|
self.model, _ = attempt_load_one_weight(ckpt)
|
||||||
self.overrides = self.model.args
|
self.overrides = self.model.args
|
||||||
|
|
|
||||||
|
|
@ -470,7 +470,7 @@ class Boxes(BaseTensor):
|
||||||
if boxes.ndim == 1:
|
if boxes.ndim == 1:
|
||||||
boxes = boxes[None, :]
|
boxes = boxes[None, :]
|
||||||
n = boxes.shape[-1]
|
n = boxes.shape[-1]
|
||||||
assert n in (6, 7), f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls
|
assert n in {6, 7}, f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls
|
||||||
super().__init__(boxes, orig_shape)
|
super().__init__(boxes, orig_shape)
|
||||||
self.is_track = n == 7
|
self.is_track = n == 7
|
||||||
self.orig_shape = orig_shape
|
self.orig_shape = orig_shape
|
||||||
|
|
@ -687,7 +687,7 @@ class OBB(BaseTensor):
|
||||||
if boxes.ndim == 1:
|
if boxes.ndim == 1:
|
||||||
boxes = boxes[None, :]
|
boxes = boxes[None, :]
|
||||||
n = boxes.shape[-1]
|
n = boxes.shape[-1]
|
||||||
assert n in (7, 8), f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls
|
assert n in {7, 8}, f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls
|
||||||
super().__init__(boxes, orig_shape)
|
super().__init__(boxes, orig_shape)
|
||||||
self.is_track = n == 8
|
self.is_track = n == 8
|
||||||
self.orig_shape = orig_shape
|
self.orig_shape = orig_shape
|
||||||
|
|
|
||||||
|
|
@ -107,7 +107,7 @@ class BaseTrainer:
|
||||||
self.save_dir = get_save_dir(self.args)
|
self.save_dir = get_save_dir(self.args)
|
||||||
self.args.name = self.save_dir.name # update name for loggers
|
self.args.name = self.save_dir.name # update name for loggers
|
||||||
self.wdir = self.save_dir / "weights" # weights dir
|
self.wdir = self.save_dir / "weights" # weights dir
|
||||||
if RANK in (-1, 0):
|
if RANK in {-1, 0}:
|
||||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||||
self.args.save_dir = str(self.save_dir)
|
self.args.save_dir = str(self.save_dir)
|
||||||
yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
|
yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
|
||||||
|
|
@ -121,7 +121,7 @@ class BaseTrainer:
|
||||||
print_args(vars(self.args))
|
print_args(vars(self.args))
|
||||||
|
|
||||||
# Device
|
# Device
|
||||||
if self.device.type in ("cpu", "mps"):
|
if self.device.type in {"cpu", "mps"}:
|
||||||
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
||||||
|
|
||||||
# Model and Dataset
|
# Model and Dataset
|
||||||
|
|
@ -144,7 +144,7 @@ class BaseTrainer:
|
||||||
|
|
||||||
# Callbacks
|
# Callbacks
|
||||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||||
if RANK in (-1, 0):
|
if RANK in {-1, 0}:
|
||||||
callbacks.add_integration_callbacks(self)
|
callbacks.add_integration_callbacks(self)
|
||||||
|
|
||||||
def add_callback(self, event: str, callback):
|
def add_callback(self, event: str, callback):
|
||||||
|
|
@ -251,7 +251,7 @@ class BaseTrainer:
|
||||||
|
|
||||||
# Check AMP
|
# Check AMP
|
||||||
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
|
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
|
||||||
if self.amp and RANK in (-1, 0): # Single-GPU and DDP
|
if self.amp and RANK in {-1, 0}: # Single-GPU and DDP
|
||||||
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
|
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
|
||||||
self.amp = torch.tensor(check_amp(self.model), device=self.device)
|
self.amp = torch.tensor(check_amp(self.model), device=self.device)
|
||||||
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
||||||
|
|
@ -274,7 +274,7 @@ class BaseTrainer:
|
||||||
# Dataloaders
|
# Dataloaders
|
||||||
batch_size = self.batch_size // max(world_size, 1)
|
batch_size = self.batch_size // max(world_size, 1)
|
||||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
|
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
|
||||||
if RANK in (-1, 0):
|
if RANK in {-1, 0}:
|
||||||
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
|
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
|
||||||
self.test_loader = self.get_dataloader(
|
self.test_loader = self.get_dataloader(
|
||||||
self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
|
self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
|
||||||
|
|
@ -340,7 +340,7 @@ class BaseTrainer:
|
||||||
self._close_dataloader_mosaic()
|
self._close_dataloader_mosaic()
|
||||||
self.train_loader.reset()
|
self.train_loader.reset()
|
||||||
|
|
||||||
if RANK in (-1, 0):
|
if RANK in {-1, 0}:
|
||||||
LOGGER.info(self.progress_string())
|
LOGGER.info(self.progress_string())
|
||||||
pbar = TQDM(enumerate(self.train_loader), total=nb)
|
pbar = TQDM(enumerate(self.train_loader), total=nb)
|
||||||
self.tloss = None
|
self.tloss = None
|
||||||
|
|
@ -392,7 +392,7 @@ class BaseTrainer:
|
||||||
mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
|
mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
|
||||||
loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1
|
loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1
|
||||||
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
||||||
if RANK in (-1, 0):
|
if RANK in {-1, 0}:
|
||||||
pbar.set_description(
|
pbar.set_description(
|
||||||
("%11s" * 2 + "%11.4g" * (2 + loss_len))
|
("%11s" * 2 + "%11.4g" * (2 + loss_len))
|
||||||
% (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
|
% (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
|
||||||
|
|
@ -405,7 +405,7 @@ class BaseTrainer:
|
||||||
|
|
||||||
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
||||||
self.run_callbacks("on_train_epoch_end")
|
self.run_callbacks("on_train_epoch_end")
|
||||||
if RANK in (-1, 0):
|
if RANK in {-1, 0}:
|
||||||
final_epoch = epoch + 1 >= self.epochs
|
final_epoch = epoch + 1 >= self.epochs
|
||||||
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
|
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
|
||||||
|
|
||||||
|
|
@ -447,7 +447,7 @@ class BaseTrainer:
|
||||||
break # must break all DDP ranks
|
break # must break all DDP ranks
|
||||||
epoch += 1
|
epoch += 1
|
||||||
|
|
||||||
if RANK in (-1, 0):
|
if RANK in {-1, 0}:
|
||||||
# Do final val with best.pt
|
# Do final val with best.pt
|
||||||
LOGGER.info(
|
LOGGER.info(
|
||||||
f"\n{epoch - self.start_epoch + 1} epochs completed in "
|
f"\n{epoch - self.start_epoch + 1} epochs completed in "
|
||||||
|
|
@ -503,12 +503,12 @@ class BaseTrainer:
|
||||||
try:
|
try:
|
||||||
if self.args.task == "classify":
|
if self.args.task == "classify":
|
||||||
data = check_cls_dataset(self.args.data)
|
data = check_cls_dataset(self.args.data)
|
||||||
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
|
elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in {
|
||||||
"detect",
|
"detect",
|
||||||
"segment",
|
"segment",
|
||||||
"pose",
|
"pose",
|
||||||
"obb",
|
"obb",
|
||||||
):
|
}:
|
||||||
data = check_det_dataset(self.args.data)
|
data = check_det_dataset(self.args.data)
|
||||||
if "yaml_file" in data:
|
if "yaml_file" in data:
|
||||||
self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
||||||
|
|
@ -740,7 +740,7 @@ class BaseTrainer:
|
||||||
else: # weight (with decay)
|
else: # weight (with decay)
|
||||||
g[0].append(param)
|
g[0].append(param)
|
||||||
|
|
||||||
if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
|
if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
|
||||||
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
||||||
elif name == "RMSProp":
|
elif name == "RMSProp":
|
||||||
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
||||||
|
|
|
||||||
|
|
@ -139,14 +139,14 @@ class BaseValidator:
|
||||||
self.args.batch = 1 # export.py models default to batch-size 1
|
self.args.batch = 1 # export.py models default to batch-size 1
|
||||||
LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
|
LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
|
||||||
|
|
||||||
if str(self.args.data).split(".")[-1] in ("yaml", "yml"):
|
if str(self.args.data).split(".")[-1] in {"yaml", "yml"}:
|
||||||
self.data = check_det_dataset(self.args.data)
|
self.data = check_det_dataset(self.args.data)
|
||||||
elif self.args.task == "classify":
|
elif self.args.task == "classify":
|
||||||
self.data = check_cls_dataset(self.args.data, split=self.args.split)
|
self.data = check_cls_dataset(self.args.data, split=self.args.split)
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
||||||
|
|
||||||
if self.device.type in ("cpu", "mps"):
|
if self.device.type in {"cpu", "mps"}:
|
||||||
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
||||||
if not pt:
|
if not pt:
|
||||||
self.args.rect = False
|
self.args.rect = False
|
||||||
|
|
|
||||||
|
|
@ -198,7 +198,7 @@ class Events:
|
||||||
}
|
}
|
||||||
self.enabled = (
|
self.enabled = (
|
||||||
SETTINGS["sync"]
|
SETTINGS["sync"]
|
||||||
and RANK in (-1, 0)
|
and RANK in {-1, 0}
|
||||||
and not TESTS_RUNNING
|
and not TESTS_RUNNING
|
||||||
and ONLINE
|
and ONLINE
|
||||||
and (is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git")
|
and (is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git")
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ class FastSAM(Model):
|
||||||
"""Call the __init__ method of the parent class (YOLO) with the updated default model."""
|
"""Call the __init__ method of the parent class (YOLO) with the updated default model."""
|
||||||
if str(model) == "FastSAM.pt":
|
if str(model) == "FastSAM.pt":
|
||||||
model = "FastSAM-x.pt"
|
model = "FastSAM-x.pt"
|
||||||
assert Path(model).suffix not in (".yaml", ".yml"), "FastSAM models only support pre-trained models."
|
assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models."
|
||||||
super().__init__(model=model, task="segment")
|
super().__init__(model=model, task="segment")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ class NAS(Model):
|
||||||
|
|
||||||
def __init__(self, model="yolo_nas_s.pt") -> None:
|
def __init__(self, model="yolo_nas_s.pt") -> None:
|
||||||
"""Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
|
"""Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
|
||||||
assert Path(model).suffix not in (".yaml", ".yml"), "YOLO-NAS models only support pre-trained models."
|
assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."
|
||||||
super().__init__(model, task="detect")
|
super().__init__(model, task="detect")
|
||||||
|
|
||||||
@smart_inference_mode()
|
@smart_inference_mode()
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ class SAM(Model):
|
||||||
Raises:
|
Raises:
|
||||||
NotImplementedError: If the model file extension is not .pt or .pth.
|
NotImplementedError: If the model file extension is not .pt or .pth.
|
||||||
"""
|
"""
|
||||||
if model and Path(model).suffix not in (".pt", ".pth"):
|
if model and Path(model).suffix not in {".pt", ".pth"}:
|
||||||
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
|
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
|
||||||
super().__init__(model=model, task="segment")
|
super().__init__(model=model, task="segment")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -112,7 +112,7 @@ class PatchMerging(nn.Module):
|
||||||
self.out_dim = out_dim
|
self.out_dim = out_dim
|
||||||
self.act = activation()
|
self.act = activation()
|
||||||
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
|
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
|
||||||
stride_c = 1 if out_dim in [320, 448, 576] else 2
|
stride_c = 1 if out_dim in {320, 448, 576} else 2
|
||||||
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
|
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
|
||||||
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ class ClassificationTrainer(BaseTrainer):
|
||||||
self.model, ckpt = attempt_load_one_weight(model, device="cpu")
|
self.model, ckpt = attempt_load_one_weight(model, device="cpu")
|
||||||
for p in self.model.parameters():
|
for p in self.model.parameters():
|
||||||
p.requires_grad = True # for training
|
p.requires_grad = True # for training
|
||||||
elif model.split(".")[-1] in ("yaml", "yml"):
|
elif model.split(".")[-1] in {"yaml", "yml"}:
|
||||||
self.model = self.get_model(cfg=model)
|
self.model = self.get_model(cfg=model)
|
||||||
elif model in torchvision.models.__dict__:
|
elif model in torchvision.models.__dict__:
|
||||||
self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None)
|
self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None)
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ class DetectionTrainer(BaseTrainer):
|
||||||
|
|
||||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
||||||
"""Construct and return dataloader."""
|
"""Construct and return dataloader."""
|
||||||
assert mode in ["train", "val"]
|
assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}."
|
||||||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||||
dataset = self.build_dataset(dataset_path, mode, batch_size)
|
dataset = self.build_dataset(dataset_path, mode, batch_size)
|
||||||
shuffle = mode == "train"
|
shuffle = mode == "train"
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from ultralytics.utils.torch_utils import de_parallel
|
||||||
|
|
||||||
def on_pretrain_routine_end(trainer):
|
def on_pretrain_routine_end(trainer):
|
||||||
"""Callback."""
|
"""Callback."""
|
||||||
if RANK in (-1, 0):
|
if RANK in {-1, 0}:
|
||||||
# NOTE: for evaluation
|
# NOTE: for evaluation
|
||||||
names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
|
names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
|
||||||
de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
|
de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
|
||||||
|
|
|
||||||
|
|
@ -374,9 +374,9 @@ class AutoBackend(nn.Module):
|
||||||
metadata = yaml_load(metadata)
|
metadata = yaml_load(metadata)
|
||||||
if metadata:
|
if metadata:
|
||||||
for k, v in metadata.items():
|
for k, v in metadata.items():
|
||||||
if k in ("stride", "batch"):
|
if k in {"stride", "batch"}:
|
||||||
metadata[k] = int(v)
|
metadata[k] = int(v)
|
||||||
elif k in ("imgsz", "names", "kpt_shape") and isinstance(v, str):
|
elif k in {"imgsz", "names", "kpt_shape"} and isinstance(v, str):
|
||||||
metadata[k] = eval(v)
|
metadata[k] = eval(v)
|
||||||
stride = metadata["stride"]
|
stride = metadata["stride"]
|
||||||
task = metadata["task"]
|
task = metadata["task"]
|
||||||
|
|
@ -531,8 +531,8 @@ class AutoBackend(nn.Module):
|
||||||
self.names = {i: f"class{i}" for i in range(nc)}
|
self.names = {i: f"class{i}" for i in range(nc)}
|
||||||
else: # Lite or Edge TPU
|
else: # Lite or Edge TPU
|
||||||
details = self.input_details[0]
|
details = self.input_details[0]
|
||||||
integer = details["dtype"] in (np.int8, np.int16) # is TFLite quantized int8 or int16 model
|
is_int = details["dtype"] in {np.int8, np.int16} # is TFLite quantized int8 or int16 model
|
||||||
if integer:
|
if is_int:
|
||||||
scale, zero_point = details["quantization"]
|
scale, zero_point = details["quantization"]
|
||||||
im = (im / scale + zero_point).astype(details["dtype"]) # de-scale
|
im = (im / scale + zero_point).astype(details["dtype"]) # de-scale
|
||||||
self.interpreter.set_tensor(details["index"], im)
|
self.interpreter.set_tensor(details["index"], im)
|
||||||
|
|
@ -540,7 +540,7 @@ class AutoBackend(nn.Module):
|
||||||
y = []
|
y = []
|
||||||
for output in self.output_details:
|
for output in self.output_details:
|
||||||
x = self.interpreter.get_tensor(output["index"])
|
x = self.interpreter.get_tensor(output["index"])
|
||||||
if integer:
|
if is_int:
|
||||||
scale, zero_point = output["quantization"]
|
scale, zero_point = output["quantization"]
|
||||||
x = (x.astype(np.float32) - zero_point) * scale # re-scale
|
x = (x.astype(np.float32) - zero_point) * scale # re-scale
|
||||||
if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well
|
if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well
|
||||||
|
|
|
||||||
|
|
@ -296,7 +296,7 @@ class SpatialAttention(nn.Module):
|
||||||
def __init__(self, kernel_size=7):
|
def __init__(self, kernel_size=7):
|
||||||
"""Initialize Spatial-attention module with kernel size argument."""
|
"""Initialize Spatial-attention module with kernel size argument."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert kernel_size in (3, 7), "kernel size must be 3 or 7"
|
assert kernel_size in {3, 7}, "kernel size must be 3 or 7"
|
||||||
padding = 3 if kernel_size == 7 else 1
|
padding = 3 if kernel_size == 7 else 1
|
||||||
self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
||||||
self.act = nn.Sigmoid()
|
self.act = nn.Sigmoid()
|
||||||
|
|
|
||||||
|
|
@ -54,13 +54,13 @@ class Detect(nn.Module):
|
||||||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
|
|
||||||
if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV ops
|
if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
|
||||||
box = x_cat[:, : self.reg_max * 4]
|
box = x_cat[:, : self.reg_max * 4]
|
||||||
cls = x_cat[:, self.reg_max * 4 :]
|
cls = x_cat[:, self.reg_max * 4 :]
|
||||||
else:
|
else:
|
||||||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||||||
|
|
||||||
if self.export and self.format in ("tflite", "edgetpu"):
|
if self.export and self.format in {"tflite", "edgetpu"}:
|
||||||
# Precompute normalization factor to increase numerical stability
|
# Precompute normalization factor to increase numerical stability
|
||||||
# See https://github.com/ultralytics/ultralytics/issues/7371
|
# See https://github.com/ultralytics/ultralytics/issues/7371
|
||||||
grid_h = shape[2]
|
grid_h = shape[2]
|
||||||
|
|
@ -230,13 +230,13 @@ class WorldDetect(Detect):
|
||||||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||||
self.shape = shape
|
self.shape = shape
|
||||||
|
|
||||||
if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV ops
|
if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
|
||||||
box = x_cat[:, : self.reg_max * 4]
|
box = x_cat[:, : self.reg_max * 4]
|
||||||
cls = x_cat[:, self.reg_max * 4 :]
|
cls = x_cat[:, self.reg_max * 4 :]
|
||||||
else:
|
else:
|
||||||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||||||
|
|
||||||
if self.export and self.format in ("tflite", "edgetpu"):
|
if self.export and self.format in {"tflite", "edgetpu"}:
|
||||||
# Precompute normalization factor to increase numerical stability
|
# Precompute normalization factor to increase numerical stability
|
||||||
# See https://github.com/ultralytics/ultralytics/issues/7371
|
# See https://github.com/ultralytics/ultralytics/issues/7371
|
||||||
grid_h = shape[2]
|
grid_h = shape[2]
|
||||||
|
|
|
||||||
|
|
@ -896,7 +896,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||||
) # num heads
|
) # num heads
|
||||||
|
|
||||||
args = [c1, c2, *args[1:]]
|
args = [c1, c2, *args[1:]]
|
||||||
if m in (BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3):
|
if m in {BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3}:
|
||||||
args.insert(2, n) # number of repeats
|
args.insert(2, n) # number of repeats
|
||||||
n = 1
|
n = 1
|
||||||
elif m is AIFI:
|
elif m is AIFI:
|
||||||
|
|
|
||||||
|
|
@ -81,7 +81,7 @@ class AIGym:
|
||||||
self.annotator = Annotator(im0, line_width=2)
|
self.annotator = Annotator(im0, line_width=2)
|
||||||
|
|
||||||
for ind, k in enumerate(reversed(self.keypoints)):
|
for ind, k in enumerate(reversed(self.keypoints)):
|
||||||
if self.pose_type in ["pushup", "pullup"]:
|
if self.pose_type in {"pushup", "pullup"}:
|
||||||
self.angle[ind] = self.annotator.estimate_pose_angle(
|
self.angle[ind] = self.annotator.estimate_pose_angle(
|
||||||
k[int(self.kpts_to_check[0])].cpu(),
|
k[int(self.kpts_to_check[0])].cpu(),
|
||||||
k[int(self.kpts_to_check[1])].cpu(),
|
k[int(self.kpts_to_check[1])].cpu(),
|
||||||
|
|
|
||||||
|
|
@ -153,7 +153,7 @@ class Heatmap:
|
||||||
self.cls_txtdisplay_gap = cls_txtdisplay_gap
|
self.cls_txtdisplay_gap = cls_txtdisplay_gap
|
||||||
|
|
||||||
# shape of heatmap, if not selected
|
# shape of heatmap, if not selected
|
||||||
if self.shape not in ["circle", "rect"]:
|
if self.shape not in {"circle", "rect"}:
|
||||||
print("Unknown shape value provided, 'circle' & 'rect' supported")
|
print("Unknown shape value provided, 'circle' & 'rect' supported")
|
||||||
print("Using Circular shape now")
|
print("Using Circular shape now")
|
||||||
self.shape = "circle"
|
self.shape = "circle"
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ class STrack(BaseTrack):
|
||||||
"""Initialize new STrack instance."""
|
"""Initialize new STrack instance."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# xywh+idx or xywha+idx
|
# xywh+idx or xywha+idx
|
||||||
assert len(xywh) in [5, 6], f"expected 5 or 6 values but got {len(xywh)}"
|
assert len(xywh) in {5, 6}, f"expected 5 or 6 values but got {len(xywh)}"
|
||||||
self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32)
|
self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32)
|
||||||
self.kalman_filter = None
|
self.kalman_filter = None
|
||||||
self.mean, self.covariance = None, None
|
self.mean, self.covariance = None, None
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
|
||||||
tracker = check_yaml(predictor.args.tracker)
|
tracker = check_yaml(predictor.args.tracker)
|
||||||
cfg = IterableSimpleNamespace(**yaml_load(tracker))
|
cfg = IterableSimpleNamespace(**yaml_load(tracker))
|
||||||
|
|
||||||
if cfg.tracker_type not in ["bytetrack", "botsort"]:
|
if cfg.tracker_type not in {"bytetrack", "botsort"}:
|
||||||
raise AssertionError(f"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'")
|
raise AssertionError(f"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'")
|
||||||
|
|
||||||
trackers = []
|
trackers = []
|
||||||
|
|
|
||||||
|
|
@ -94,7 +94,7 @@ class GMC:
|
||||||
array([[1, 2, 3],
|
array([[1, 2, 3],
|
||||||
[4, 5, 6]])
|
[4, 5, 6]])
|
||||||
"""
|
"""
|
||||||
if self.method in ["orb", "sift"]:
|
if self.method in {"orb", "sift"}:
|
||||||
return self.applyFeatures(raw_frame, detections)
|
return self.applyFeatures(raw_frame, detections)
|
||||||
elif self.method == "ecc":
|
elif self.method == "ecc":
|
||||||
return self.applyEcc(raw_frame)
|
return self.applyEcc(raw_frame)
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ VERBOSE = str(os.getenv("YOLO_VERBOSE", True)).lower() == "true" # global verbo
|
||||||
TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar format
|
TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar format
|
||||||
LOGGING_NAME = "ultralytics"
|
LOGGING_NAME = "ultralytics"
|
||||||
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans
|
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans
|
||||||
ARM64 = platform.machine() in ("arm64", "aarch64") # ARM64 booleans
|
ARM64 = platform.machine() in {"arm64", "aarch64"} # ARM64 booleans
|
||||||
HELP_MSG = """
|
HELP_MSG = """
|
||||||
Usage examples for running YOLOv8:
|
Usage examples for running YOLOv8:
|
||||||
|
|
||||||
|
|
@ -359,7 +359,7 @@ def yaml_load(file="data.yaml", append_filename=False):
|
||||||
Returns:
|
Returns:
|
||||||
(dict): YAML data and file name.
|
(dict): YAML data and file name.
|
||||||
"""
|
"""
|
||||||
assert Path(file).suffix in (".yaml", ".yml"), f"Attempting to load non-YAML file {file} with yaml_load()"
|
assert Path(file).suffix in {".yaml", ".yml"}, f"Attempting to load non-YAML file {file} with yaml_load()"
|
||||||
with open(file, errors="ignore", encoding="utf-8") as f:
|
with open(file, errors="ignore", encoding="utf-8") as f:
|
||||||
s = f.read() # string
|
s = f.read() # string
|
||||||
|
|
||||||
|
|
@ -866,7 +866,7 @@ def set_sentry():
|
||||||
"""
|
"""
|
||||||
if "exc_info" in hint:
|
if "exc_info" in hint:
|
||||||
exc_type, exc_value, tb = hint["exc_info"]
|
exc_type, exc_value, tb = hint["exc_info"]
|
||||||
if exc_type in (KeyboardInterrupt, FileNotFoundError) or "out of memory" in str(exc_value):
|
if exc_type in {KeyboardInterrupt, FileNotFoundError} or "out of memory" in str(exc_value):
|
||||||
return None # do not send event
|
return None # do not send event
|
||||||
|
|
||||||
event["tags"] = {
|
event["tags"] = {
|
||||||
|
|
@ -879,7 +879,7 @@ def set_sentry():
|
||||||
|
|
||||||
if (
|
if (
|
||||||
SETTINGS["sync"]
|
SETTINGS["sync"]
|
||||||
and RANK in (-1, 0)
|
and RANK in {-1, 0}
|
||||||
and Path(ARGV[0]).name == "yolo"
|
and Path(ARGV[0]).name == "yolo"
|
||||||
and not TESTS_RUNNING
|
and not TESTS_RUNNING
|
||||||
and ONLINE
|
and ONLINE
|
||||||
|
|
|
||||||
|
|
@ -115,7 +115,7 @@ def benchmark(
|
||||||
|
|
||||||
# Predict
|
# Predict
|
||||||
assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported"
|
assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported"
|
||||||
assert i not in (9, 10), "inference not supported" # Edge TPU and TF.js are unsupported
|
assert i not in {9, 10}, "inference not supported" # Edge TPU and TF.js are unsupported
|
||||||
assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML
|
assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML
|
||||||
exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half)
|
exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half)
|
||||||
|
|
||||||
|
|
@ -220,7 +220,7 @@ class ProfileModels:
|
||||||
output = []
|
output = []
|
||||||
for file in files:
|
for file in files:
|
||||||
engine_file = file.with_suffix(".engine")
|
engine_file = file.with_suffix(".engine")
|
||||||
if file.suffix in (".pt", ".yaml", ".yml"):
|
if file.suffix in {".pt", ".yaml", ".yml"}:
|
||||||
model = YOLO(str(file))
|
model = YOLO(str(file))
|
||||||
model.fuse() # to report correct params and GFLOPs in model.info()
|
model.fuse() # to report correct params and GFLOPs in model.info()
|
||||||
model_info = model.info()
|
model_info = model.info()
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ def _get_experiment_type(mode, project_name):
|
||||||
|
|
||||||
def _create_experiment(args):
|
def _create_experiment(args):
|
||||||
"""Ensures that the experiment object is only created in a single process during distributed training."""
|
"""Ensures that the experiment object is only created in a single process during distributed training."""
|
||||||
if RANK not in (-1, 0):
|
if RANK not in {-1, 0}:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
comet_mode = _get_comet_mode()
|
comet_mode = _get_comet_mode()
|
||||||
|
|
|
||||||
|
|
@ -108,7 +108,7 @@ def on_train_end(trainer):
|
||||||
for f in trainer.save_dir.glob("*"): # log all other files in save_dir
|
for f in trainer.save_dir.glob("*"): # log all other files in save_dir
|
||||||
if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
|
if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
|
||||||
mlflow.log_artifact(str(f))
|
mlflow.log_artifact(str(f))
|
||||||
keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() in ("true")
|
keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() == "true"
|
||||||
if keep_run_active:
|
if keep_run_active:
|
||||||
LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()")
|
LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()")
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -237,7 +237,7 @@ def check_version(
|
||||||
result = False
|
result = False
|
||||||
elif op == "!=" and c == v:
|
elif op == "!=" and c == v:
|
||||||
result = False
|
result = False
|
||||||
elif op in (">=", "") and not (c >= v): # if no constraint passed assume '>=required'
|
elif op in {">=", ""} and not (c >= v): # if no constraint passed assume '>=required'
|
||||||
result = False
|
result = False
|
||||||
elif op == "<=" and not (c <= v):
|
elif op == "<=" and not (c <= v):
|
||||||
result = False
|
result = False
|
||||||
|
|
@ -632,7 +632,7 @@ def check_amp(model):
|
||||||
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
|
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
|
||||||
"""
|
"""
|
||||||
device = next(model.parameters()).device # get model device
|
device = next(model.parameters()).device # get model device
|
||||||
if device.type in ("cpu", "mps"):
|
if device.type in {"cpu", "mps"}:
|
||||||
return False # AMP only used on CUDA devices
|
return False # AMP only used on CUDA devices
|
||||||
|
|
||||||
def amp_allclose(m, im):
|
def amp_allclose(m, im):
|
||||||
|
|
|
||||||
|
|
@ -356,13 +356,13 @@ def safe_download(
|
||||||
raise ConnectionError(emojis(f"❌ Download failure for {url}. Retry limit reached.")) from e
|
raise ConnectionError(emojis(f"❌ Download failure for {url}. Retry limit reached.")) from e
|
||||||
LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {url}...")
|
LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {url}...")
|
||||||
|
|
||||||
if unzip and f.exists() and f.suffix in ("", ".zip", ".tar", ".gz"):
|
if unzip and f.exists() and f.suffix in {"", ".zip", ".tar", ".gz"}:
|
||||||
from zipfile import is_zipfile
|
from zipfile import is_zipfile
|
||||||
|
|
||||||
unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place
|
unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place
|
||||||
if is_zipfile(f):
|
if is_zipfile(f):
|
||||||
unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip
|
unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip
|
||||||
elif f.suffix in (".tar", ".gz"):
|
elif f.suffix in {".tar", ".gz"}:
|
||||||
LOGGER.info(f"Unzipping {f} to {unzip_dir}...")
|
LOGGER.info(f"Unzipping {f} to {unzip_dir}...")
|
||||||
subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True)
|
subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True)
|
||||||
if delete:
|
if delete:
|
||||||
|
|
|
||||||
|
|
@ -298,7 +298,7 @@ class ConfusionMatrix:
|
||||||
self.task = task
|
self.task = task
|
||||||
self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
|
self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
|
||||||
self.nc = nc # number of classes
|
self.nc = nc # number of classes
|
||||||
self.conf = 0.25 if conf in (None, 0.001) else conf # apply 0.25 if default val conf is passed
|
self.conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passed
|
||||||
self.iou_thres = iou_thres
|
self.iou_thres = iou_thres
|
||||||
|
|
||||||
def process_cls_preds(self, preds, targets):
|
def process_cls_preds(self, preds, targets):
|
||||||
|
|
|
||||||
|
|
@ -904,7 +904,7 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
|
||||||
ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
|
ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
|
||||||
ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
|
ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
|
||||||
ax[i].set_title(s[j], fontsize=12)
|
ax[i].set_title(s[j], fontsize=12)
|
||||||
# if j in [8, 9, 10]: # share train and val loss y axes
|
# if j in {8, 9, 10}: # share train and val loss y axes
|
||||||
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
|
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.warning(f"WARNING: Plotting error for {f}: {e}")
|
LOGGER.warning(f"WARNING: Plotting error for {f}: {e}")
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ TORCHVISION_0_13 = check_version(torchvision.__version__, "0.13.0")
|
||||||
def torch_distributed_zero_first(local_rank: int):
|
def torch_distributed_zero_first(local_rank: int):
|
||||||
"""Decorator to make all processes in distributed training wait for each local_master to do something."""
|
"""Decorator to make all processes in distributed training wait for each local_master to do something."""
|
||||||
initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
|
initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||||
if initialized and local_rank not in (-1, 0):
|
if initialized and local_rank not in {-1, 0}:
|
||||||
dist.barrier(device_ids=[local_rank])
|
dist.barrier(device_ids=[local_rank])
|
||||||
yield
|
yield
|
||||||
if initialized and local_rank == 0:
|
if initialized and local_rank == 0:
|
||||||
|
|
@ -109,7 +109,7 @@ def select_device(device="", batch=0, newline=False, verbose=True):
|
||||||
for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
|
for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
|
||||||
device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
|
device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
|
||||||
cpu = device == "cpu"
|
cpu = device == "cpu"
|
||||||
mps = device in ("mps", "mps:0") # Apple Metal Performance Shaders (MPS)
|
mps = device in {"mps", "mps:0"} # Apple Metal Performance Shaders (MPS)
|
||||||
if cpu or mps:
|
if cpu or mps:
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
|
||||||
elif device: # non-cpu device requested
|
elif device: # non-cpu device requested
|
||||||
|
|
@ -347,7 +347,7 @@ def initialize_weights(model):
|
||||||
elif t is nn.BatchNorm2d:
|
elif t is nn.BatchNorm2d:
|
||||||
m.eps = 1e-3
|
m.eps = 1e-3
|
||||||
m.momentum = 0.03
|
m.momentum = 0.03
|
||||||
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
|
elif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}:
|
||||||
m.inplace = True
|
m.inplace = True
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue