Expand Model method type hinting (#8279)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
fbed8499da
commit
42744a1717
6 changed files with 104 additions and 50 deletions
|
|
@ -5,6 +5,10 @@ import sys
|
|||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import PIL
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
||||
from ultralytics.hub.utils import HUB_WEB_ROOT
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
||||
|
|
@ -78,7 +82,12 @@ class Model(nn.Module):
|
|||
NotImplementedError: If a specific model task or mode is not supported.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Union[str, Path] = "yolov8n.pt", task=None, verbose=False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[str, Path] = "yolov8n.pt",
|
||||
task: str = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes a new instance of the YOLO model class.
|
||||
|
||||
|
|
@ -135,7 +144,12 @@ class Model(nn.Module):
|
|||
|
||||
self.model_name = model
|
||||
|
||||
def __call__(self, source=None, stream=False, **kwargs):
|
||||
def __call__(
|
||||
self,
|
||||
source: Union[str, Path, int, list, tuple, PIL.Image.Image, np.ndarray, torch.Tensor] = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
"""
|
||||
An alias for the predict method, enabling the model instance to be callable.
|
||||
|
||||
|
|
@ -143,8 +157,9 @@ class Model(nn.Module):
|
|||
with the required arguments for prediction.
|
||||
|
||||
Args:
|
||||
source (str | int | PIL.Image | np.ndarray, optional): The source of the image for making predictions.
|
||||
Accepts various types, including file paths, URLs, PIL images, and numpy arrays. Defaults to None.
|
||||
source (str | Path | int | PIL.Image | np.ndarray, optional): The source of the image for making
|
||||
predictions. Accepts various types, including file paths, URLs, PIL images, and numpy arrays.
|
||||
Defaults to None.
|
||||
stream (bool, optional): If True, treats the input source as a continuous stream for predictions.
|
||||
Defaults to False.
|
||||
**kwargs (dict): Additional keyword arguments for configuring the prediction process.
|
||||
|
|
@ -163,7 +178,7 @@ class Model(nn.Module):
|
|||
return session if session.client.authenticated else None
|
||||
|
||||
@staticmethod
|
||||
def is_triton_model(model):
|
||||
def is_triton_model(model: str) -> bool:
|
||||
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
|
|
@ -171,7 +186,7 @@ class Model(nn.Module):
|
|||
return url.netloc and url.path and url.scheme in {"http", "grpc"}
|
||||
|
||||
@staticmethod
|
||||
def is_hub_model(model):
|
||||
def is_hub_model(model: str) -> bool:
|
||||
"""Check if the provided model is a HUB model."""
|
||||
return any(
|
||||
(
|
||||
|
|
@ -181,7 +196,7 @@ class Model(nn.Module):
|
|||
)
|
||||
)
|
||||
|
||||
def _new(self, cfg: str, task=None, model=None, verbose=False):
|
||||
def _new(self, cfg: str, task=None, model=None, verbose=False) -> None:
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model definitions.
|
||||
|
||||
|
|
@ -202,7 +217,7 @@ class Model(nn.Module):
|
|||
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
|
||||
self.model.task = self.task
|
||||
|
||||
def _load(self, weights: str, task=None):
|
||||
def _load(self, weights: str, task=None) -> None:
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model head.
|
||||
|
||||
|
|
@ -224,7 +239,7 @@ class Model(nn.Module):
|
|||
self.overrides["model"] = weights
|
||||
self.overrides["task"] = self.task
|
||||
|
||||
def _check_is_pytorch_model(self):
|
||||
def _check_is_pytorch_model(self) -> None:
|
||||
"""Raises TypeError is model is not a PyTorch model."""
|
||||
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
|
||||
pt_module = isinstance(self.model, nn.Module)
|
||||
|
|
@ -237,7 +252,7 @@ class Model(nn.Module):
|
|||
f"argument directly in your inference command, i.e. 'model.predict(source=..., device=0)'"
|
||||
)
|
||||
|
||||
def reset_weights(self):
|
||||
def reset_weights(self) -> "Model":
|
||||
"""
|
||||
Resets the model parameters to randomly initialized values, effectively discarding all training information.
|
||||
|
||||
|
|
@ -259,7 +274,7 @@ class Model(nn.Module):
|
|||
p.requires_grad = True
|
||||
return self
|
||||
|
||||
def load(self, weights="yolov8n.pt"):
|
||||
def load(self, weights: Union[str, Path] = "yolov8n.pt") -> "Model":
|
||||
"""
|
||||
Loads parameters from the specified weights file into the model.
|
||||
|
||||
|
|
@ -281,24 +296,22 @@ class Model(nn.Module):
|
|||
self.model.load(weights)
|
||||
return self
|
||||
|
||||
def save(self, filename="model.pt"):
|
||||
def save(self, filename: Union[str, Path] = "saved_model.pt") -> None:
|
||||
"""
|
||||
Saves the current model state to a file.
|
||||
|
||||
This method exports the model's checkpoint (ckpt) to the specified filename.
|
||||
|
||||
Args:
|
||||
filename (str): The name of the file to save the model to. Defaults to 'model.pt'.
|
||||
filename (str | Path): The name of the file to save the model to. Defaults to 'saved_model.pt'.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the model is not a PyTorch model.
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
import torch
|
||||
|
||||
torch.save(self.ckpt, filename)
|
||||
|
||||
def info(self, detailed=False, verbose=True):
|
||||
def info(self, detailed: bool = False, verbose: bool = True):
|
||||
"""
|
||||
Logs or returns model information.
|
||||
|
||||
|
|
@ -330,7 +343,12 @@ class Model(nn.Module):
|
|||
self._check_is_pytorch_model()
|
||||
self.model.fuse()
|
||||
|
||||
def embed(self, source=None, stream=False, **kwargs):
|
||||
def embed(
|
||||
self,
|
||||
source: Union[str, Path, int, list, tuple, PIL.Image.Image, np.ndarray, torch.Tensor] = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
"""
|
||||
Generates image embeddings based on the provided source.
|
||||
|
||||
|
|
@ -353,7 +371,13 @@ class Model(nn.Module):
|
|||
kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
||||
def predict(self, source=None, stream=False, predictor=None, **kwargs):
|
||||
def predict(
|
||||
self,
|
||||
source: Union[str, Path, int, list, tuple, PIL.Image.Image, np.ndarray, torch.Tensor] = None,
|
||||
stream: bool = False,
|
||||
predictor=None,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
"""
|
||||
Performs predictions on the given image source using the YOLO model.
|
||||
|
||||
|
|
@ -405,7 +429,13 @@ class Model(nn.Module):
|
|||
self.predictor.set_prompts(prompts)
|
||||
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
||||
|
||||
def track(self, source=None, stream=False, persist=False, **kwargs):
|
||||
def track(
|
||||
self,
|
||||
source: Union[str, Path, int, list, tuple, PIL.Image.Image, np.ndarray, torch.Tensor] = None,
|
||||
stream: bool = False,
|
||||
persist: bool = False,
|
||||
**kwargs,
|
||||
) -> list:
|
||||
"""
|
||||
Conducts object tracking on the specified input source using the registered trackers.
|
||||
|
||||
|
|
@ -438,7 +468,11 @@ class Model(nn.Module):
|
|||
kwargs["mode"] = "track"
|
||||
return self.predict(source=source, stream=stream, **kwargs)
|
||||
|
||||
def val(self, validator=None, **kwargs):
|
||||
def val(
|
||||
self,
|
||||
validator=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Validates the model using a specified dataset and validation configuration.
|
||||
|
||||
|
|
@ -471,7 +505,10 @@ class Model(nn.Module):
|
|||
self.metrics = validator.metrics
|
||||
return validator.metrics
|
||||
|
||||
def benchmark(self, **kwargs):
|
||||
def benchmark(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Benchmarks the model across various export formats to evaluate performance.
|
||||
|
||||
|
|
@ -509,7 +546,10 @@ class Model(nn.Module):
|
|||
verbose=kwargs.get("verbose"),
|
||||
)
|
||||
|
||||
def export(self, **kwargs):
|
||||
def export(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Exports the model to a different format suitable for deployment.
|
||||
|
||||
|
|
@ -537,7 +577,11 @@ class Model(nn.Module):
|
|||
args = {**self.overrides, **custom, **kwargs, "mode": "export"} # highest priority args on the right
|
||||
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
||||
|
||||
def train(self, trainer=None, **kwargs):
|
||||
def train(
|
||||
self,
|
||||
trainer=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Trains the model using the specified dataset and training configuration.
|
||||
|
||||
|
|
@ -607,7 +651,13 @@ class Model(nn.Module):
|
|||
self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
|
||||
return self.metrics
|
||||
|
||||
def tune(self, use_ray=False, iterations=10, *args, **kwargs):
|
||||
def tune(
|
||||
self,
|
||||
use_ray=False,
|
||||
iterations=10,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Conducts hyperparameter tuning for the model, with an option to use Ray Tune.
|
||||
|
||||
|
|
@ -640,7 +690,7 @@ class Model(nn.Module):
|
|||
args = {**self.overrides, **custom, **kwargs, "mode": "train"} # highest priority args on the right
|
||||
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
|
||||
|
||||
def _apply(self, fn):
|
||||
def _apply(self, fn) -> "Model":
|
||||
"""Apply to(), cpu(), cuda(), half(), float() to model tensors that are not parameters or registered buffers."""
|
||||
self._check_is_pytorch_model()
|
||||
self = super()._apply(fn) # noqa
|
||||
|
|
@ -649,7 +699,7 @@ class Model(nn.Module):
|
|||
return self
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
def names(self) -> list:
|
||||
"""
|
||||
Retrieves the class names associated with the loaded model.
|
||||
|
||||
|
|
@ -664,7 +714,7 @@ class Model(nn.Module):
|
|||
return check_class_names(self.model.names) if hasattr(self.model, "names") else None
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
def device(self) -> torch.device:
|
||||
"""
|
||||
Retrieves the device on which the model's parameters are allocated.
|
||||
|
||||
|
|
@ -688,7 +738,7 @@ class Model(nn.Module):
|
|||
"""
|
||||
return self.model.transforms if hasattr(self.model, "transforms") else None
|
||||
|
||||
def add_callback(self, event: str, func):
|
||||
def add_callback(self, event: str, func) -> None:
|
||||
"""
|
||||
Adds a callback function for a specified event.
|
||||
|
||||
|
|
@ -704,7 +754,7 @@ class Model(nn.Module):
|
|||
"""
|
||||
self.callbacks[event].append(func)
|
||||
|
||||
def clear_callback(self, event: str):
|
||||
def clear_callback(self, event: str) -> None:
|
||||
"""
|
||||
Clears all callback functions registered for a specified event.
|
||||
|
||||
|
|
@ -718,7 +768,7 @@ class Model(nn.Module):
|
|||
"""
|
||||
self.callbacks[event] = []
|
||||
|
||||
def reset_callbacks(self):
|
||||
def reset_callbacks(self) -> None:
|
||||
"""
|
||||
Resets all callbacks to their default functions.
|
||||
|
||||
|
|
@ -729,7 +779,7 @@ class Model(nn.Module):
|
|||
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
||||
|
||||
@staticmethod
|
||||
def _reset_ckpt_args(args):
|
||||
def _reset_ckpt_args(args: dict) -> dict:
|
||||
"""Reset arguments when loading a PyTorch model."""
|
||||
include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model
|
||||
return {k: v for k, v in args.items() if k in include}
|
||||
|
|
@ -739,7 +789,7 @@ class Model(nn.Module):
|
|||
# name = self.__class__.__name__
|
||||
# raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
def _smart_load(self, key):
|
||||
def _smart_load(self, key: str):
|
||||
"""Load model/trainer/validator/predictor."""
|
||||
try:
|
||||
return self.task_map[self.task][key]
|
||||
|
|
@ -751,7 +801,7 @@ class Model(nn.Module):
|
|||
) from e
|
||||
|
||||
@property
|
||||
def task_map(self):
|
||||
def task_map(self) -> dict:
|
||||
"""
|
||||
Map head to model, trainer, validator, and predictor classes.
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue