YOLOv5 updates (#90)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
ebd3cfb2fd
commit
98815d560f
27 changed files with 281 additions and 161 deletions
|
|
@ -64,7 +64,7 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank
|
|||
label_path=label_path,
|
||||
imgsz=cfg.imgsz,
|
||||
batch_size=batch_size,
|
||||
augment=True if mode == "train" else False, # augmentation
|
||||
augment=mode == "train", # augmentation
|
||||
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
|
||||
rect=cfg.rect if mode == "train" else True, # rectangular batches
|
||||
cache=None if cfg.noval else cfg.get("cache", None),
|
||||
|
|
@ -73,31 +73,25 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank
|
|||
pad=0.0 if mode == "train" else 0.5,
|
||||
prefix=colorstr(f"{mode}: "),
|
||||
use_segments=cfg.task == "segment",
|
||||
use_keypoints=cfg.task == "keypoint",
|
||||
)
|
||||
use_keypoints=cfg.task == "keypoint")
|
||||
|
||||
batch_size = min(batch_size, len(dataset))
|
||||
nd = torch.cuda.device_count() # number of CUDA devices
|
||||
workers = cfg.workers if mode == "train" else cfg.workers * 2
|
||||
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
|
||||
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||
loader = DataLoader if cfg.image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
|
||||
loader = DataLoader if cfg.image_weights or cfg.close_mosaic else InfiniteDataLoader # allow attribute updates
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(6148914691236517205 + RANK)
|
||||
return (
|
||||
loader(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle and sampler is None,
|
||||
num_workers=nw,
|
||||
sampler=sampler,
|
||||
pin_memory=PIN_MEMORY,
|
||||
collate_fn=getattr(dataset, "collate_fn", None),
|
||||
worker_init_fn=seed_worker,
|
||||
generator=generator,
|
||||
),
|
||||
dataset,
|
||||
)
|
||||
return loader(dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle and sampler is None,
|
||||
num_workers=nw,
|
||||
sampler=sampler,
|
||||
pin_memory=PIN_MEMORY,
|
||||
collate_fn=getattr(dataset, "collate_fn", None),
|
||||
worker_init_fn=seed_worker,
|
||||
generator=generator), dataset
|
||||
|
||||
|
||||
# build classification
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue