ultralytics 8.1.40 search in Python sets {} for speed (#9450)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
30484d5925
commit
ea527507fe
41 changed files with 97 additions and 93 deletions
|
|
@ -351,7 +351,7 @@ def test_labels_and_crops():
|
|||
crop_dirs = [p for p in (save_path / "crops").iterdir()]
|
||||
crop_files = [f for p in crop_dirs for f in p.glob("*")]
|
||||
# Crop directories match detections
|
||||
assert all([r.names.get(c) in [d.name for d in crop_dirs] for c in cls_idxs])
|
||||
assert all([r.names.get(c) in {d.name for d in crop_dirs} for c in cls_idxs])
|
||||
# Same number of crops as detections
|
||||
assert len([f for f in crop_files if im_name in f.name]) == len(r.boxes.data)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.1.39"
|
||||
__version__ = "8.1.40"
|
||||
|
||||
from ultralytics.data.explorer.explorer import Explorer
|
||||
from ultralytics.models import RTDETR, SAM, YOLO, YOLOWorld
|
||||
|
|
|
|||
|
|
@ -272,7 +272,7 @@ def get_save_dir(args, name=None):
|
|||
|
||||
project = args.project or (ROOT.parent / "tests/tmp/runs" if TESTS_RUNNING else RUNS_DIR) / args.task
|
||||
name = name or args.name or f"{args.mode}"
|
||||
save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in (-1, 0) else True)
|
||||
save_dir = increment_path(Path(project) / name, exist_ok=args.exist_ok if RANK in {-1, 0} else True)
|
||||
|
||||
return Path(save_dir)
|
||||
|
||||
|
|
@ -566,10 +566,10 @@ def entrypoint(debug=""):
|
|||
task = model.task
|
||||
|
||||
# Mode
|
||||
if mode in ("predict", "track") and "source" not in overrides:
|
||||
if mode in {"predict", "track"} and "source" not in overrides:
|
||||
overrides["source"] = DEFAULT_CFG.source or ASSETS
|
||||
LOGGER.warning(f"WARNING ⚠️ 'source' argument is missing. Using default 'source={overrides['source']}'.")
|
||||
elif mode in ("train", "val"):
|
||||
elif mode in {"train", "val"}:
|
||||
if "data" not in overrides and "resume" not in overrides:
|
||||
overrides["data"] = DEFAULT_CFG.data or TASK2DATA.get(task or DEFAULT_CFG.task, DEFAULT_CFG.data)
|
||||
LOGGER.warning(f"WARNING ⚠️ 'data' argument is missing. Using default 'data={overrides['data']}'.")
|
||||
|
|
|
|||
|
|
@ -191,7 +191,7 @@ class Mosaic(BaseMixTransform):
|
|||
def __init__(self, dataset, imgsz=640, p=1.0, n=4):
|
||||
"""Initializes the object with a dataset, image size, probability, and border."""
|
||||
assert 0 <= p <= 1.0, f"The probability should be in range [0, 1], but got {p}."
|
||||
assert n in (4, 9), "grid must be equal to 4 or 9."
|
||||
assert n in {4, 9}, "grid must be equal to 4 or 9."
|
||||
super().__init__(dataset=dataset, p=p)
|
||||
self.dataset = dataset
|
||||
self.imgsz = imgsz
|
||||
|
|
@ -685,7 +685,7 @@ class RandomFlip:
|
|||
Default is 'horizontal'.
|
||||
flip_idx (array-like, optional): Index mapping for flipping keypoints, if any.
|
||||
"""
|
||||
assert direction in ["horizontal", "vertical"], f"Support direction `horizontal` or `vertical`, got {direction}"
|
||||
assert direction in {"horizontal", "vertical"}, f"Support direction `horizontal` or `vertical`, got {direction}"
|
||||
assert 0 <= p <= 1.0
|
||||
|
||||
self.p = p
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import psutil
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM
|
||||
from .utils import HELP_URL, IMG_FORMATS
|
||||
from .utils import HELP_URL, FORMATS_HELP_MSG, IMG_FORMATS
|
||||
|
||||
|
||||
class BaseDataset(Dataset):
|
||||
|
|
@ -118,7 +118,7 @@ class BaseDataset(Dataset):
|
|||
raise FileNotFoundError(f"{self.prefix}{p} does not exist")
|
||||
im_files = sorted(x.replace("/", os.sep) for x in f if x.split(".")[-1].lower() in IMG_FORMATS)
|
||||
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
||||
assert im_files, f"{self.prefix}No images found in {img_path}"
|
||||
assert im_files, f"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}"
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
|
||||
if self.fraction < 1:
|
||||
|
|
|
|||
|
|
@ -481,7 +481,7 @@ def merge_multi_segment(segments):
|
|||
segments[i] = np.roll(segments[i], -idx[0], axis=0)
|
||||
segments[i] = np.concatenate([segments[i], segments[i][:1]])
|
||||
# Deal with the first segment and the last one
|
||||
if i in [0, len(idx_list) - 1]:
|
||||
if i in {0, len(idx_list) - 1}:
|
||||
s.append(segments[i])
|
||||
else:
|
||||
idx = [0, idx[1] - idx[0]]
|
||||
|
|
@ -489,7 +489,7 @@ def merge_multi_segment(segments):
|
|||
|
||||
else:
|
||||
for i in range(len(idx_list) - 1, -1, -1):
|
||||
if i not in [0, len(idx_list) - 1]:
|
||||
if i not in {0, len(idx_list) - 1}:
|
||||
idx = idx_list[i]
|
||||
nidx = abs(idx[1] - idx[0])
|
||||
s.append(segments[i][nidx:])
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ class YOLODataset(BaseDataset):
|
|||
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
||||
total = len(self.im_files)
|
||||
nkpt, ndim = self.data.get("kpt_shape", (0, 0))
|
||||
if self.use_keypoints and (nkpt <= 0 or ndim not in (2, 3)):
|
||||
if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):
|
||||
raise ValueError(
|
||||
"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
|
||||
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
|
||||
|
|
@ -142,7 +142,7 @@ class YOLODataset(BaseDataset):
|
|||
|
||||
# Display cache
|
||||
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
||||
if exists and LOCAL_RANK in (-1, 0):
|
||||
if exists and LOCAL_RANK in {-1, 0}:
|
||||
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
|
||||
if cache["msgs"]:
|
||||
|
|
@ -235,7 +235,7 @@ class YOLODataset(BaseDataset):
|
|||
value = values[i]
|
||||
if k == "img":
|
||||
value = torch.stack(value, 0)
|
||||
if k in ["masks", "keypoints", "bboxes", "cls", "segments", "obb"]:
|
||||
if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
|
||||
value = torch.cat(value, 0)
|
||||
new_batch[k] = value
|
||||
new_batch["batch_idx"] = list(new_batch["batch_idx"])
|
||||
|
|
@ -334,7 +334,7 @@ class ClassificationDataset(torchvision.datasets.ImageFolder):
|
|||
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
||||
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
||||
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
|
||||
if LOCAL_RANK in (-1, 0):
|
||||
if LOCAL_RANK in {-1, 0}:
|
||||
d = f"{desc} {nf} images, {nc} corrupt"
|
||||
TQDM(None, desc=d, total=n, initial=n)
|
||||
if cache["msgs"]:
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ import requests
|
|||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS, FORMATS_HELP_MSG
|
||||
from ultralytics.utils import LOGGER, is_colab, is_kaggle, ops
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
|
|
@ -83,7 +83,7 @@ class LoadStreams:
|
|||
for i, s in enumerate(sources): # index, source
|
||||
# Start thread to read frames from video stream
|
||||
st = f"{i + 1}/{n}: {s}... "
|
||||
if urlparse(s).hostname in ("www.youtube.com", "youtube.com", "youtu.be"): # if source is YouTube video
|
||||
if urlparse(s).hostname in {"www.youtube.com", "youtube.com", "youtu.be"}: # if source is YouTube video
|
||||
# YouTube format i.e. 'https://www.youtube.com/watch?v=Zgi9g1ksQHc' or 'https://youtu.be/LNwODJXcvt4'
|
||||
s = get_best_youtube_url(s)
|
||||
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
||||
|
|
@ -291,8 +291,14 @@ class LoadImagesAndVideos:
|
|||
else:
|
||||
raise FileNotFoundError(f"{p} does not exist")
|
||||
|
||||
images = [x for x in files if x.split(".")[-1].lower() in IMG_FORMATS]
|
||||
videos = [x for x in files if x.split(".")[-1].lower() in VID_FORMATS]
|
||||
# Define files as images or videos
|
||||
images, videos = [], []
|
||||
for f in files:
|
||||
suffix = f.split(".")[-1].lower() # Get file extension without the dot and lowercase
|
||||
if suffix in IMG_FORMATS:
|
||||
images.append(f)
|
||||
elif suffix in VID_FORMATS:
|
||||
videos.append(f)
|
||||
ni, nv = len(images), len(videos)
|
||||
|
||||
self.files = images + videos
|
||||
|
|
@ -307,10 +313,7 @@ class LoadImagesAndVideos:
|
|||
else:
|
||||
self.cap = None
|
||||
if self.nf == 0:
|
||||
raise FileNotFoundError(
|
||||
f"No images or videos found in {p}. "
|
||||
f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
|
||||
)
|
||||
raise FileNotFoundError(f"No images or videos found in {p}. {FORMATS_HELP_MSG}")
|
||||
|
||||
def __iter__(self):
|
||||
"""Returns an iterator object for VideoStream or ImageFolder."""
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ def load_yolo_dota(data_root, split="train"):
|
|||
- train
|
||||
- val
|
||||
"""
|
||||
assert split in ["train", "val"]
|
||||
assert split in {"train", "val"}, f"Split must be 'train' or 'val', not {split}."
|
||||
im_dir = Path(data_root) / "images" / split
|
||||
assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
|
||||
im_files = glob(str(Path(data_root) / "images" / split / "*"))
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ HELP_URL = "See https://docs.ultralytics.com/datasets/detect for dataset formatt
|
|||
IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm"} # image suffixes
|
||||
VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes
|
||||
PIN_MEMORY = str(os.getenv("PIN_MEMORY", True)).lower() == "true" # global pin_memory for dataloaders
|
||||
FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
|
||||
|
||||
|
||||
def img2label_paths(img_paths):
|
||||
|
|
@ -63,7 +64,7 @@ def exif_size(img: Image.Image):
|
|||
exif = img.getexif()
|
||||
if exif:
|
||||
rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274
|
||||
if rotation in [6, 8]: # rotation 270 or 90
|
||||
if rotation in {6, 8}: # rotation 270 or 90
|
||||
s = s[1], s[0]
|
||||
return s
|
||||
|
||||
|
|
@ -79,8 +80,8 @@ def verify_image(args):
|
|||
shape = exif_size(im) # image size
|
||||
shape = (shape[1], shape[0]) # hw
|
||||
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
||||
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
|
||||
if im.format.lower() in ("jpg", "jpeg"):
|
||||
assert im.format.lower() in IMG_FORMATS, f"Invalid image format {im.format}. {FORMATS_HELP_MSG}"
|
||||
if im.format.lower() in {"jpg", "jpeg"}:
|
||||
with open(im_file, "rb") as f:
|
||||
f.seek(-2, 2)
|
||||
if f.read() != b"\xff\xd9": # corrupt JPEG
|
||||
|
|
@ -105,8 +106,8 @@ def verify_image_label(args):
|
|||
shape = exif_size(im) # image size
|
||||
shape = (shape[1], shape[0]) # hw
|
||||
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
||||
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}"
|
||||
if im.format.lower() in ("jpg", "jpeg"):
|
||||
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}. {FORMATS_HELP_MSG}"
|
||||
if im.format.lower() in {"jpg", "jpeg"}:
|
||||
with open(im_file, "rb") as f:
|
||||
f.seek(-2, 2)
|
||||
if f.read() != b"\xff\xd9": # corrupt JPEG
|
||||
|
|
@ -336,7 +337,7 @@ def check_det_dataset(dataset, autodownload=True):
|
|||
else: # python script
|
||||
exec(s, {"yaml": data})
|
||||
dt = f"({round(time.time() - t, 1)}s)"
|
||||
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f"failure {dt} ❌"
|
||||
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in {0, None} else f"failure {dt} ❌"
|
||||
LOGGER.info(f"Dataset download {s}\n")
|
||||
check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts
|
||||
|
||||
|
|
@ -366,7 +367,7 @@ def check_cls_dataset(dataset, split=""):
|
|||
# Download (optional if dataset=https://file.zip is passed directly)
|
||||
if str(dataset).startswith(("http:/", "https:/")):
|
||||
dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
|
||||
elif Path(dataset).suffix in (".zip", ".tar", ".gz"):
|
||||
elif Path(dataset).suffix in {".zip", ".tar", ".gz"}:
|
||||
file = check_file(dataset)
|
||||
dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -159,7 +159,7 @@ class Exporter:
|
|||
_callbacks (dict, optional): Dictionary of callback functions. Defaults to None.
|
||||
"""
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
if self.args.format.lower() in ("coreml", "mlmodel"): # fix attempt for protobuf<3.20.x errors
|
||||
if self.args.format.lower() in {"coreml", "mlmodel"}: # fix attempt for protobuf<3.20.x errors
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # must run before TensorBoard callback
|
||||
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
|
|
@ -171,9 +171,9 @@ class Exporter:
|
|||
self.run_callbacks("on_export_start")
|
||||
t = time.time()
|
||||
fmt = self.args.format.lower() # to lowercase
|
||||
if fmt in ("tensorrt", "trt"): # 'engine' aliases
|
||||
if fmt in {"tensorrt", "trt"}: # 'engine' aliases
|
||||
fmt = "engine"
|
||||
if fmt in ("mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"): # 'coreml' aliases
|
||||
if fmt in {"mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"}: # 'coreml' aliases
|
||||
fmt = "coreml"
|
||||
fmts = tuple(export_formats()["Argument"][1:]) # available export formats
|
||||
flags = [x == fmt for x in fmts]
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ class Model(nn.Module):
|
|||
return
|
||||
|
||||
# Load or create new YOLO model
|
||||
if Path(model).suffix in (".yaml", ".yml"):
|
||||
if Path(model).suffix in {".yaml", ".yml"}:
|
||||
self._new(model, task=task, verbose=verbose)
|
||||
else:
|
||||
self._load(model, task=task)
|
||||
|
|
@ -666,7 +666,7 @@ class Model(nn.Module):
|
|||
self.trainer.hub_session = self.session # attach optional HUB session
|
||||
self.trainer.train()
|
||||
# Update model and cfg after training
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
|
||||
self.model, _ = attempt_load_one_weight(ckpt)
|
||||
self.overrides = self.model.args
|
||||
|
|
|
|||
|
|
@ -470,7 +470,7 @@ class Boxes(BaseTensor):
|
|||
if boxes.ndim == 1:
|
||||
boxes = boxes[None, :]
|
||||
n = boxes.shape[-1]
|
||||
assert n in (6, 7), f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls
|
||||
assert n in {6, 7}, f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls
|
||||
super().__init__(boxes, orig_shape)
|
||||
self.is_track = n == 7
|
||||
self.orig_shape = orig_shape
|
||||
|
|
@ -687,7 +687,7 @@ class OBB(BaseTensor):
|
|||
if boxes.ndim == 1:
|
||||
boxes = boxes[None, :]
|
||||
n = boxes.shape[-1]
|
||||
assert n in (7, 8), f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls
|
||||
assert n in {7, 8}, f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls
|
||||
super().__init__(boxes, orig_shape)
|
||||
self.is_track = n == 8
|
||||
self.orig_shape = orig_shape
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ class BaseTrainer:
|
|||
self.save_dir = get_save_dir(self.args)
|
||||
self.args.name = self.save_dir.name # update name for loggers
|
||||
self.wdir = self.save_dir / "weights" # weights dir
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||
self.args.save_dir = str(self.save_dir)
|
||||
yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
|
||||
|
|
@ -121,7 +121,7 @@ class BaseTrainer:
|
|||
print_args(vars(self.args))
|
||||
|
||||
# Device
|
||||
if self.device.type in ("cpu", "mps"):
|
||||
if self.device.type in {"cpu", "mps"}:
|
||||
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
||||
|
||||
# Model and Dataset
|
||||
|
|
@ -144,7 +144,7 @@ class BaseTrainer:
|
|||
|
||||
# Callbacks
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
|
|
@ -251,7 +251,7 @@ class BaseTrainer:
|
|||
|
||||
# Check AMP
|
||||
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
|
||||
if self.amp and RANK in (-1, 0): # Single-GPU and DDP
|
||||
if self.amp and RANK in {-1, 0}: # Single-GPU and DDP
|
||||
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
|
||||
self.amp = torch.tensor(check_amp(self.model), device=self.device)
|
||||
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
||||
|
|
@ -274,7 +274,7 @@ class BaseTrainer:
|
|||
# Dataloaders
|
||||
batch_size = self.batch_size // max(world_size, 1)
|
||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
|
||||
self.test_loader = self.get_dataloader(
|
||||
self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
|
||||
|
|
@ -340,7 +340,7 @@ class BaseTrainer:
|
|||
self._close_dataloader_mosaic()
|
||||
self.train_loader.reset()
|
||||
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
LOGGER.info(self.progress_string())
|
||||
pbar = TQDM(enumerate(self.train_loader), total=nb)
|
||||
self.tloss = None
|
||||
|
|
@ -392,7 +392,7 @@ class BaseTrainer:
|
|||
mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
|
||||
loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1
|
||||
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
pbar.set_description(
|
||||
("%11s" * 2 + "%11.4g" * (2 + loss_len))
|
||||
% (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
|
||||
|
|
@ -405,7 +405,7 @@ class BaseTrainer:
|
|||
|
||||
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
||||
self.run_callbacks("on_train_epoch_end")
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
final_epoch = epoch + 1 >= self.epochs
|
||||
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
|
||||
|
||||
|
|
@ -447,7 +447,7 @@ class BaseTrainer:
|
|||
break # must break all DDP ranks
|
||||
epoch += 1
|
||||
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
# Do final val with best.pt
|
||||
LOGGER.info(
|
||||
f"\n{epoch - self.start_epoch + 1} epochs completed in "
|
||||
|
|
@ -503,12 +503,12 @@ class BaseTrainer:
|
|||
try:
|
||||
if self.args.task == "classify":
|
||||
data = check_cls_dataset(self.args.data)
|
||||
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
|
||||
elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in {
|
||||
"detect",
|
||||
"segment",
|
||||
"pose",
|
||||
"obb",
|
||||
):
|
||||
}:
|
||||
data = check_det_dataset(self.args.data)
|
||||
if "yaml_file" in data:
|
||||
self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
||||
|
|
@ -740,7 +740,7 @@ class BaseTrainer:
|
|||
else: # weight (with decay)
|
||||
g[0].append(param)
|
||||
|
||||
if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
|
||||
if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
|
||||
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
||||
elif name == "RMSProp":
|
||||
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
||||
|
|
|
|||
|
|
@ -139,14 +139,14 @@ class BaseValidator:
|
|||
self.args.batch = 1 # export.py models default to batch-size 1
|
||||
LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
|
||||
|
||||
if str(self.args.data).split(".")[-1] in ("yaml", "yml"):
|
||||
if str(self.args.data).split(".")[-1] in {"yaml", "yml"}:
|
||||
self.data = check_det_dataset(self.args.data)
|
||||
elif self.args.task == "classify":
|
||||
self.data = check_cls_dataset(self.args.data, split=self.args.split)
|
||||
else:
|
||||
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
||||
|
||||
if self.device.type in ("cpu", "mps"):
|
||||
if self.device.type in {"cpu", "mps"}:
|
||||
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
||||
if not pt:
|
||||
self.args.rect = False
|
||||
|
|
|
|||
|
|
@ -198,7 +198,7 @@ class Events:
|
|||
}
|
||||
self.enabled = (
|
||||
SETTINGS["sync"]
|
||||
and RANK in (-1, 0)
|
||||
and RANK in {-1, 0}
|
||||
and not TESTS_RUNNING
|
||||
and ONLINE
|
||||
and (is_pip_package() or get_git_origin_url() == "https://github.com/ultralytics/ultralytics.git")
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ class FastSAM(Model):
|
|||
"""Call the __init__ method of the parent class (YOLO) with the updated default model."""
|
||||
if str(model) == "FastSAM.pt":
|
||||
model = "FastSAM-x.pt"
|
||||
assert Path(model).suffix not in (".yaml", ".yml"), "FastSAM models only support pre-trained models."
|
||||
assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models."
|
||||
super().__init__(model=model, task="segment")
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class NAS(Model):
|
|||
|
||||
def __init__(self, model="yolo_nas_s.pt") -> None:
|
||||
"""Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
|
||||
assert Path(model).suffix not in (".yaml", ".yml"), "YOLO-NAS models only support pre-trained models."
|
||||
assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."
|
||||
super().__init__(model, task="detect")
|
||||
|
||||
@smart_inference_mode()
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ class SAM(Model):
|
|||
Raises:
|
||||
NotImplementedError: If the model file extension is not .pt or .pth.
|
||||
"""
|
||||
if model and Path(model).suffix not in (".pt", ".pth"):
|
||||
if model and Path(model).suffix not in {".pt", ".pth"}:
|
||||
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
|
||||
super().__init__(model=model, task="segment")
|
||||
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ class PatchMerging(nn.Module):
|
|||
self.out_dim = out_dim
|
||||
self.act = activation()
|
||||
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
|
||||
stride_c = 1 if out_dim in [320, 448, 576] else 2
|
||||
stride_c = 1 if out_dim in {320, 448, 576} else 2
|
||||
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
|
||||
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
||||
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ class ClassificationTrainer(BaseTrainer):
|
|||
self.model, ckpt = attempt_load_one_weight(model, device="cpu")
|
||||
for p in self.model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
elif model.split(".")[-1] in ("yaml", "yml"):
|
||||
elif model.split(".")[-1] in {"yaml", "yml"}:
|
||||
self.model = self.get_model(cfg=model)
|
||||
elif model in torchvision.models.__dict__:
|
||||
self.model = torchvision.models.__dict__[model](weights="IMAGENET1K_V1" if self.args.pretrained else None)
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class DetectionTrainer(BaseTrainer):
|
|||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
||||
"""Construct and return dataloader."""
|
||||
assert mode in ["train", "val"]
|
||||
assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}."
|
||||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||
dataset = self.build_dataset(dataset_path, mode, batch_size)
|
||||
shuffle = mode == "train"
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from ultralytics.utils.torch_utils import de_parallel
|
|||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""Callback."""
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
# NOTE: for evaluation
|
||||
names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
|
||||
de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
|
||||
|
|
|
|||
|
|
@ -374,9 +374,9 @@ class AutoBackend(nn.Module):
|
|||
metadata = yaml_load(metadata)
|
||||
if metadata:
|
||||
for k, v in metadata.items():
|
||||
if k in ("stride", "batch"):
|
||||
if k in {"stride", "batch"}:
|
||||
metadata[k] = int(v)
|
||||
elif k in ("imgsz", "names", "kpt_shape") and isinstance(v, str):
|
||||
elif k in {"imgsz", "names", "kpt_shape"} and isinstance(v, str):
|
||||
metadata[k] = eval(v)
|
||||
stride = metadata["stride"]
|
||||
task = metadata["task"]
|
||||
|
|
@ -531,8 +531,8 @@ class AutoBackend(nn.Module):
|
|||
self.names = {i: f"class{i}" for i in range(nc)}
|
||||
else: # Lite or Edge TPU
|
||||
details = self.input_details[0]
|
||||
integer = details["dtype"] in (np.int8, np.int16) # is TFLite quantized int8 or int16 model
|
||||
if integer:
|
||||
is_int = details["dtype"] in {np.int8, np.int16} # is TFLite quantized int8 or int16 model
|
||||
if is_int:
|
||||
scale, zero_point = details["quantization"]
|
||||
im = (im / scale + zero_point).astype(details["dtype"]) # de-scale
|
||||
self.interpreter.set_tensor(details["index"], im)
|
||||
|
|
@ -540,7 +540,7 @@ class AutoBackend(nn.Module):
|
|||
y = []
|
||||
for output in self.output_details:
|
||||
x = self.interpreter.get_tensor(output["index"])
|
||||
if integer:
|
||||
if is_int:
|
||||
scale, zero_point = output["quantization"]
|
||||
x = (x.astype(np.float32) - zero_point) * scale # re-scale
|
||||
if x.ndim == 3: # if task is not classification, excluding masks (ndim=4) as well
|
||||
|
|
|
|||
|
|
@ -296,7 +296,7 @@ class SpatialAttention(nn.Module):
|
|||
def __init__(self, kernel_size=7):
|
||||
"""Initialize Spatial-attention module with kernel size argument."""
|
||||
super().__init__()
|
||||
assert kernel_size in (3, 7), "kernel size must be 3 or 7"
|
||||
assert kernel_size in {3, 7}, "kernel size must be 3 or 7"
|
||||
padding = 3 if kernel_size == 7 else 1
|
||||
self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
||||
self.act = nn.Sigmoid()
|
||||
|
|
|
|||
|
|
@ -54,13 +54,13 @@ class Detect(nn.Module):
|
|||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||
self.shape = shape
|
||||
|
||||
if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV ops
|
||||
if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
|
||||
box = x_cat[:, : self.reg_max * 4]
|
||||
cls = x_cat[:, self.reg_max * 4 :]
|
||||
else:
|
||||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||||
|
||||
if self.export and self.format in ("tflite", "edgetpu"):
|
||||
if self.export and self.format in {"tflite", "edgetpu"}:
|
||||
# Precompute normalization factor to increase numerical stability
|
||||
# See https://github.com/ultralytics/ultralytics/issues/7371
|
||||
grid_h = shape[2]
|
||||
|
|
@ -230,13 +230,13 @@ class WorldDetect(Detect):
|
|||
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
|
||||
self.shape = shape
|
||||
|
||||
if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"): # avoid TF FlexSplitV ops
|
||||
if self.export and self.format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}: # avoid TF FlexSplitV ops
|
||||
box = x_cat[:, : self.reg_max * 4]
|
||||
cls = x_cat[:, self.reg_max * 4 :]
|
||||
else:
|
||||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||||
|
||||
if self.export and self.format in ("tflite", "edgetpu"):
|
||||
if self.export and self.format in {"tflite", "edgetpu"}:
|
||||
# Precompute normalization factor to increase numerical stability
|
||||
# See https://github.com/ultralytics/ultralytics/issues/7371
|
||||
grid_h = shape[2]
|
||||
|
|
|
|||
|
|
@ -896,7 +896,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
) # num heads
|
||||
|
||||
args = [c1, c2, *args[1:]]
|
||||
if m in (BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3):
|
||||
if m in {BottleneckCSP, C1, C2, C2f, C2fAttn, C3, C3TR, C3Ghost, C3x, RepC3}:
|
||||
args.insert(2, n) # number of repeats
|
||||
n = 1
|
||||
elif m is AIFI:
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ class AIGym:
|
|||
self.annotator = Annotator(im0, line_width=2)
|
||||
|
||||
for ind, k in enumerate(reversed(self.keypoints)):
|
||||
if self.pose_type in ["pushup", "pullup"]:
|
||||
if self.pose_type in {"pushup", "pullup"}:
|
||||
self.angle[ind] = self.annotator.estimate_pose_angle(
|
||||
k[int(self.kpts_to_check[0])].cpu(),
|
||||
k[int(self.kpts_to_check[1])].cpu(),
|
||||
|
|
|
|||
|
|
@ -153,7 +153,7 @@ class Heatmap:
|
|||
self.cls_txtdisplay_gap = cls_txtdisplay_gap
|
||||
|
||||
# shape of heatmap, if not selected
|
||||
if self.shape not in ["circle", "rect"]:
|
||||
if self.shape not in {"circle", "rect"}:
|
||||
print("Unknown shape value provided, 'circle' & 'rect' supported")
|
||||
print("Using Circular shape now")
|
||||
self.shape = "circle"
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class STrack(BaseTrack):
|
|||
"""Initialize new STrack instance."""
|
||||
super().__init__()
|
||||
# xywh+idx or xywha+idx
|
||||
assert len(xywh) in [5, 6], f"expected 5 or 6 values but got {len(xywh)}"
|
||||
assert len(xywh) in {5, 6}, f"expected 5 or 6 values but got {len(xywh)}"
|
||||
self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32)
|
||||
self.kalman_filter = None
|
||||
self.mean, self.covariance = None, None
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
|
|||
tracker = check_yaml(predictor.args.tracker)
|
||||
cfg = IterableSimpleNamespace(**yaml_load(tracker))
|
||||
|
||||
if cfg.tracker_type not in ["bytetrack", "botsort"]:
|
||||
if cfg.tracker_type not in {"bytetrack", "botsort"}:
|
||||
raise AssertionError(f"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'")
|
||||
|
||||
trackers = []
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ class GMC:
|
|||
array([[1, 2, 3],
|
||||
[4, 5, 6]])
|
||||
"""
|
||||
if self.method in ["orb", "sift"]:
|
||||
if self.method in {"orb", "sift"}:
|
||||
return self.applyFeatures(raw_frame, detections)
|
||||
elif self.method == "ecc":
|
||||
return self.applyEcc(raw_frame)
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ VERBOSE = str(os.getenv("YOLO_VERBOSE", True)).lower() == "true" # global verbo
|
|||
TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar format
|
||||
LOGGING_NAME = "ultralytics"
|
||||
MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans
|
||||
ARM64 = platform.machine() in ("arm64", "aarch64") # ARM64 booleans
|
||||
ARM64 = platform.machine() in {"arm64", "aarch64"} # ARM64 booleans
|
||||
HELP_MSG = """
|
||||
Usage examples for running YOLOv8:
|
||||
|
||||
|
|
@ -359,7 +359,7 @@ def yaml_load(file="data.yaml", append_filename=False):
|
|||
Returns:
|
||||
(dict): YAML data and file name.
|
||||
"""
|
||||
assert Path(file).suffix in (".yaml", ".yml"), f"Attempting to load non-YAML file {file} with yaml_load()"
|
||||
assert Path(file).suffix in {".yaml", ".yml"}, f"Attempting to load non-YAML file {file} with yaml_load()"
|
||||
with open(file, errors="ignore", encoding="utf-8") as f:
|
||||
s = f.read() # string
|
||||
|
||||
|
|
@ -866,7 +866,7 @@ def set_sentry():
|
|||
"""
|
||||
if "exc_info" in hint:
|
||||
exc_type, exc_value, tb = hint["exc_info"]
|
||||
if exc_type in (KeyboardInterrupt, FileNotFoundError) or "out of memory" in str(exc_value):
|
||||
if exc_type in {KeyboardInterrupt, FileNotFoundError} or "out of memory" in str(exc_value):
|
||||
return None # do not send event
|
||||
|
||||
event["tags"] = {
|
||||
|
|
@ -879,7 +879,7 @@ def set_sentry():
|
|||
|
||||
if (
|
||||
SETTINGS["sync"]
|
||||
and RANK in (-1, 0)
|
||||
and RANK in {-1, 0}
|
||||
and Path(ARGV[0]).name == "yolo"
|
||||
and not TESTS_RUNNING
|
||||
and ONLINE
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ def benchmark(
|
|||
|
||||
# Predict
|
||||
assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported"
|
||||
assert i not in (9, 10), "inference not supported" # Edge TPU and TF.js are unsupported
|
||||
assert i not in {9, 10}, "inference not supported" # Edge TPU and TF.js are unsupported
|
||||
assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML
|
||||
exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half)
|
||||
|
||||
|
|
@ -220,7 +220,7 @@ class ProfileModels:
|
|||
output = []
|
||||
for file in files:
|
||||
engine_file = file.with_suffix(".engine")
|
||||
if file.suffix in (".pt", ".yaml", ".yml"):
|
||||
if file.suffix in {".pt", ".yaml", ".yml"}:
|
||||
model = YOLO(str(file))
|
||||
model.fuse() # to report correct params and GFLOPs in model.info()
|
||||
model_info = model.info()
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ def _get_experiment_type(mode, project_name):
|
|||
|
||||
def _create_experiment(args):
|
||||
"""Ensures that the experiment object is only created in a single process during distributed training."""
|
||||
if RANK not in (-1, 0):
|
||||
if RANK not in {-1, 0}:
|
||||
return
|
||||
try:
|
||||
comet_mode = _get_comet_mode()
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ def on_train_end(trainer):
|
|||
for f in trainer.save_dir.glob("*"): # log all other files in save_dir
|
||||
if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
|
||||
mlflow.log_artifact(str(f))
|
||||
keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() in ("true")
|
||||
keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() == "true"
|
||||
if keep_run_active:
|
||||
LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()")
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -237,7 +237,7 @@ def check_version(
|
|||
result = False
|
||||
elif op == "!=" and c == v:
|
||||
result = False
|
||||
elif op in (">=", "") and not (c >= v): # if no constraint passed assume '>=required'
|
||||
elif op in {">=", ""} and not (c >= v): # if no constraint passed assume '>=required'
|
||||
result = False
|
||||
elif op == "<=" and not (c <= v):
|
||||
result = False
|
||||
|
|
@ -632,7 +632,7 @@ def check_amp(model):
|
|||
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
|
||||
"""
|
||||
device = next(model.parameters()).device # get model device
|
||||
if device.type in ("cpu", "mps"):
|
||||
if device.type in {"cpu", "mps"}:
|
||||
return False # AMP only used on CUDA devices
|
||||
|
||||
def amp_allclose(m, im):
|
||||
|
|
|
|||
|
|
@ -356,13 +356,13 @@ def safe_download(
|
|||
raise ConnectionError(emojis(f"❌ Download failure for {url}. Retry limit reached.")) from e
|
||||
LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {url}...")
|
||||
|
||||
if unzip and f.exists() and f.suffix in ("", ".zip", ".tar", ".gz"):
|
||||
if unzip and f.exists() and f.suffix in {"", ".zip", ".tar", ".gz"}:
|
||||
from zipfile import is_zipfile
|
||||
|
||||
unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place
|
||||
if is_zipfile(f):
|
||||
unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip
|
||||
elif f.suffix in (".tar", ".gz"):
|
||||
elif f.suffix in {".tar", ".gz"}:
|
||||
LOGGER.info(f"Unzipping {f} to {unzip_dir}...")
|
||||
subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True)
|
||||
if delete:
|
||||
|
|
|
|||
|
|
@ -298,7 +298,7 @@ class ConfusionMatrix:
|
|||
self.task = task
|
||||
self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
|
||||
self.nc = nc # number of classes
|
||||
self.conf = 0.25 if conf in (None, 0.001) else conf # apply 0.25 if default val conf is passed
|
||||
self.conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passed
|
||||
self.iou_thres = iou_thres
|
||||
|
||||
def process_cls_preds(self, preds, targets):
|
||||
|
|
|
|||
|
|
@ -904,7 +904,7 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
|
|||
ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
|
||||
ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
|
||||
ax[i].set_title(s[j], fontsize=12)
|
||||
# if j in [8, 9, 10]: # share train and val loss y axes
|
||||
# if j in {8, 9, 10}: # share train and val loss y axes
|
||||
# ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"WARNING: Plotting error for {f}: {e}")
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ TORCHVISION_0_13 = check_version(torchvision.__version__, "0.13.0")
|
|||
def torch_distributed_zero_first(local_rank: int):
|
||||
"""Decorator to make all processes in distributed training wait for each local_master to do something."""
|
||||
initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
|
||||
if initialized and local_rank not in (-1, 0):
|
||||
if initialized and local_rank not in {-1, 0}:
|
||||
dist.barrier(device_ids=[local_rank])
|
||||
yield
|
||||
if initialized and local_rank == 0:
|
||||
|
|
@ -109,7 +109,7 @@ def select_device(device="", batch=0, newline=False, verbose=True):
|
|||
for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
|
||||
device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
|
||||
cpu = device == "cpu"
|
||||
mps = device in ("mps", "mps:0") # Apple Metal Performance Shaders (MPS)
|
||||
mps = device in {"mps", "mps:0"} # Apple Metal Performance Shaders (MPS)
|
||||
if cpu or mps:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
|
||||
elif device: # non-cpu device requested
|
||||
|
|
@ -347,7 +347,7 @@ def initialize_weights(model):
|
|||
elif t is nn.BatchNorm2d:
|
||||
m.eps = 1e-3
|
||||
m.momentum = 0.03
|
||||
elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
|
||||
elif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}:
|
||||
m.inplace = True
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue