update segment training (#57)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Laughing 2022-11-29 05:30:08 -06:00 committed by GitHub
parent d0b0fe2592
commit 3a241e4cea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 460 additions and 144 deletions

View file

@ -5,7 +5,7 @@ import numpy as np
import torch
from torch.utils.data import DataLoader, dataloader, distributed
from ..utils import LOGGER
from ..utils import LOGGER, colorstr
from ..utils.torch_utils import torch_distributed_zero_first
from .dataset import ClassificationDataset, YOLODataset
from .utils import PIN_MEMORY, RANK
@ -52,53 +52,36 @@ def seed_worker(worker_id):
random.seed(worker_seed)
# TODO: we can inject most args from a config file
def build_dataloader(
img_path,
img_size, #
batch_size, #
single_cls=False, #
hyp=None, #
augment=False,
cache=False, #
image_weights=False, #
stride=32,
label_path=None,
pad=0.0,
rect=False,
rank=-1,
workers=8,
prefix="",
shuffle=False,
use_segments=False,
use_keypoints=False,
):
if rect and shuffle:
def build_dataloader(cfg, batch_size, img_path, stride=32, label_path=None, rank=-1, mode="train"):
assert mode in ["train", "val"]
shuffle = mode == "train"
if cfg.rect and shuffle:
LOGGER.warning("WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False")
shuffle = False
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = YOLODataset(
img_path=img_path,
img_size=img_size,
batch_size=batch_size,
label_path=label_path,
augment=augment, # augmentation
hyp=hyp,
rect=rect, # rectangular batches
cache=cache,
single_cls=single_cls,
img_size=cfg.img_size,
batch_size=batch_size,
augment=True if mode == "train" else False, # augmentation
hyp=cfg.get("augment_hyp", None),
rect=cfg.rect if mode == "train" else True, # rectangular batches
cache=None if cfg.noval else cfg.get("cache", None),
single_cls=cfg.get("single_cls", False),
stride=int(stride),
pad=pad,
prefix=prefix,
use_segments=use_segments,
use_keypoints=use_keypoints,
pad=0.0 if mode == "train" else 0.5,
prefix=colorstr(f"{mode}: "),
use_segments=cfg.task == "segment",
use_keypoints=cfg.task == "keypoint",
)
batch_size = min(batch_size, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices
workers = cfg.workers if mode == "train" else cfg.workers * 2
nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers]) # number of workers
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
loader = DataLoader if cfg.image_weights else InfiniteDataLoader # only DataLoader allows for attribute updates
generator = torch.Generator()
generator.manual_seed(6148914691236517205 + RANK)
return (
@ -118,6 +101,7 @@ def build_dataloader(
# build classification
# TODO: using cfg like `build_dataloader`
def build_classification_dataloader(path,
imgsz=224,
batch_size=16,