ultralytics 8.1.40 search in Python sets {} for speed (#9450)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
30484d5925
commit
ea527507fe
41 changed files with 97 additions and 93 deletions
|
|
@ -107,7 +107,7 @@ class BaseTrainer:
|
|||
self.save_dir = get_save_dir(self.args)
|
||||
self.args.name = self.save_dir.name # update name for loggers
|
||||
self.wdir = self.save_dir / "weights" # weights dir
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||
self.args.save_dir = str(self.save_dir)
|
||||
yaml_save(self.save_dir / "args.yaml", vars(self.args)) # save run args
|
||||
|
|
@ -121,7 +121,7 @@ class BaseTrainer:
|
|||
print_args(vars(self.args))
|
||||
|
||||
# Device
|
||||
if self.device.type in ("cpu", "mps"):
|
||||
if self.device.type in {"cpu", "mps"}:
|
||||
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
||||
|
||||
# Model and Dataset
|
||||
|
|
@ -144,7 +144,7 @@ class BaseTrainer:
|
|||
|
||||
# Callbacks
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
|
|
@ -251,7 +251,7 @@ class BaseTrainer:
|
|||
|
||||
# Check AMP
|
||||
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
|
||||
if self.amp and RANK in (-1, 0): # Single-GPU and DDP
|
||||
if self.amp and RANK in {-1, 0}: # Single-GPU and DDP
|
||||
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
|
||||
self.amp = torch.tensor(check_amp(self.model), device=self.device)
|
||||
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
||||
|
|
@ -274,7 +274,7 @@ class BaseTrainer:
|
|||
# Dataloaders
|
||||
batch_size = self.batch_size // max(world_size, 1)
|
||||
self.train_loader = self.get_dataloader(self.trainset, batch_size=batch_size, rank=RANK, mode="train")
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
|
||||
self.test_loader = self.get_dataloader(
|
||||
self.testset, batch_size=batch_size if self.args.task == "obb" else batch_size * 2, rank=-1, mode="val"
|
||||
|
|
@ -340,7 +340,7 @@ class BaseTrainer:
|
|||
self._close_dataloader_mosaic()
|
||||
self.train_loader.reset()
|
||||
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
LOGGER.info(self.progress_string())
|
||||
pbar = TQDM(enumerate(self.train_loader), total=nb)
|
||||
self.tloss = None
|
||||
|
|
@ -392,7 +392,7 @@ class BaseTrainer:
|
|||
mem = f"{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G" # (GB)
|
||||
loss_len = self.tloss.shape[0] if len(self.tloss.shape) else 1
|
||||
losses = self.tloss if loss_len > 1 else torch.unsqueeze(self.tloss, 0)
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
pbar.set_description(
|
||||
("%11s" * 2 + "%11.4g" * (2 + loss_len))
|
||||
% (f"{epoch + 1}/{self.epochs}", mem, *losses, batch["cls"].shape[0], batch["img"].shape[-1])
|
||||
|
|
@ -405,7 +405,7 @@ class BaseTrainer:
|
|||
|
||||
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
||||
self.run_callbacks("on_train_epoch_end")
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
final_epoch = epoch + 1 >= self.epochs
|
||||
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
|
||||
|
||||
|
|
@ -447,7 +447,7 @@ class BaseTrainer:
|
|||
break # must break all DDP ranks
|
||||
epoch += 1
|
||||
|
||||
if RANK in (-1, 0):
|
||||
if RANK in {-1, 0}:
|
||||
# Do final val with best.pt
|
||||
LOGGER.info(
|
||||
f"\n{epoch - self.start_epoch + 1} epochs completed in "
|
||||
|
|
@ -503,12 +503,12 @@ class BaseTrainer:
|
|||
try:
|
||||
if self.args.task == "classify":
|
||||
data = check_cls_dataset(self.args.data)
|
||||
elif self.args.data.split(".")[-1] in ("yaml", "yml") or self.args.task in (
|
||||
elif self.args.data.split(".")[-1] in {"yaml", "yml"} or self.args.task in {
|
||||
"detect",
|
||||
"segment",
|
||||
"pose",
|
||||
"obb",
|
||||
):
|
||||
}:
|
||||
data = check_det_dataset(self.args.data)
|
||||
if "yaml_file" in data:
|
||||
self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
||||
|
|
@ -740,7 +740,7 @@ class BaseTrainer:
|
|||
else: # weight (with decay)
|
||||
g[0].append(param)
|
||||
|
||||
if name in ("Adam", "Adamax", "AdamW", "NAdam", "RAdam"):
|
||||
if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
|
||||
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
||||
elif name == "RMSProp":
|
||||
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue