ultralytics 8.0.47 Docker and reformat updates (#1153)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
d4be4cb24b
commit
a58f766f94
41 changed files with 224 additions and 201 deletions
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from ultralytics import yolo # noqa
|
||||
from ultralytics.nn.tasks import (ClassificationModel, DetectionModel, SegmentationModel, attempt_load_one_weight,
|
||||
|
|
@ -68,7 +67,7 @@ class YOLO:
|
|||
list(ultralytics.yolo.engine.results.Results): The prediction results.
|
||||
"""
|
||||
|
||||
def __init__(self, model='yolov8n.pt') -> None:
|
||||
def __init__(self, model='yolov8n.pt', task=None) -> None:
|
||||
"""
|
||||
Initializes the YOLO model.
|
||||
|
||||
|
|
@ -91,9 +90,9 @@ class YOLO:
|
|||
if not suffix and Path(model).stem in GITHUB_ASSET_STEMS:
|
||||
model, suffix = Path(model).with_suffix('.pt'), '.pt' # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||
if suffix == '.yaml':
|
||||
self._new(model)
|
||||
self._new(model, task)
|
||||
else:
|
||||
self._load(model)
|
||||
self._load(model, task)
|
||||
|
||||
def __call__(self, source=None, stream=False, **kwargs):
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
|
@ -102,17 +101,18 @@ class YOLO:
|
|||
name = self.__class__.__name__
|
||||
raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
def _new(self, cfg: str, verbose=True):
|
||||
def _new(self, cfg: str, task=None, verbose=True):
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model definitions.
|
||||
|
||||
Args:
|
||||
cfg (str): model configuration file
|
||||
task (str) or (None): model task
|
||||
verbose (bool): display model info on load
|
||||
"""
|
||||
self.cfg = check_yaml(cfg) # check YAML
|
||||
cfg_dict = yaml_load(self.cfg, append_filename=True) # model dict
|
||||
self.task = guess_model_task(cfg_dict)
|
||||
self.task = task or guess_model_task(cfg_dict)
|
||||
self.model = TASK_MAP[self.task][0](cfg_dict, verbose=verbose and RANK == -1) # build model
|
||||
self.overrides['model'] = self.cfg
|
||||
|
||||
|
|
@ -121,12 +121,13 @@ class YOLO:
|
|||
self.model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
||||
self.model.task = self.task
|
||||
|
||||
def _load(self, weights: str, task=''):
|
||||
def _load(self, weights: str, task=None):
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model head.
|
||||
|
||||
Args:
|
||||
weights (str): model checkpoint to be loaded
|
||||
task (str) or (None): model task
|
||||
"""
|
||||
suffix = Path(weights).suffix
|
||||
if suffix == '.pt':
|
||||
|
|
@ -137,7 +138,7 @@ class YOLO:
|
|||
else:
|
||||
weights = check_file(weights)
|
||||
self.model, self.ckpt = weights, None
|
||||
self.task = guess_model_task(weights)
|
||||
self.task = task or guess_model_task(weights)
|
||||
self.ckpt_path = weights
|
||||
self.overrides['model'] = weights
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue