ultralytics 8.0.32 HUB and TensorFlow fixes (#870)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
f5d003d05a
commit
c9893810c7
14 changed files with 118 additions and 85 deletions
|
|
@ -61,7 +61,7 @@ def seed_worker(worker_id):
|
|||
random.seed(worker_seed)
|
||||
|
||||
|
||||
def build_dataloader(cfg, batch_size, img_path, stride=32, rect=False, label_path=None, rank=-1, mode="train"):
|
||||
def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, rank=-1, mode="train"):
|
||||
assert mode in ["train", "val"]
|
||||
shuffle = mode == "train"
|
||||
if cfg.rect and shuffle:
|
||||
|
|
@ -70,9 +70,8 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, rect=False, label_pat
|
|||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||
dataset = YOLODataset(
|
||||
img_path=img_path,
|
||||
label_path=label_path,
|
||||
imgsz=cfg.imgsz,
|
||||
batch_size=batch_size,
|
||||
batch_size=batch,
|
||||
augment=mode == "train", # augmentation
|
||||
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
|
||||
rect=cfg.rect or rect, # rectangular batches
|
||||
|
|
@ -82,18 +81,19 @@ def build_dataloader(cfg, batch_size, img_path, stride=32, rect=False, label_pat
|
|||
pad=0.0 if mode == "train" else 0.5,
|
||||
prefix=colorstr(f"{mode}: "),
|
||||
use_segments=cfg.task == "segment",
|
||||
use_keypoints=cfg.task == "keypoint")
|
||||
use_keypoints=cfg.task == "keypoint",
|
||||
names=names)
|
||||
|
||||
batch_size = min(batch_size, len(dataset))
|
||||
batch = min(batch, len(dataset))
|
||||
nd = torch.cuda.device_count() # number of CUDA devices
|
||||
workers = cfg.workers if mode == "train" else cfg.workers * 2
|
||||
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
|
||||
nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers]) # number of workers
|
||||
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||
loader = DataLoader if cfg.image_weights or cfg.close_mosaic else InfiniteDataLoader # allow attribute updates
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(6148914691236517205 + RANK)
|
||||
return loader(dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
batch_size=batch,
|
||||
shuffle=shuffle and sampler is None,
|
||||
num_workers=nw,
|
||||
sampler=sampler,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue