Docstring additions (#122)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2022-12-31 13:42:45 +01:00 committed by GitHub
parent c9f3e469cb
commit df4fc14c10
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 291 additions and 73 deletions

View file

@ -1,6 +1,6 @@
import torch
from ultralytics import yolo # noqa required for python usage
from ultralytics import yolo # noqa
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights
from ultralytics.yolo.configs import get_config
from ultralytics.yolo.engine.exporter import Exporter
@ -9,7 +9,7 @@ from ultralytics.yolo.utils.checks import check_imgsz, check_yaml
from ultralytics.yolo.utils.files import yaml_load
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
# map head: [model, trainer, validator, predictor]
# Map head to model, trainer, validator, and predictor classes
MODEL_MAP = {
"classify": [
ClassificationModel, 'yolo.TYPE.classify.ClassificationTrainer', 'yolo.TYPE.classify.ClassificationValidator',
@ -24,39 +24,44 @@ MODEL_MAP = {
class YOLO:
"""
Python interface which emulates a model-like behaviour by wrapping trainers.
YOLO
A python interface which emulates a model-like behaviour by wrapping trainers.
"""
__init_key = object()
__init_key = object() # used to ensure proper initialization
def __init__(self, init_key=None, type="v8") -> None:
"""
Initializes the YOLO object.
Args:
type (str): Type/version of models to use
init_key (object): used to ensure proper initialization. Defaults to None.
type (str): Type/version of models to use. Defaults to "v8".
"""
if init_key != YOLO.__init_key:
raise SyntaxError(HELP_MSG)
self.type = type
self.ModelClass = None
self.TrainerClass = None
self.ValidatorClass = None
self.PredictorClass = None
self.model = None
self.trainer = None
self.task = None
self.ModelClass = None # model class
self.TrainerClass = None # trainer class
self.ValidatorClass = None # validator class
self.PredictorClass = None # predictor class
self.model = None # model object
self.trainer = None # trainer object
self.task = None # task type
self.ckpt = None # if loaded from *.pt
self.cfg = None # if loaded from *.yaml
self.overrides = {}
self.init_disabled = False
self.overrides = {} # overrides for trainer object
self.init_disabled = False # disable model initialization
@classmethod
def new(cls, cfg: str, verbose=True):
"""
Initializes a new model and infers the task type from the model definitions
Initializes a new model and infers the task type from the model definitions.
Args:
cfg (str): model configuration file
verbsoe (bool): display model info on load
verbose (bool): display model info on load
"""
cfg = check_yaml(cfg) # check YAML
cfg_dict = yaml_load(cfg) # model dict