YOLOv5 updates (#90)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2022-12-25 14:33:18 +01:00 committed by GitHub
parent ebd3cfb2fd
commit 98815d560f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 281 additions and 161 deletions

View file

@ -64,7 +64,7 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank
label_path=label_path,
imgsz=cfg.imgsz,
batch_size=batch_size,
augment=True if mode == "train" else False, # augmentation
augment=mode == "train", # augmentation
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
rect=cfg.rect if mode == "train" else True, # rectangular batches
cache=None if cfg.noval else cfg.get("cache", None),
@ -73,31 +73,25 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank
pad=0.0 if mode == "train" else 0.5,
prefix=colorstr(f"{mode}: "),
use_segments=cfg.task == "segment",
use_keypoints=cfg.task == "keypoint",
)
use_keypoints=cfg.task == "keypoint")
batch_size = min(batch_size, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices
workers = cfg.workers if mode == "train" else cfg.workers * 2
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if cfg.image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
loader = DataLoader if cfg.image_weights or cfg.close_mosaic else InfiniteDataLoader # allow attribute updates
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK)
return (
loader(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle and sampler is None,
num_workers=nw,
sampler=sampler,
pin_memory=PIN_MEMORY,
collate_fn=getattr(dataset, "collate_fn", None),
worker_init_fn=seed_worker,
generator=generator,
),
dataset,
)
return loader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle and sampler is None,
num_workers=nw,
sampler=sampler,
pin_memory=PIN_MEMORY,
collate_fn=getattr(dataset, "collate_fn", None),
worker_init_fn=seed_worker,
generator=generator), dataset
# build classification