Update .pre-commit-config.yaml (#1026)

This commit is contained in:
Glenn Jocher 2023-02-17 22:26:40 +01:00 committed by GitHub
parent 9047d737f4
commit edd3ff1669
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
76 changed files with 928 additions and 935 deletions

View file

@ -28,7 +28,7 @@ class InfiniteDataLoader(dataloader.DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
self.iterator = super().__iter__()
def __len__(self):
@ -61,9 +61,9 @@ def seed_worker(worker_id):
random.seed(worker_seed)
def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, rank=-1, mode="train"):
assert mode in ["train", "val"]
shuffle = mode == "train"
def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, rank=-1, mode='train'):
assert mode in ['train', 'val']
shuffle = mode == 'train'
if cfg.rect and shuffle:
LOGGER.warning("WARNING ⚠️ 'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
shuffle = False
@ -72,21 +72,21 @@ def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, ra
img_path=img_path,
imgsz=cfg.imgsz,
batch_size=batch,
augment=mode == "train", # augmentation
augment=mode == 'train', # augmentation
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
rect=cfg.rect or rect, # rectangular batches
cache=cfg.cache or None,
single_cls=cfg.single_cls or False,
stride=int(stride),
pad=0.0 if mode == "train" else 0.5,
prefix=colorstr(f"{mode}: "),
use_segments=cfg.task == "segment",
use_keypoints=cfg.task == "keypoint",
pad=0.0 if mode == 'train' else 0.5,
prefix=colorstr(f'{mode}: '),
use_segments=cfg.task == 'segment',
use_keypoints=cfg.task == 'keypoint',
names=names)
batch = min(batch, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices
workers = cfg.workers if mode == "train" else cfg.workers * 2
workers = cfg.workers if mode == 'train' else cfg.workers * 2
nw = min([os.cpu_count() // max(nd, 1), batch if batch > 1 else 0, workers]) # number of workers
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
loader = DataLoader if cfg.image_weights or cfg.close_mosaic else InfiniteDataLoader # allow attribute updates
@ -98,7 +98,7 @@ def build_dataloader(cfg, batch, img_path, stride=32, rect=False, names=None, ra
num_workers=nw,
sampler=sampler,
pin_memory=PIN_MEMORY,
collate_fn=getattr(dataset, "collate_fn", None),
collate_fn=getattr(dataset, 'collate_fn', None),
worker_init_fn=seed_worker,
generator=generator), dataset
@ -151,7 +151,7 @@ def check_source(source):
from_img = True
else:
raise Exception(
"Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict")
'Unsupported type encountered! See docs for supported types https://docs.ultralytics.com/predict')
return source, webcam, screenshot, from_img, in_memory