Update .pre-commit-config.yaml (#1026)
This commit is contained in:
parent
9047d737f4
commit
edd3ff1669
76 changed files with 928 additions and 935 deletions
|
|
@ -16,13 +16,13 @@ from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
|||
|
||||
# Map head to model, trainer, validator, and predictor classes
|
||||
MODEL_MAP = {
|
||||
"classify": [
|
||||
'classify': [
|
||||
ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator',
|
||||
'yolo.TYPE.classify.ClassificationPredictor'],
|
||||
"detect": [
|
||||
'detect': [
|
||||
DetectionModel, 'yolo.TYPE.detect.DetectionTrainer', 'yolo.TYPE.detect.DetectionValidator',
|
||||
'yolo.TYPE.detect.DetectionPredictor'],
|
||||
"segment": [
|
||||
'segment': [
|
||||
SegmentationModel, 'yolo.TYPE.segment.SegmentationTrainer', 'yolo.TYPE.segment.SegmentationValidator',
|
||||
'yolo.TYPE.segment.SegmentationPredictor']}
|
||||
|
||||
|
|
@ -34,7 +34,7 @@ class YOLO:
|
|||
A python interface which emulates a model-like behaviour by wrapping trainers.
|
||||
"""
|
||||
|
||||
def __init__(self, model='yolov8n.pt', type="v8") -> None:
|
||||
def __init__(self, model='yolov8n.pt', type='v8') -> None:
|
||||
"""
|
||||
Initializes the YOLO object.
|
||||
|
||||
|
|
@ -94,7 +94,7 @@ class YOLO:
|
|||
suffix = Path(weights).suffix
|
||||
if suffix == '.pt':
|
||||
self.model, self.ckpt = attempt_load_one_weight(weights)
|
||||
self.task = self.model.args["task"]
|
||||
self.task = self.model.args['task']
|
||||
self.overrides = self.model.args
|
||||
self._reset_ckpt_args(self.overrides)
|
||||
else:
|
||||
|
|
@ -111,7 +111,7 @@ class YOLO:
|
|||
"""
|
||||
if not isinstance(self.model, nn.Module):
|
||||
raise TypeError(f"model='{self.model}' must be a *.pt PyTorch model, but is a different type. "
|
||||
f"PyTorch models can be used to train, val, predict and export, i.e. "
|
||||
f'PyTorch models can be used to train, val, predict and export, i.e. '
|
||||
f"'yolo export model=yolov8n.pt', but exported formats like ONNX, TensorRT etc. only "
|
||||
f"support 'predict' and 'val' modes, i.e. 'yolo predict model=yolov8n.onnx'.")
|
||||
|
||||
|
|
@ -155,11 +155,11 @@ class YOLO:
|
|||
(List[ultralytics.yolo.engine.results.Results]): The prediction results.
|
||||
"""
|
||||
overrides = self.overrides.copy()
|
||||
overrides["conf"] = 0.25
|
||||
overrides['conf'] = 0.25
|
||||
overrides.update(kwargs)
|
||||
overrides["mode"] = kwargs.get("mode", "predict")
|
||||
assert overrides["mode"] in ['track', 'predict']
|
||||
overrides["save"] = kwargs.get("save", False) # not save files by default
|
||||
overrides['mode'] = kwargs.get('mode', 'predict')
|
||||
assert overrides['mode'] in ['track', 'predict']
|
||||
overrides['save'] = kwargs.get('save', False) # not save files by default
|
||||
if not self.predictor:
|
||||
self.predictor = self.PredictorClass(overrides=overrides)
|
||||
self.predictor.setup_model(model=self.model)
|
||||
|
|
@ -173,7 +173,7 @@ class YOLO:
|
|||
from ultralytics.tracker.track import register_tracker
|
||||
register_tracker(self)
|
||||
# bytetrack-based method needs low confidence predictions as input
|
||||
conf = kwargs.get("conf") or 0.1
|
||||
conf = kwargs.get('conf') or 0.1
|
||||
kwargs['conf'] = conf
|
||||
kwargs['mode'] = 'track'
|
||||
return self.predict(source=source, stream=stream, **kwargs)
|
||||
|
|
@ -188,9 +188,9 @@ class YOLO:
|
|||
**kwargs : Any other args accepted by the validators. To see all args check 'configuration' section in docs
|
||||
"""
|
||||
overrides = self.overrides.copy()
|
||||
overrides["rect"] = True # rect batches as default
|
||||
overrides['rect'] = True # rect batches as default
|
||||
overrides.update(kwargs)
|
||||
overrides["mode"] = "val"
|
||||
overrides['mode'] = 'val'
|
||||
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
|
||||
args.data = data or args.data
|
||||
args.task = self.task
|
||||
|
|
@ -234,18 +234,18 @@ class YOLO:
|
|||
self._check_is_pytorch_model()
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
if kwargs.get("cfg"):
|
||||
if kwargs.get('cfg'):
|
||||
LOGGER.info(f"cfg file passed. Overriding default params with {kwargs['cfg']}.")
|
||||
overrides = yaml_load(check_yaml(kwargs["cfg"]), append_filename=True)
|
||||
overrides["task"] = self.task
|
||||
overrides["mode"] = "train"
|
||||
if not overrides.get("data"):
|
||||
overrides = yaml_load(check_yaml(kwargs['cfg']), append_filename=True)
|
||||
overrides['task'] = self.task
|
||||
overrides['mode'] = 'train'
|
||||
if not overrides.get('data'):
|
||||
raise AttributeError("Dataset required but missing, i.e. pass 'data=coco128.yaml'")
|
||||
if overrides.get("resume"):
|
||||
overrides["resume"] = self.ckpt_path
|
||||
if overrides.get('resume'):
|
||||
overrides['resume'] = self.ckpt_path
|
||||
|
||||
self.trainer = self.TrainerClass(overrides=overrides)
|
||||
if not overrides.get("resume"): # manually set model only if not resuming
|
||||
if not overrides.get('resume'): # manually set model only if not resuming
|
||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||
self.model = self.trainer.model
|
||||
self.trainer.train()
|
||||
|
|
@ -267,9 +267,9 @@ class YOLO:
|
|||
|
||||
def _assign_ops_from_task(self):
|
||||
model_class, train_lit, val_lit, pred_lit = MODEL_MAP[self.task]
|
||||
trainer_class = eval(train_lit.replace("TYPE", f"{self.type}"))
|
||||
validator_class = eval(val_lit.replace("TYPE", f"{self.type}"))
|
||||
predictor_class = eval(pred_lit.replace("TYPE", f"{self.type}"))
|
||||
trainer_class = eval(train_lit.replace('TYPE', f'{self.type}'))
|
||||
validator_class = eval(val_lit.replace('TYPE', f'{self.type}'))
|
||||
predictor_class = eval(pred_lit.replace('TYPE', f'{self.type}'))
|
||||
return model_class, trainer_class, validator_class, predictor_class
|
||||
|
||||
@property
|
||||
|
|
@ -292,7 +292,7 @@ class YOLO:
|
|||
Returns metrics if computed
|
||||
"""
|
||||
if not self.metrics_data:
|
||||
LOGGER.info("No metrics data found! Run training or validation operation first.")
|
||||
LOGGER.info('No metrics data found! Run training or validation operation first.')
|
||||
|
||||
return self.metrics_data
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue