Update .pre-commit-config.yaml (#1026)

This commit is contained in:
Glenn Jocher 2023-02-17 22:26:40 +01:00 committed by GitHub
parent 9047d737f4
commit edd3ff1669
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
76 changed files with 928 additions and 935 deletions

View file

@ -16,13 +16,13 @@ from ultralytics.yolo.utils.torch_utils import smart_inference_mode
# Map head to model, trainer, validator, and predictor classes
MODEL_MAP = {
"classify": [
'classify': [
ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator',
'yolo.TYPE.classify.ClassificationPredictor'],
"detect": [
'detect': [
DetectionModel, 'yolo.TYPE.detect.DetectionTrainer', 'yolo.TYPE.detect.DetectionValidator',
'yolo.TYPE.detect.DetectionPredictor'],
"segment": [
'segment': [
SegmentationModel, 'yolo.TYPE.segment.SegmentationTrainer', 'yolo.TYPE.segment.SegmentationValidator',
'yolo.TYPE.segment.SegmentationPredictor']}
@ -34,7 +34,7 @@ class YOLO:
A python interface which emulates a model-like behaviour by wrapping trainers.
"""
def __init__(self, model='yolov8n.pt', type="v8") -> None:
def __init__(self, model='yolov8n.pt', type='v8') -> None:
"""
Initializes the YOLO object.
@ -94,7 +94,7 @@ class YOLO:
suffix = Path(weights).suffix
if suffix == '.pt':
self.model, self.ckpt = attempt_load_one_weight(weights)
self.task = self.model.args["task"]
self.task = self.model.args['task']
self.overrides = self.model.args
self._reset_ckpt_args(self.overrides)
else:
@ -111,7 +111,7 @@ class YOLO:
"""
if not isinstance(self.model, nn.Module):
raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
f"PyTorch models can be used to train, val, predict and export, i.e. "
f'PyTorch models can be used to train, val, predict and export, i.e. '
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
@ -155,11 +155,11 @@ class YOLO:
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
"""
overrides = self.overrides.copy()
overrides["conf"] = 0.25
overrides['conf'] = 0.25
overrides.update(kwargs)
overrides["mode"] = kwargs.get("mode", "predict")
assert overrides["mode"] in ['track', 'predict']
overrides["save"] = kwargs.get("save", False) # not save files by default
overrides['mode'] = kwargs.get('mode', 'predict')
assert overrides['mode'] in ['track', 'predict']
overrides['save'] = kwargs.get('save', False) # not save files by default
if not self.predictor:
self.predictor = self.PredictorClass(overrides=overrides)
self.predictor.setup_model(model=self.model)
@ -173,7 +173,7 @@ class YOLO:
from ultralytics.tracker.track import register_tracker
register_tracker(self)
# bytetrack-based method needs low confidence predictions as input
conf = kwargs.get("conf") or 0.1
conf = kwargs.get('conf') or 0.1
kwargs['conf'] = conf
kwargs['mode'] = 'track'
return self.predict(source=source, stream=stream, **kwargs)
@ -188,9 +188,9 @@ class YOLO:
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
"""
overrides = self.overrides.copy()
overrides["rect"] = True # rect batches as default
overrides['rect'] = True # rect batches as default
overrides.update(kwargs)
overrides["mode"] = "val"
overrides['mode'] = 'val'
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.data = data or args.data
args.task = self.task
@ -234,18 +234,18 @@ class YOLO:
self._check_is_pytorch_model()
overrides = self.overrides.copy()
overrides.update(kwargs)
if kwargs.get("cfg"):
if kwargs.get('cfg'):
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
overrides = yaml_load(check_yaml(kwargs["cfg"]), append_filename=True)
overrides["task"] = self.task
overrides["mode"] = "train"
if not overrides.get("data"):
overrides = yaml_load(check_yaml(kwargs['cfg']), append_filename=True)
overrides['task'] = self.task
overrides['mode'] = 'train'
if not overrides.get('data'):
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
if overrides.get("resume"):
overrides["resume"] = self.ckpt_path
if overrides.get('resume'):
overrides['resume'] = self.ckpt_path
self.trainer = self.TrainerClass(overrides=overrides)
if not overrides.get("resume"): # manually set model only if not resuming
if not overrides.get('resume'): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
self.trainer.train()
@ -267,9 +267,9 @@ class YOLO:
def _assign_ops_from_task(self):
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task]
trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
validator_class = eval(val_lit.replace("TYPE", f"{self.type}"))
predictor_class = eval(pred_lit.replace("TYPE", f"{self.type}"))
trainer_class = eval(train_lit.replace('TYPE', f'{self.type}'))
validator_class = eval(val_lit.replace('TYPE', f'{self.type}'))
predictor_class = eval(pred_lit.replace('TYPE', f'{self.type}'))
return model_class, trainer_class, validator_class, predictor_class
@property
@ -292,7 +292,7 @@ class YOLO:
Returns metrics if computed
"""
if not self.metrics_data:
LOGGER.info("No metrics data found! Run training or validation operation first.")
LOGGER.info('No metrics data found! Run training or validation operation first.')
return self.metrics_data