ultralytics 8.1.40 search in Python sets {} for speed (#9450)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
30484d5925
commit
ea527507fe
41 changed files with 97 additions and 93 deletions
|
|
@ -159,7 +159,7 @@ class Exporter:
|
|||
_callbacks (dict, optional): Dictionary of callback functions. Defaults to None.
|
||||
"""
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
if self.args.format.lower() in ("coreml", "mlmodel"): # fix attempt for protobuf<3.20.x errors
|
||||
if self.args.format.lower() in {"coreml", "mlmodel"}: # fix attempt for protobuf<3.20.x errors
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" # must run before TensorBoard callback
|
||||
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
|
|
@ -171,9 +171,9 @@ class Exporter:
|
|||
self.run_callbacks("on_export_start")
|
||||
t = time.time()
|
||||
fmt = self.args.format.lower() # to lowercase
|
||||
if fmt in ("tensorrt", "trt"): # 'engine' aliases
|
||||
if fmt in {"tensorrt", "trt"}: # 'engine' aliases
|
||||
fmt = "engine"
|
||||
if fmt in ("mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"): # 'coreml' aliases
|
||||
if fmt in {"mlmodel", "mlpackage", "mlprogram", "apple", "ios", "coreml"}: # 'coreml' aliases
|
||||
fmt = "coreml"
|
||||
fmts = tuple(export_formats()["Argument"][1:]) # available export formats
|
||||
flags = [x == fmt for x in fmts]
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ class Model(nn.Module):
|
|||
return
|
||||
|
||||
# Load or create new YOLO model
|
||||
if Path(model).suffix in (".yaml", ".yml"):
|
||||
if Path(model).suffix in {".yaml", ".yml"}:
|
||||
self._new(model, task=task, verbose=verbose)
|
||||
else:
|
||||
self._load(model, task=task)
|
||||
|
|
@ -666,7 +666,7 @@ class Model(nn.Module):
|
|||
self.trainer.hub_session = self.session # attach optional HUB session
|
||||
self.trainer.train()
|
||||
# Update model and cfg after training
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
|
||||
self.model, _ = attempt_load_one_weight(ckpt)
|
||||
self.overrides = self.model.args
|
||||
|
|
|
|||
|
|
@ -470,7 +470,7 @@ class Boxes(BaseTensor):
|
|||
if boxes.ndim == 1:
|
||||
boxes = boxes[None, :]
|
||||
n = boxes.shape[-1]
|
||||
assert n in (6, 7), f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls
|
||||
assert n in {6, 7}, f"expected 6 or 7 values but got {n}" # xyxy, track_id, conf, cls
|
||||
super().__init__(boxes, orig_shape)
|
||||
self.is_track = n == 7
|
||||
self.orig_shape = orig_shape
|
||||
|
|
@ -687,7 +687,7 @@ class OBB(BaseTensor):
|
|||
if boxes.ndim == 1:
|
||||
boxes = boxes[None, :]
|
||||
n = boxes.shape[-1]
|
||||
assert n in (7, 8), f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls
|
||||
assert n in {7, 8}, f"expected 7 or 8 values but got {n}" # xywh, rotation, track_id, conf, cls
|
||||
super().__init__(boxes, orig_shape)
|
||||
self.is_track = n == 8
|
||||
self.orig_shape = orig_shape
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ class BaseTrainer:
|
|||
self.save_dir = get_save_dir(self.args)
|
||||
self.args.name = self.save_dir.name # update name for loggers
|
||||
self.wdir = self.save_dir / "weights" # weights dir
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||
self.args.save_dir = str(self.save_dir)
|
||||
yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
|
||||
|
|
@ -121,7 +121,7 @@ class BaseTrainer:
|
|||
print_args(vars(self.args))
|
||||
|
||||
# Device
|
||||
if self.device.type in ("cpu", "mps"):
|
||||
if self.device.type in {"cpu", "mps"}:
|
||||
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
||||
|
||||
# Model and Dataset
|
||||
|
|
@ -144,7 +144,7 @@ class BaseTrainer:
|
|||
|
||||
# Callbacks
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
|
|
@ -251,7 +251,7 @@ class BaseTrainer:
|
|||
|
||||
# Check AMP
|
||||
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
|
||||
if self.amp and RANK in (-1, 0): # Single-GPU and DDP
|
||||
if self.amp and RANK in {-1, 0}: # Single-GPU and DDP
|
||||
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
|
||||
self.amp = torch.tensor(check_amp(self.model), device=self.device)
|
||||
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
||||
|
|
@ -274,7 +274,7 @@ class BaseTrainer:
|
|||
# Dataloaders
|
||||
batch_size = self.batch_size // max(world_size, 1)
|
||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
|
||||
self.test_loader = self.get_dataloader(
|
||||
self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
|
||||
|
|
@ -340,7 +340,7 @@ class BaseTrainer:
|
|||
self._close_dataloader_mosaic()
|
||||
self.train_loader.reset()
|
||||
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
LOGGER.info(self.progress_string())
|
||||
pbar = TQDM(enumerate(self.train_loader), total=nb)
|
||||
self.tloss = None
|
||||
|
|
@ -392,7 +392,7 @@ class BaseTrainer:
|
|||
mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
|
||||
loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1
|
||||
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
pbar.set_description(
|
||||
("%11s" * 2 + "%11.4g" * (2 + loss_len))
|
||||
% (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
|
||||
|
|
@ -405,7 +405,7 @@ class BaseTrainer:
|
|||
|
||||
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
||||
self.run_callbacks("on_train_epoch_end")
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
final_epoch = epoch + 1 >= self.epochs
|
||||
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
|
||||
|
||||
|
|
@ -447,7 +447,7 @@ class BaseTrainer:
|
|||
break # must break all DDP ranks
|
||||
epoch += 1
|
||||
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
# Do final val with best.pt
|
||||
LOGGER.info(
|
||||
f"\n{epoch - self.start_epoch + 1} epochs completed in "
|
||||
|
|
@ -503,12 +503,12 @@ class BaseTrainer:
|
|||
try:
|
||||
if self.args.task == "classify":
|
||||
data = check_cls_dataset(self.args.data)
|
||||
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
|
||||
elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in {
|
||||
"detect",
|
||||
"segment",
|
||||
"pose",
|
||||
"obb",
|
||||
):
|
||||
}:
|
||||
data = check_det_dataset(self.args.data)
|
||||
if "yaml_file" in data:
|
||||
self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
||||
|
|
@ -740,7 +740,7 @@ class BaseTrainer:
|
|||
else: # weight (with decay)
|
||||
g[0].append(param)
|
||||
|
||||
if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
|
||||
if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
|
||||
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
||||
elif name == "RMSProp":
|
||||
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
||||
|
|
|
|||
|
|
@ -139,14 +139,14 @@ class BaseValidator:
|
|||
self.args.batch = 1 # export.py models default to batch-size 1
|
||||
LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
|
||||
|
||||
if str(self.args.data).split(".")[-1] in ("yaml", "yml"):
|
||||
if str(self.args.data).split(".")[-1] in {"yaml", "yml"}:
|
||||
self.data = check_det_dataset(self.args.data)
|
||||
elif self.args.task == "classify":
|
||||
self.data = check_cls_dataset(self.args.data, split=self.args.split)
|
||||
else:
|
||||
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
||||
|
||||
if self.device.type in ("cpu", "mps"):
|
||||
if self.device.type in {"cpu", "mps"}:
|
||||
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
||||
if not pt:
|
||||
self.args.rect = False
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue