Update .pre-commit-config.yaml (#1026)

This commit is contained in:
Glenn Jocher 2023-02-17 22:26:40 +01:00 committed by GitHub
parent 9047d737f4
commit edd3ff1669
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
76 changed files with 928 additions and 935 deletions

View file

@ -62,7 +62,7 @@ class BaseValidator:
self.jdict = None
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
name = self.args.name or f"{self.args.mode}"
name = self.args.name or f'{self.args.mode}'
self.save_dir = save_dir or increment_path(Path(project) / name,
exist_ok=self.args.exist_ok if RANK in {-1, 0} else True)
(self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
@ -92,7 +92,7 @@ class BaseValidator:
else:
callbacks.add_integration_callbacks(self)
self.run_callbacks('on_val_start')
assert model is not None, "Either trainer or model is needed for validation"
assert model is not None, 'Either trainer or model is needed for validation'
self.device = select_device(self.args.device, self.args.batch)
self.args.half &= self.device.type != 'cpu'
model = AutoBackend(model, device=self.device, dnn=self.args.dnn, data=self.args.data, fp16=self.args.half)
@ -108,7 +108,7 @@ class BaseValidator:
self.logger.info(
f'Forcing --batch-size 1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
if isinstance(self.args.data, str) and self.args.data.endswith(".yaml"):
if isinstance(self.args.data, str) and self.args.data.endswith('.yaml'):
self.data = check_det_dataset(self.args.data)
elif self.args.task == 'classify':
self.data = check_cls_dataset(self.args.data)
@ -142,7 +142,7 @@ class BaseValidator:
# inference
with dt[1]:
preds = model(batch["img"])
preds = model(batch['img'])
# loss
with dt[2]:
@ -166,14 +166,14 @@ class BaseValidator:
self.run_callbacks('on_val_end')
if self.training:
model.float()
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
else:
self.logger.info('Speed: %.1fms pre-process, %.1fms inference, %.1fms loss, %.1fms post-process per image' %
self.speed)
if self.args.save_json and self.jdict:
with open(str(self.save_dir / "predictions.json"), 'w') as f:
self.logger.info(f"Saving {f.name}...")
with open(str(self.save_dir / 'predictions.json'), 'w') as f:
self.logger.info(f'Saving {f.name}...')
json.dump(self.jdict, f) # flatten and save
stats = self.eval_json(stats) # update stats
return stats
@ -183,7 +183,7 @@ class BaseValidator:
callback(self)
def get_dataloader(self, dataset_path, batch_size):
raise NotImplementedError("get_dataloader function not implemented for this validator")
raise NotImplementedError('get_dataloader function not implemented for this validator')
def preprocess(self, batch):
return batch