Engine Model and Results Docs improvements (#14564)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Felipe Parodi <fparodi@pennmedicine.upenn.edu>
This commit is contained in:
parent
e59376b55f
commit
291883a23f
123 changed files with 3789 additions and 1368 deletions
|
|
@ -187,11 +187,11 @@ CFG_BOOL_KEYS = { # boolean-only arguments
|
|||
|
||||
def cfg2dict(cfg):
|
||||
"""
|
||||
Convert a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
|
||||
Converts a configuration object to a dictionary.
|
||||
|
||||
Args:
|
||||
cfg (str | Path | Dict | SimpleNamespace): Configuration object to be converted to a dictionary. This may be a
|
||||
path to a configuration file, a dictionary, or a SimpleNamespace object.
|
||||
cfg (str | Path | Dict | SimpleNamespace): Configuration object to be converted. Can be a file path,
|
||||
a string, a dictionary, or a SimpleNamespace object.
|
||||
|
||||
Returns:
|
||||
(Dict): Configuration object in dictionary format.
|
||||
|
|
@ -209,8 +209,9 @@ def cfg2dict(cfg):
|
|||
>>> config_dict = cfg2dict({'param1': 'value1', 'param2': 'value2'})
|
||||
|
||||
Notes:
|
||||
- If `cfg` is a path or a string, it will be loaded as YAML and converted to a dictionary.
|
||||
- If `cfg` is a SimpleNamespace object, it will be converted to a dictionary using `vars()`.
|
||||
- If cfg is a path or string, it's loaded as YAML and converted to a dictionary.
|
||||
- If cfg is a SimpleNamespace object, it's converted to a dictionary using vars().
|
||||
- If cfg is already a dictionary, it's returned unchanged.
|
||||
"""
|
||||
if isinstance(cfg, (str, Path)):
|
||||
cfg = yaml_load(cfg) # load dict
|
||||
|
|
@ -224,24 +225,23 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
|
|||
Load and merge configuration data from a file or dictionary, with optional overrides.
|
||||
|
||||
Args:
|
||||
cfg (str | Path | Dict | SimpleNamespace): Configuration data source.
|
||||
cfg (str | Path | Dict | SimpleNamespace): Configuration data source. Can be a file path, dictionary, or
|
||||
SimpleNamespace object.
|
||||
overrides (Dict | None): Dictionary containing key-value pairs to override the base configuration.
|
||||
|
||||
Returns:
|
||||
(SimpleNamespace): Namespace containing the merged training arguments.
|
||||
(SimpleNamespace): Namespace containing the merged configuration arguments.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics import get_cfg
|
||||
>>> config = get_cfg() # Load default configuration
|
||||
>>> config = get_cfg('path/to/config.yaml', overrides={'epochs': 50, 'batch_size': 16})
|
||||
|
||||
Notes:
|
||||
- If both `cfg` and `overrides` are provided, the values in `overrides` will take precedence.
|
||||
- Special handling ensures alignment and correctness of the configuration, such as converting numeric `project`
|
||||
and `name` to strings and validating configuration keys and values.
|
||||
|
||||
Examples:
|
||||
Load default configuration:
|
||||
>>> from ultralytics import get_cfg
|
||||
>>> config = get_cfg()
|
||||
|
||||
Load from a custom file with overrides:
|
||||
>>> config = get_cfg('path/to/config.yaml', overrides={'epochs': 50, 'batch_size': 16})
|
||||
- Special handling ensures alignment and correctness of the configuration, such as converting numeric
|
||||
`project` and `name` to strings and validating configuration keys and values.
|
||||
- The function performs type and value checks on the configuration data.
|
||||
"""
|
||||
cfg = cfg2dict(cfg)
|
||||
|
||||
|
|
@ -270,24 +270,31 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
|
|||
|
||||
def check_cfg(cfg, hard=True):
|
||||
"""
|
||||
Checks configuration argument types and values for the Ultralytics library, ensuring correctness and converting them
|
||||
if necessary.
|
||||
Checks configuration argument types and values for the Ultralytics library.
|
||||
|
||||
This function validates the types and values of configuration arguments, ensuring correctness and converting
|
||||
them if necessary. It checks for specific key types defined in global variables such as CFG_FLOAT_KEYS,
|
||||
CFG_FRACTION_KEYS, CFG_INT_KEYS, and CFG_BOOL_KEYS.
|
||||
|
||||
Args:
|
||||
cfg (Dict): Configuration dictionary to validate.
|
||||
hard (bool): If True, raises exceptions for invalid types and values; if False, attempts to convert them.
|
||||
|
||||
Examples:
|
||||
Validate a configuration with a mix of valid and invalid values:
|
||||
>>> config = {
|
||||
... 'epochs': 50, # valid integer
|
||||
... 'lr0': 0.01, # valid float
|
||||
... 'momentum': 1.2, # invalid float (out of 0.0-1.0 range)
|
||||
... 'save': 'true', # invalid bool
|
||||
... 'epochs': 50, # valid integer
|
||||
... 'lr0': 0.01, # valid float
|
||||
... 'momentum': 1.2, # invalid float (out of 0.0-1.0 range)
|
||||
... 'save': 'true', # invalid bool
|
||||
... }
|
||||
>>> check_cfg(config, hard=False)
|
||||
>>> print(config)
|
||||
{'epochs': 50, 'lr0': 0.01, 'momentum': 1.2, 'save': False} # corrected 'save' key and retained other values
|
||||
{'epochs': 50, 'lr0': 0.01, 'momentum': 1.2, 'save': False} # corrected 'save' key
|
||||
|
||||
Notes:
|
||||
- The function modifies the input dictionary in-place.
|
||||
- None values are ignored as they may be from optional arguments.
|
||||
- Fraction keys are checked to be within the range [0.0, 1.0].
|
||||
"""
|
||||
for k, v in cfg.items():
|
||||
if v is not None: # None values may be from optional args
|
||||
|
|
@ -328,16 +335,15 @@ def get_save_dir(args, name=None):
|
|||
Returns the directory path for saving outputs, derived from arguments or default settings.
|
||||
|
||||
Args:
|
||||
args (SimpleNamespace): Namespace object containing configurations such as 'project', 'name', 'task', 'mode', and
|
||||
'save_dir'.
|
||||
name (str | None): Optional name for the output directory. If not provided, it defaults to 'args.name' or the
|
||||
'args.mode'.
|
||||
args (SimpleNamespace): Namespace object containing configurations such as 'project', 'name', 'task',
|
||||
'mode', and 'save_dir'.
|
||||
name (str | None): Optional name for the output directory. If not provided, it defaults to 'args.name'
|
||||
or the 'args.mode'.
|
||||
|
||||
Returns:
|
||||
(Path): Directory path where outputs should be saved.
|
||||
|
||||
Examples:
|
||||
Generate a save directory using provided arguments
|
||||
>>> from types import SimpleNamespace
|
||||
>>> args = SimpleNamespace(project='my_project', task='detect', mode='train', exist_ok=True)
|
||||
>>> save_dir = get_save_dir(args)
|
||||
|
|
@ -369,6 +375,11 @@ def _handle_deprecation(custom):
|
|||
>>> _handle_deprecation(custom_config)
|
||||
>>> print(custom_config)
|
||||
{'show_boxes': True, 'show_labels': True, 'line_width': 2}
|
||||
|
||||
Notes:
|
||||
This function modifies the input dictionary in-place, replacing deprecated keys with their current
|
||||
equivalents. It also handles value conversions where necessary, such as inverting boolean values for
|
||||
'hide_labels' and 'hide_conf'.
|
||||
"""
|
||||
|
||||
for key in custom.copy().keys():
|
||||
|
|
@ -390,32 +401,29 @@ def _handle_deprecation(custom):
|
|||
|
||||
def check_dict_alignment(base: Dict, custom: Dict, e=None):
|
||||
"""
|
||||
Check for key alignment between custom and base configuration dictionaries, handling deprecated keys and providing
|
||||
informative error messages for mismatched keys.
|
||||
Checks alignment between custom and base configuration dictionaries, handling deprecated keys and providing error
|
||||
messages for mismatched keys.
|
||||
|
||||
Args:
|
||||
base (Dict): The base configuration dictionary containing valid keys.
|
||||
custom (Dict): The custom configuration dictionary to be checked for alignment.
|
||||
e (Exception | None): Optional error instance passed by the calling function. Default is None.
|
||||
e (Exception | None): Optional error instance passed by the calling function.
|
||||
|
||||
Raises:
|
||||
SystemExit: Terminates the program execution if mismatched keys are found.
|
||||
|
||||
Notes:
|
||||
- The function suggests corrections for mismatched keys based on similarity to valid keys.
|
||||
- Deprecated keys in the custom configuration are automatically replaced with their updated equivalents.
|
||||
- Detailed error messages are printed for each mismatched key to help users identify and correct their custom
|
||||
configurations.
|
||||
SystemExit: If mismatched keys are found between the custom and base dictionaries.
|
||||
|
||||
Examples:
|
||||
>>> base_cfg = {'epochs': 50, 'lr0': 0.01, 'batch_size': 16}
|
||||
>>> custom_cfg = {'epoch': 100, 'lr': 0.02, 'batch_size': 32}
|
||||
|
||||
>>> try:
|
||||
... check_dict_alignment(base_cfg, custom_cfg)
|
||||
... except SystemExit:
|
||||
... # Handle the error or correct the configuration
|
||||
... pass
|
||||
... print("Mismatched keys found")
|
||||
|
||||
Notes:
|
||||
- Suggests corrections for mismatched keys based on similarity to valid keys.
|
||||
- Automatically replaces deprecated keys in the custom configuration with updated equivalents.
|
||||
- Prints detailed error messages for each mismatched key to help users correct their configurations.
|
||||
"""
|
||||
custom = _handle_deprecation(custom)
|
||||
base_keys, custom_keys = (set(x.keys()) for x in (base, custom))
|
||||
|
|
@ -434,7 +442,10 @@ def check_dict_alignment(base: Dict, custom: Dict, e=None):
|
|||
|
||||
def merge_equals_args(args: List[str]) -> List[str]:
|
||||
"""
|
||||
Merges arguments around isolated '=' in a list of strings.
|
||||
Merges arguments around isolated '=' in a list of strings, handling three cases:
|
||||
1. ['arg', '=', 'val'] becomes ['arg=val'],
|
||||
2. ['arg=', 'val'] becomes ['arg=val'],
|
||||
3. ['arg', '=val'] becomes ['arg=val'].
|
||||
|
||||
Args:
|
||||
args (List[str]): A list of strings where each element represents an argument.
|
||||
|
|
@ -443,20 +454,9 @@ def merge_equals_args(args: List[str]) -> List[str]:
|
|||
(List[str]): A list of strings where the arguments around isolated '=' are merged.
|
||||
|
||||
Examples:
|
||||
Merge arguments where equals sign is separated:
|
||||
>>> args = ["arg1", "=", "value"]
|
||||
>>> args = ["arg1", "=", "value", "arg2=", "value2", "arg3", "=value3"]
|
||||
>>> merge_equals_args(args)
|
||||
["arg1=value"]
|
||||
|
||||
Merge arguments where equals sign is at the end of the first argument:
|
||||
>>> args = ["arg1=", "value"]
|
||||
>>> merge_equals_args(args)
|
||||
["arg1=value"]
|
||||
|
||||
Merge arguments where equals sign is at the beginning of the second argument:
|
||||
>>> args = ["arg1", "=value"]
|
||||
>>> merge_equals_args(args)
|
||||
["arg1=value"]
|
||||
['arg1=value', 'arg2=value2', 'arg3=value3']
|
||||
"""
|
||||
new_args = []
|
||||
for i, arg in enumerate(args):
|
||||
|
|
@ -475,18 +475,24 @@ def merge_equals_args(args: List[str]) -> List[str]:
|
|||
|
||||
def handle_yolo_hub(args: List[str]) -> None:
|
||||
"""
|
||||
Handle Ultralytics HUB command-line interface (CLI) commands.
|
||||
Handles Ultralytics HUB command-line interface (CLI) commands for authentication.
|
||||
|
||||
This function processes Ultralytics HUB CLI commands such as login and logout. It should be called when executing a
|
||||
script with arguments related to HUB authentication.
|
||||
|
||||
Args:
|
||||
args (List[str]): A list of command line arguments.
|
||||
args (List[str]): A list of command line arguments. The first argument should be either 'login'
|
||||
or 'logout'. For 'login', an optional second argument can be the API key.
|
||||
|
||||
Examples:
|
||||
```bash
|
||||
yolo hub login YOUR_API_KEY
|
||||
```
|
||||
|
||||
Notes:
|
||||
- The function imports the 'hub' module from ultralytics to perform login and logout operations.
|
||||
- For the 'login' command, if no API key is provided, an empty string is passed to the login function.
|
||||
- The 'logout' command does not require any additional arguments.
|
||||
"""
|
||||
from ultralytics import hub
|
||||
|
||||
|
|
@ -501,21 +507,26 @@ def handle_yolo_hub(args: List[str]) -> None:
|
|||
|
||||
def handle_yolo_settings(args: List[str]) -> None:
|
||||
"""
|
||||
Handle YOLO settings command-line interface (CLI) commands.
|
||||
Handles YOLO settings command-line interface (CLI) commands.
|
||||
|
||||
This function processes YOLO settings CLI commands such as reset. It should be called when executing a script with
|
||||
arguments related to YOLO settings management.
|
||||
This function processes YOLO settings CLI commands such as reset and updating individual settings. It should be
|
||||
called when executing a script with arguments related to YOLO settings management.
|
||||
|
||||
Args:
|
||||
args (List[str]): A list of command line arguments for YOLO settings management.
|
||||
|
||||
Examples:
|
||||
Reset YOLO settings:
|
||||
>>> yolo settings reset
|
||||
>>> handle_yolo_settings(["reset"]) # Reset YOLO settings
|
||||
>>> handle_yolo_settings(["default_cfg_path=yolov8n.yaml"]) # Update a specific setting
|
||||
|
||||
Notes:
|
||||
For more information on handling YOLO settings, visit:
|
||||
https://docs.ultralytics.com/quickstart/#ultralytics-settings
|
||||
- If no arguments are provided, the function will display the current settings.
|
||||
- The 'reset' command will delete the existing settings file and create new default settings.
|
||||
- Other arguments are treated as key-value pairs to update specific settings.
|
||||
- The function will check for alignment between the provided settings and the existing ones.
|
||||
- After processing, the updated settings will be displayed.
|
||||
- For more information on handling YOLO settings, visit:
|
||||
https://docs.ultralytics.com/quickstart/#ultralytics-settings
|
||||
"""
|
||||
url = "https://docs.ultralytics.com/quickstart/#ultralytics-settings" # help URL
|
||||
try:
|
||||
|
|
@ -539,12 +550,17 @@ def handle_explorer():
|
|||
"""
|
||||
Open the Ultralytics Explorer GUI for dataset exploration and analysis.
|
||||
|
||||
This function launches a graphical user interface that provides tools for interacting with and analyzing datasets
|
||||
using the Ultralytics Explorer API.
|
||||
This function launches a graphical user interface that provides tools for interacting with and analyzing
|
||||
datasets using the Ultralytics Explorer API. It checks for the required 'streamlit' package and informs
|
||||
the user that the Explorer dashboard is loading.
|
||||
|
||||
Examples:
|
||||
Start the Ultralytics Explorer:
|
||||
>>> handle_explorer()
|
||||
|
||||
Notes:
|
||||
- Requires 'streamlit' package version 1.29.0 or higher.
|
||||
- The function does not take any arguments or return any values.
|
||||
- It is typically called from the command line interface using the 'yolo explorer' command.
|
||||
"""
|
||||
checks.check_requirements("streamlit>=1.29.0")
|
||||
LOGGER.info("💡 Loading Explorer dashboard...")
|
||||
|
|
@ -553,18 +569,18 @@ def handle_explorer():
|
|||
|
||||
def handle_streamlit_inference():
|
||||
"""
|
||||
Open the Ultralytics Live Inference streamlit app for real-time object detection.
|
||||
Open the Ultralytics Live Inference Streamlit app for real-time object detection.
|
||||
|
||||
This function initializes and runs a Streamlit application designed for performing live object detection using
|
||||
Ultralytics models.
|
||||
|
||||
References:
|
||||
- Streamlit documentation: https://docs.streamlit.io/
|
||||
- Ultralytics: https://docs.ultralytics.com
|
||||
Ultralytics models. It checks for the required Streamlit package and launches the app.
|
||||
|
||||
Examples:
|
||||
To run the live inference Streamlit app, execute:
|
||||
>>> handle_streamlit_inference()
|
||||
|
||||
Notes:
|
||||
- Requires Streamlit version 1.29.0 or higher.
|
||||
- The app is launched using the 'streamlit run' command.
|
||||
- The Streamlit app file is located in the Ultralytics package directory.
|
||||
"""
|
||||
checks.check_requirements("streamlit>=1.29.0")
|
||||
LOGGER.info("💡 Loading Ultralytics Live Inference app...")
|
||||
|
|
@ -573,20 +589,32 @@ def handle_streamlit_inference():
|
|||
|
||||
def parse_key_value_pair(pair):
|
||||
"""
|
||||
Parse a 'key=value' pair and return the key and value.
|
||||
Parses a key-value pair string into separate key and value components.
|
||||
|
||||
Args:
|
||||
pair (str): The 'key=value' string to be parsed.
|
||||
pair (str): A string containing a key-value pair in the format "key=value".
|
||||
|
||||
Returns:
|
||||
(tuple[str, str]): A tuple containing the key and value as separate strings.
|
||||
(tuple): A tuple containing two elements:
|
||||
- key (str): The parsed key.
|
||||
- value (str): The parsed value.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the value is missing or empty.
|
||||
|
||||
Examples:
|
||||
>>> key, value = parse_key_value_pair("model=yolov8n.pt")
|
||||
>>> key
|
||||
'model'
|
||||
>>> value
|
||||
'yolov8n.pt
|
||||
>>> print(f"Key: {key}, Value: {value}")
|
||||
Key: model, Value: yolov8n.pt
|
||||
|
||||
>>> key, value = parse_key_value_pair("epochs=100")
|
||||
>>> print(f"Key: {key}, Value: {value}")
|
||||
Key: epochs, Value: 100
|
||||
|
||||
Notes:
|
||||
- The function splits the input string on the first '=' character.
|
||||
- Leading and trailing whitespace is removed from both key and value.
|
||||
- An assertion error is raised if the value is empty after stripping.
|
||||
"""
|
||||
k, v = pair.split("=", 1) # split on first '=' sign
|
||||
k, v = k.strip(), v.strip() # remove spaces
|
||||
|
|
@ -596,17 +624,19 @@ def parse_key_value_pair(pair):
|
|||
|
||||
def smart_value(v):
|
||||
"""
|
||||
Convert a string representation of a value into its appropriate Python type (int, float, bool, None, etc.).
|
||||
Converts a string representation of a value to its appropriate Python type.
|
||||
|
||||
This function attempts to convert a given string into a Python object of the most appropriate type. It handles
|
||||
conversions to None, bool, int, float, and other types that can be evaluated safely.
|
||||
|
||||
Args:
|
||||
v (str): String representation of the value to be converted.
|
||||
v (str): The string representation of the value to be converted.
|
||||
|
||||
Returns:
|
||||
(Any): The converted value, which can be of type int, float, bool, None, or the original string if no conversion
|
||||
(Any): The converted value. The type can be None, bool, int, float, or the original string if no conversion
|
||||
is applicable.
|
||||
|
||||
Examples:
|
||||
Convert a string to various types:
|
||||
>>> smart_value("42")
|
||||
42
|
||||
>>> smart_value("3.14")
|
||||
|
|
@ -617,6 +647,11 @@ def smart_value(v):
|
|||
None
|
||||
>>> smart_value("some_string")
|
||||
'some_string'
|
||||
|
||||
Notes:
|
||||
- The function uses a case-insensitive comparison for boolean and None values.
|
||||
- For other types, it attempts to use Python's eval() function, which can be unsafe if used with untrusted input.
|
||||
- If no conversion is possible, the original string is returned.
|
||||
"""
|
||||
v_lower = v.lower()
|
||||
if v_lower == "none":
|
||||
|
|
@ -639,7 +674,7 @@ def entrypoint(debug=""):
|
|||
executing the corresponding tasks such as training, validation, prediction, exporting models, and more.
|
||||
|
||||
Args:
|
||||
debug (str, optional): Space-separated string of command-line arguments for debugging purposes.
|
||||
debug (str): Space-separated string of command-line arguments for debugging purposes.
|
||||
|
||||
Examples:
|
||||
Train a detection model for 10 epochs with an initial learning_rate of 0.01:
|
||||
|
|
@ -652,9 +687,9 @@ def entrypoint(debug=""):
|
|||
>>> entrypoint("val model=yolov8n.pt data=coco8.yaml batch=1 imgsz=640")
|
||||
|
||||
Notes:
|
||||
- For a list of all available commands and their arguments, see the provided help messages and the Ultralytics
|
||||
documentation at https://docs.ultralytics.com.
|
||||
- If no arguments are passed, the function will display the usage help message.
|
||||
- For a list of all available commands and their arguments, see the provided help messages and the
|
||||
Ultralytics documentation at https://docs.ultralytics.com.
|
||||
"""
|
||||
args = (debug.split(" ") if debug else ARGV)[1:]
|
||||
if not args: # no arguments passed
|
||||
|
|
@ -793,16 +828,24 @@ def entrypoint(debug=""):
|
|||
# Special modes --------------------------------------------------------------------------------------------------------
|
||||
def copy_default_cfg():
|
||||
"""
|
||||
Copy and create a new default configuration file with '_copy' appended to its name, providing a usage example.
|
||||
Copies the default configuration file and creates a new one with '_copy' appended to its name.
|
||||
|
||||
This function duplicates the existing default configuration file and appends '_copy' to its name in the current
|
||||
working directory.
|
||||
This function duplicates the existing default configuration file (DEFAULT_CFG_PATH) and saves it
|
||||
with '_copy' appended to its name in the current working directory. It provides a convenient way
|
||||
to create a custom configuration file based on the default settings.
|
||||
|
||||
Examples:
|
||||
Copy the default configuration file and use it in a YOLO command:
|
||||
>>> copy_default_cfg()
|
||||
>>> # Example YOLO command with this new custom cfg:
|
||||
>>> # yolo cfg='default_copy.yaml' imgsz=320 batch=8
|
||||
# Output: default.yaml copied to /path/to/current/directory/default_copy.yaml
|
||||
# Example YOLO command with this new custom cfg:
|
||||
# yolo cfg='/path/to/current/directory/default_copy.yaml' imgsz=320 batch=8
|
||||
|
||||
Notes:
|
||||
- The new configuration file is created in the current working directory.
|
||||
- After copying, the function prints a message with the new file's location and an example
|
||||
YOLO command demonstrating how to use the new configuration file.
|
||||
- This function is useful for users who want to modify the default configuration without
|
||||
altering the original file.
|
||||
"""
|
||||
new_file = Path.cwd() / DEFAULT_CFG_PATH.name.replace(".yaml", "_copy.yaml")
|
||||
shutil.copy2(DEFAULT_CFG_PATH, new_file)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -30,26 +30,18 @@ class Model(nn.Module):
|
|||
|
||||
This class provides a common interface for various operations related to YOLO models, such as training,
|
||||
validation, prediction, exporting, and benchmarking. It handles different types of models, including those
|
||||
loaded from local files, Ultralytics HUB, or Triton Server. The class is designed to be flexible and
|
||||
extendable for different tasks and model configurations.
|
||||
|
||||
Args:
|
||||
model (Union[str, Path], optional): Path or name of the model to load or create. This can be a local file
|
||||
path, a model name from Ultralytics HUB, or a Triton Server model. Defaults to 'yolov8n.pt'.
|
||||
task (Any, optional): The task type associated with the YOLO model. This can be used to specify the model's
|
||||
application domain, such as object detection, segmentation, etc. Defaults to None.
|
||||
verbose (bool, optional): If True, enables verbose output during the model's operations. Defaults to False.
|
||||
loaded from local files, Ultralytics HUB, or Triton Server.
|
||||
|
||||
Attributes:
|
||||
callbacks (dict): A dictionary of callback functions for various events during model operations.
|
||||
callbacks (Dict): A dictionary of callback functions for various events during model operations.
|
||||
predictor (BasePredictor): The predictor object used for making predictions.
|
||||
model (nn.Module): The underlying PyTorch model.
|
||||
trainer (BaseTrainer): The trainer object used for training the model.
|
||||
ckpt (dict): The checkpoint data if the model is loaded from a *.pt file.
|
||||
ckpt (Dict): The checkpoint data if the model is loaded from a *.pt file.
|
||||
cfg (str): The configuration of the model if loaded from a *.yaml file.
|
||||
ckpt_path (str): The path to the checkpoint file.
|
||||
overrides (dict): A dictionary of overrides for model configuration.
|
||||
metrics (dict): The latest training/validation metrics.
|
||||
overrides (Dict): A dictionary of overrides for model configuration.
|
||||
metrics (Dict): The latest training/validation metrics.
|
||||
session (HUBTrainingSession): The Ultralytics HUB session, if applicable.
|
||||
task (str): The type of task the model is intended for.
|
||||
model_name (str): The name of the model.
|
||||
|
|
@ -75,19 +67,14 @@ class Model(nn.Module):
|
|||
add_callback: Adds a callback function for an event.
|
||||
clear_callback: Clears all callbacks for an event.
|
||||
reset_callbacks: Resets all callbacks to their default functions.
|
||||
is_triton_model: Checks if a model is a Triton Server model.
|
||||
is_hub_model: Checks if a model is an Ultralytics HUB model.
|
||||
_reset_ckpt_args: Resets checkpoint arguments when loading a PyTorch model.
|
||||
_smart_load: Loads the appropriate module based on the model task.
|
||||
task_map: Provides a mapping from model tasks to corresponding classes.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the specified model file does not exist or is inaccessible.
|
||||
ValueError: If the model file or configuration is invalid or unsupported.
|
||||
ImportError: If required dependencies for specific model types (like HUB SDK) are not installed.
|
||||
TypeError: If the model is not a PyTorch model when required.
|
||||
AttributeError: If required attributes or methods are not implemented or available.
|
||||
NotImplementedError: If a specific model task or mode is not supported.
|
||||
Examples:
|
||||
>>> from ultralytics import YOLO
|
||||
>>> model = YOLO('yolov8n.pt')
|
||||
>>> results = model.predict('image.jpg')
|
||||
>>> model.train(data='coco128.yaml', epochs=3)
|
||||
>>> metrics = model.val()
|
||||
>>> model.export(format='onnx')
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -99,22 +86,27 @@ class Model(nn.Module):
|
|||
"""
|
||||
Initializes a new instance of the YOLO model class.
|
||||
|
||||
This constructor sets up the model based on the provided model path or name. It handles various types of model
|
||||
sources, including local files, Ultralytics HUB models, and Triton Server models. The method initializes several
|
||||
important attributes of the model and prepares it for operations like training, prediction, or export.
|
||||
This constructor sets up the model based on the provided model path or name. It handles various types of
|
||||
model sources, including local files, Ultralytics HUB models, and Triton Server models. The method
|
||||
initializes several important attributes of the model and prepares it for operations like training,
|
||||
prediction, or export.
|
||||
|
||||
Args:
|
||||
model (Union[str, Path], optional): The path or model file to load or create. This can be a local
|
||||
file path, a model name from Ultralytics HUB, or a Triton Server model. Defaults to 'yolov8n.pt'.
|
||||
task (Any, optional): The task type associated with the YOLO model, specifying its application domain.
|
||||
Defaults to None.
|
||||
verbose (bool, optional): If True, enables verbose output during the model's initialization and subsequent
|
||||
operations. Defaults to False.
|
||||
model (Union[str, Path]): Path or name of the model to load or create. Can be a local file path, a
|
||||
model name from Ultralytics HUB, or a Triton Server model.
|
||||
task (str | None): The task type associated with the YOLO model, specifying its application domain.
|
||||
verbose (bool): If True, enables verbose output during the model's initialization and subsequent
|
||||
operations.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the specified model file does not exist or is inaccessible.
|
||||
ValueError: If the model file or configuration is invalid or unsupported.
|
||||
ImportError: If required dependencies for specific model types (like HUB SDK) are not installed.
|
||||
|
||||
Examples:
|
||||
>>> model = Model("yolov8n.pt")
|
||||
>>> model = Model("path/to/model.yaml", task="detect")
|
||||
>>> model = Model("hub_model", verbose=True)
|
||||
"""
|
||||
super().__init__()
|
||||
self.callbacks = callbacks.get_default_callbacks()
|
||||
|
|
@ -155,27 +147,50 @@ class Model(nn.Module):
|
|||
**kwargs,
|
||||
) -> list:
|
||||
"""
|
||||
An alias for the predict method, enabling the model instance to be callable.
|
||||
Alias for the predict method, enabling the model instance to be callable for predictions.
|
||||
|
||||
This method simplifies the process of making predictions by allowing the model instance to be called directly
|
||||
with the required arguments for prediction.
|
||||
This method simplifies the process of making predictions by allowing the model instance to be called
|
||||
directly with the required arguments.
|
||||
|
||||
Args:
|
||||
source (str | Path | int | PIL.Image | np.ndarray, optional): The source of the image for making
|
||||
predictions. Accepts various types, including file paths, URLs, PIL images, and numpy arrays.
|
||||
Defaults to None.
|
||||
stream (bool, optional): If True, treats the input source as a continuous stream for predictions.
|
||||
Defaults to False.
|
||||
**kwargs (any): Additional keyword arguments for configuring the prediction process.
|
||||
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | List | Tuple): The source of
|
||||
the image(s) to make predictions on. Can be a file path, URL, PIL image, numpy array, PyTorch
|
||||
tensor, or a list/tuple of these.
|
||||
stream (bool): If True, treat the input source as a continuous stream for predictions.
|
||||
**kwargs (Any): Additional keyword arguments to configure the prediction process.
|
||||
|
||||
Returns:
|
||||
(List[ultralytics.engine.results.Results]): A list of prediction results, encapsulated in the Results class.
|
||||
(List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
|
||||
Results object.
|
||||
|
||||
Examples:
|
||||
>>> model = YOLO('yolov8n.pt')
|
||||
>>> results = model('https://ultralytics.com/images/bus.jpg')
|
||||
>>> for r in results:
|
||||
... print(f"Detected {len(r)} objects in image")
|
||||
"""
|
||||
return self.predict(source, stream, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def is_triton_model(model: str) -> bool:
|
||||
"""Is model a Triton Server URL string, i.e. <scheme>://<netloc>/<endpoint>/<task_name>"""
|
||||
"""
|
||||
Checks if the given model string is a Triton Server URL.
|
||||
|
||||
This static method determines whether the provided model string represents a valid Triton Server URL by
|
||||
parsing its components using urllib.parse.urlsplit().
|
||||
|
||||
Args:
|
||||
model (str): The model string to be checked.
|
||||
|
||||
Returns:
|
||||
(bool): True if the model string is a valid Triton Server URL, False otherwise.
|
||||
|
||||
Examples:
|
||||
>>> Model.is_triton_model('http://localhost:8000/v2/models/yolov8n')
|
||||
True
|
||||
>>> Model.is_triton_model('yolov8n.pt')
|
||||
False
|
||||
"""
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
url = urlsplit(model)
|
||||
|
|
@ -183,7 +198,30 @@ class Model(nn.Module):
|
|||
|
||||
@staticmethod
|
||||
def is_hub_model(model: str) -> bool:
|
||||
"""Check if the provided model is a HUB model."""
|
||||
"""
|
||||
Check if the provided model is an Ultralytics HUB model.
|
||||
|
||||
This static method determines whether the given model string represents a valid Ultralytics HUB model
|
||||
identifier. It checks for three possible formats: a full HUB URL, an API key and model ID combination,
|
||||
or a standalone model ID.
|
||||
|
||||
Args:
|
||||
model (str): The model identifier to check. This can be a URL, an API key and model ID
|
||||
combination, or a standalone model ID.
|
||||
|
||||
Returns:
|
||||
(bool): True if the model is a valid Ultralytics HUB model, False otherwise.
|
||||
|
||||
Examples:
|
||||
>>> Model.is_hub_model("https://hub.ultralytics.com/models/example_model")
|
||||
True
|
||||
>>> Model.is_hub_model("api_key_example_model_id")
|
||||
True
|
||||
>>> Model.is_hub_model("example_model_id")
|
||||
True
|
||||
>>> Model.is_hub_model("not_a_hub_model.pt")
|
||||
False
|
||||
"""
|
||||
return any(
|
||||
(
|
||||
model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
|
||||
|
|
@ -196,11 +234,24 @@ class Model(nn.Module):
|
|||
"""
|
||||
Initializes a new model and infers the task type from the model definitions.
|
||||
|
||||
This method creates a new model instance based on the provided configuration file. It loads the model
|
||||
configuration, infers the task type if not specified, and initializes the model using the appropriate
|
||||
class from the task map.
|
||||
|
||||
Args:
|
||||
cfg (str): model configuration file
|
||||
task (str | None): model task
|
||||
model (BaseModel): Customized model.
|
||||
verbose (bool): display model info on load
|
||||
cfg (str): Path to the model configuration file in YAML format.
|
||||
task (str | None): The specific task for the model. If None, it will be inferred from the config.
|
||||
model (torch.nn.Module | None): A custom model instance. If provided, it will be used instead of creating
|
||||
a new one.
|
||||
verbose (bool): If True, displays model information during loading.
|
||||
|
||||
Raises:
|
||||
ValueError: If the configuration file is invalid or the task cannot be inferred.
|
||||
ImportError: If the required dependencies for the specified task are not installed.
|
||||
|
||||
Examples:
|
||||
>>> model = Model()
|
||||
>>> model._new('yolov8n.yaml', task='detect', verbose=True)
|
||||
"""
|
||||
cfg_dict = yaml_model_load(cfg)
|
||||
self.cfg = cfg
|
||||
|
|
@ -216,11 +267,23 @@ class Model(nn.Module):
|
|||
|
||||
def _load(self, weights: str, task=None) -> None:
|
||||
"""
|
||||
Initializes a new model and infers the task type from the model head.
|
||||
Loads a model from a checkpoint file or initializes it from a weights file.
|
||||
|
||||
This method handles loading models from either .pt checkpoint files or other weight file formats. It sets
|
||||
up the model, task, and related attributes based on the loaded weights.
|
||||
|
||||
Args:
|
||||
weights (str): model checkpoint to be loaded
|
||||
task (str | None): model task
|
||||
weights (str): Path to the model weights file to be loaded.
|
||||
task (str | None): The task associated with the model. If None, it will be inferred from the model.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the specified weights file does not exist or is inaccessible.
|
||||
ValueError: If the weights file format is unsupported or invalid.
|
||||
|
||||
Examples:
|
||||
>>> model = Model()
|
||||
>>> model._load('yolov8n.pt')
|
||||
>>> model._load('path/to/weights.pth', task='detect')
|
||||
"""
|
||||
if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")):
|
||||
weights = checks.check_file(weights) # automatically download and return local filename
|
||||
|
|
@ -241,7 +304,22 @@ class Model(nn.Module):
|
|||
self.model_name = weights
|
||||
|
||||
def _check_is_pytorch_model(self) -> None:
|
||||
"""Raises TypeError is model is not a PyTorch model."""
|
||||
"""
|
||||
Checks if the model is a PyTorch model and raises a TypeError if it's not.
|
||||
|
||||
This method verifies that the model is either a PyTorch module or a .pt file. It's used to ensure that
|
||||
certain operations that require a PyTorch model are only performed on compatible model types.
|
||||
|
||||
Raises:
|
||||
TypeError: If the model is not a PyTorch module or a .pt file. The error message provides detailed
|
||||
information about supported model formats and operations.
|
||||
|
||||
Examples:
|
||||
>>> model = Model("yolov8n.pt")
|
||||
>>> model._check_is_pytorch_model() # No error raised
|
||||
>>> model = Model("yolov8n.onnx")
|
||||
>>> model._check_is_pytorch_model() # Raises TypeError
|
||||
"""
|
||||
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
|
||||
pt_module = isinstance(self.model, nn.Module)
|
||||
if not (pt_module or pt_str):
|
||||
|
|
@ -255,17 +333,21 @@ class Model(nn.Module):
|
|||
|
||||
def reset_weights(self) -> "Model":
|
||||
"""
|
||||
Resets the model parameters to randomly initialized values, effectively discarding all training information.
|
||||
Resets the model's weights to their initial state.
|
||||
|
||||
This method iterates through all modules in the model and resets their parameters if they have a
|
||||
'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True, enabling them
|
||||
to be updated during training.
|
||||
'reset_parameters' method. It also ensures that all parameters have 'requires_grad' set to True,
|
||||
enabling them to be updated during training.
|
||||
|
||||
Returns:
|
||||
self (ultralytics.engine.model.Model): The instance of the class with reset weights.
|
||||
(Model): The instance of the class with reset weights.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the model is not a PyTorch model.
|
||||
|
||||
Examples:
|
||||
>>> model = Model('yolov8n.pt')
|
||||
>>> model.reset_weights()
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
for m in self.model.modules():
|
||||
|
|
@ -283,13 +365,18 @@ class Model(nn.Module):
|
|||
name and shape and transfers them to the model.
|
||||
|
||||
Args:
|
||||
weights (str | Path): Path to the weights file or a weights object. Defaults to 'yolov8n.pt'.
|
||||
weights (Union[str, Path]): Path to the weights file or a weights object.
|
||||
|
||||
Returns:
|
||||
self (ultralytics.engine.model.Model): The instance of the class with loaded weights.
|
||||
(Model): The instance of the class with loaded weights.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the model is not a PyTorch model.
|
||||
|
||||
Examples:
|
||||
>>> model = Model()
|
||||
>>> model.load('yolov8n.pt')
|
||||
>>> model.load(Path('path/to/weights.pt'))
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
if isinstance(weights, (str, Path)):
|
||||
|
|
@ -301,14 +388,19 @@ class Model(nn.Module):
|
|||
"""
|
||||
Saves the current model state to a file.
|
||||
|
||||
This method exports the model's checkpoint (ckpt) to the specified filename.
|
||||
This method exports the model's checkpoint (ckpt) to the specified filename. It includes metadata such as
|
||||
the date, Ultralytics version, license information, and a link to the documentation.
|
||||
|
||||
Args:
|
||||
filename (str | Path): The name of the file to save the model to. Defaults to 'saved_model.pt'.
|
||||
use_dill (bool): Whether to try using dill for serialization if available. Defaults to True.
|
||||
filename (Union[str, Path]): The name of the file to save the model to.
|
||||
use_dill (bool): Whether to try using dill for serialization if available.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the model is not a PyTorch model.
|
||||
|
||||
Examples:
|
||||
>>> model = Model('yolov8n.pt')
|
||||
>>> model.save('my_model.pt')
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
from copy import deepcopy
|
||||
|
|
@ -329,30 +421,47 @@ class Model(nn.Module):
|
|||
"""
|
||||
Logs or returns model information.
|
||||
|
||||
This method provides an overview or detailed information about the model, depending on the arguments passed.
|
||||
It can control the verbosity of the output.
|
||||
This method provides an overview or detailed information about the model, depending on the arguments
|
||||
passed. It can control the verbosity of the output and return the information as a list.
|
||||
|
||||
Args:
|
||||
detailed (bool): If True, shows detailed information about the model. Defaults to False.
|
||||
verbose (bool): If True, prints the information. If False, returns the information. Defaults to True.
|
||||
detailed (bool): If True, shows detailed information about the model layers and parameters.
|
||||
verbose (bool): If True, prints the information. If False, returns the information as a list.
|
||||
|
||||
Returns:
|
||||
(list): Various types of information about the model, depending on the 'detailed' and 'verbose' parameters.
|
||||
(List[str]): A list of strings containing various types of information about the model, including
|
||||
model summary, layer details, and parameter counts. Empty if verbose is True.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the model is not a PyTorch model.
|
||||
TypeError: If the model is not a PyTorch model.
|
||||
|
||||
Examples:
|
||||
>>> model = Model('yolov8n.pt')
|
||||
>>> model.info() # Prints model summary
|
||||
>>> info_list = model.info(detailed=True, verbose=False) # Returns detailed info as a list
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
return self.model.info(detailed=detailed, verbose=verbose)
|
||||
|
||||
def fuse(self):
|
||||
"""
|
||||
Fuses Conv2d and BatchNorm2d layers in the model.
|
||||
Fuses Conv2d and BatchNorm2d layers in the model for optimized inference.
|
||||
|
||||
This method optimizes the model by fusing Conv2d and BatchNorm2d layers, which can improve inference speed.
|
||||
This method iterates through the model's modules and fuses consecutive Conv2d and BatchNorm2d layers
|
||||
into a single layer. This fusion can significantly improve inference speed by reducing the number of
|
||||
operations and memory accesses required during forward passes.
|
||||
|
||||
The fusion process typically involves folding the BatchNorm2d parameters (mean, variance, weight, and
|
||||
bias) into the preceding Conv2d layer's weights and biases. This results in a single Conv2d layer that
|
||||
performs both convolution and normalization in one step.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the model is not a PyTorch model.
|
||||
TypeError: If the model is not a PyTorch nn.Module.
|
||||
|
||||
Examples:
|
||||
>>> model = Model("yolov8n.pt")
|
||||
>>> model.fuse()
|
||||
>>> # Model is now fused and ready for optimized inference
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
self.model.fuse()
|
||||
|
|
@ -366,20 +475,26 @@ class Model(nn.Module):
|
|||
"""
|
||||
Generates image embeddings based on the provided source.
|
||||
|
||||
This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image source.
|
||||
It allows customization of the embedding process through various keyword arguments.
|
||||
This method is a wrapper around the 'predict()' method, focusing on generating embeddings from an image
|
||||
source. It allows customization of the embedding process through various keyword arguments.
|
||||
|
||||
Args:
|
||||
source (str | int | PIL.Image | np.ndarray): The source of the image for generating embeddings.
|
||||
The source can be a file path, URL, PIL image, numpy array, etc. Defaults to None.
|
||||
stream (bool): If True, predictions are streamed. Defaults to False.
|
||||
**kwargs (any): Additional keyword arguments for configuring the embedding process.
|
||||
source (str | Path | int | List | Tuple | np.ndarray | torch.Tensor): The source of the image for
|
||||
generating embeddings. Can be a file path, URL, PIL image, numpy array, etc.
|
||||
stream (bool): If True, predictions are streamed.
|
||||
**kwargs (Any): Additional keyword arguments for configuring the embedding process.
|
||||
|
||||
Returns:
|
||||
(List[torch.Tensor]): A list containing the image embeddings.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the model is not a PyTorch model.
|
||||
|
||||
Examples:
|
||||
>>> model = YOLO('yolov8n.pt')
|
||||
>>> image = 'https://ultralytics.com/images/bus.jpg'
|
||||
>>> embeddings = model.embed(image)
|
||||
>>> print(embeddings[0].shape)
|
||||
"""
|
||||
if not kwargs.get("embed"):
|
||||
kwargs["embed"] = [len(self.model.model) - 2] # embed second-to-last layer if no indices passed
|
||||
|
|
@ -397,28 +512,31 @@ class Model(nn.Module):
|
|||
|
||||
This method facilitates the prediction process, allowing various configurations through keyword arguments.
|
||||
It supports predictions with custom predictors or the default predictor method. The method handles different
|
||||
types of image sources and can operate in a streaming mode. It also provides support for SAM-type models
|
||||
through 'prompts'.
|
||||
|
||||
The method sets up a new predictor if not already present and updates its arguments with each call.
|
||||
It also issues a warning and uses default assets if the 'source' is not provided. The method determines if it
|
||||
is being called from the command line interface and adjusts its behavior accordingly, including setting defaults
|
||||
for confidence threshold and saving behavior.
|
||||
types of image sources and can operate in a streaming mode.
|
||||
|
||||
Args:
|
||||
source (str | int | PIL.Image | np.ndarray, optional): The source of the image for making predictions.
|
||||
Accepts various types, including file paths, URLs, PIL images, and numpy arrays. Defaults to ASSETS.
|
||||
stream (bool, optional): Treats the input source as a continuous stream for predictions. Defaults to False.
|
||||
predictor (BasePredictor, optional): An instance of a custom predictor class for making predictions.
|
||||
If None, the method uses a default predictor. Defaults to None.
|
||||
**kwargs (any): Additional keyword arguments for configuring the prediction process. These arguments allow
|
||||
for further customization of the prediction behavior.
|
||||
source (str | Path | int | List[str] | List[Path] | List[int] | np.ndarray | torch.Tensor): The source
|
||||
of the image(s) to make predictions on. Accepts various types including file paths, URLs, PIL
|
||||
images, numpy arrays, and torch tensors.
|
||||
stream (bool): If True, treats the input source as a continuous stream for predictions.
|
||||
predictor (BasePredictor | None): An instance of a custom predictor class for making predictions.
|
||||
If None, the method uses a default predictor.
|
||||
**kwargs (Any): Additional keyword arguments for configuring the prediction process.
|
||||
|
||||
Returns:
|
||||
(List[ultralytics.engine.results.Results]): A list of prediction results, encapsulated in the Results class.
|
||||
(List[ultralytics.engine.results.Results]): A list of prediction results, each encapsulated in a
|
||||
Results object.
|
||||
|
||||
Raises:
|
||||
AttributeError: If the predictor is not properly set up.
|
||||
Examples:
|
||||
>>> model = YOLO('yolov8n.pt')
|
||||
>>> results = model.predict(source='path/to/image.jpg', conf=0.25)
|
||||
>>> for r in results:
|
||||
... print(r.boxes.data) # print detection bounding boxes
|
||||
|
||||
Notes:
|
||||
- If 'source' is not provided, it defaults to the ASSETS constant with a warning.
|
||||
- The method sets up a new predictor if not already present and updates its arguments with each call.
|
||||
- For SAM-type models, 'prompts' can be passed as a keyword argument.
|
||||
"""
|
||||
if source is None:
|
||||
source = ASSETS
|
||||
|
|
@ -453,26 +571,33 @@ class Model(nn.Module):
|
|||
"""
|
||||
Conducts object tracking on the specified input source using the registered trackers.
|
||||
|
||||
This method performs object tracking using the model's predictors and optionally registered trackers. It is
|
||||
capable of handling different types of input sources such as file paths or video streams. The method supports
|
||||
customization of the tracking process through various keyword arguments. It registers trackers if they are not
|
||||
already present and optionally persists them based on the 'persist' flag.
|
||||
|
||||
The method sets a default confidence threshold specifically for ByteTrack-based tracking, which requires low
|
||||
confidence predictions as input. The tracking mode is explicitly set in the keyword arguments.
|
||||
This method performs object tracking using the model's predictors and optionally registered trackers. It handles
|
||||
various input sources such as file paths or video streams, and supports customization through keyword arguments.
|
||||
The method registers trackers if not already present and can persist them between calls.
|
||||
|
||||
Args:
|
||||
source (str, optional): The input source for object tracking. It can be a file path, URL, or video stream.
|
||||
stream (bool, optional): Treats the input source as a continuous video stream. Defaults to False.
|
||||
persist (bool, optional): Persists the trackers between different calls to this method. Defaults to False.
|
||||
**kwargs (any): Additional keyword arguments for configuring the tracking process. These arguments allow
|
||||
for further customization of the tracking behavior.
|
||||
source (Union[str, Path, int, List, Tuple, np.ndarray, torch.Tensor], optional): Input source for object
|
||||
tracking. Can be a file path, URL, or video stream.
|
||||
stream (bool): If True, treats the input source as a continuous video stream. Defaults to False.
|
||||
persist (bool): If True, persists trackers between different calls to this method. Defaults to False.
|
||||
**kwargs (Any): Additional keyword arguments for configuring the tracking process.
|
||||
|
||||
Returns:
|
||||
(List[ultralytics.engine.results.Results]): A list of tracking results, encapsulated in the Results class.
|
||||
(List[ultralytics.engine.results.Results]): A list of tracking results, each encapsulated in a Results object.
|
||||
|
||||
Raises:
|
||||
AttributeError: If the predictor does not have registered trackers.
|
||||
|
||||
Examples:
|
||||
>>> model = YOLO('yolov8n.pt')
|
||||
>>> results = model.track(source='path/to/video.mp4', show=True)
|
||||
>>> for r in results:
|
||||
... print(r.boxes.id) # print tracking IDs
|
||||
|
||||
Notes:
|
||||
- This method sets a default confidence threshold of 0.1 for ByteTrack-based tracking.
|
||||
- The tracking mode is explicitly set in the keyword arguments.
|
||||
- Batch size is set to 1 for tracking in videos.
|
||||
"""
|
||||
if not hasattr(self.predictor, "trackers"):
|
||||
from ultralytics.trackers import register_tracker
|
||||
|
|
@ -491,26 +616,25 @@ class Model(nn.Module):
|
|||
"""
|
||||
Validates the model using a specified dataset and validation configuration.
|
||||
|
||||
This method facilitates the model validation process, allowing for a range of customization through various
|
||||
settings and configurations. It supports validation with a custom validator or the default validation approach.
|
||||
The method combines default configurations, method-specific defaults, and user-provided arguments to configure
|
||||
the validation process. After validation, it updates the model's metrics with the results obtained from the
|
||||
validator.
|
||||
|
||||
The method supports various arguments that allow customization of the validation process. For a comprehensive
|
||||
list of all configurable options, users should refer to the 'configuration' section in the documentation.
|
||||
This method facilitates the model validation process, allowing for customization through various settings. It
|
||||
supports validation with a custom validator or the default validation approach. The method combines default
|
||||
configurations, method-specific defaults, and user-provided arguments to configure the validation process.
|
||||
|
||||
Args:
|
||||
validator (BaseValidator, optional): An instance of a custom validator class for validating the model. If
|
||||
None, the method uses a default validator. Defaults to None.
|
||||
**kwargs (any): Arbitrary keyword arguments representing the validation configuration. These arguments are
|
||||
used to customize various aspects of the validation process.
|
||||
validator (ultralytics.engine.validator.BaseValidator | None): An instance of a custom validator class for
|
||||
validating the model.
|
||||
**kwargs (Any): Arbitrary keyword arguments for customizing the validation process.
|
||||
|
||||
Returns:
|
||||
(ultralytics.utils.metrics.DetMetrics): Validation metrics obtained from the validation process.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the model is not a PyTorch model.
|
||||
|
||||
Examples:
|
||||
>>> model = YOLO('yolov8n.pt')
|
||||
>>> results = model.val(data='coco128.yaml', imgsz=640)
|
||||
>>> print(results.box.map) # Print mAP50-95
|
||||
"""
|
||||
custom = {"rect": True} # method defaults
|
||||
args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right
|
||||
|
|
@ -528,23 +652,31 @@ class Model(nn.Module):
|
|||
Benchmarks the model across various export formats to evaluate performance.
|
||||
|
||||
This method assesses the model's performance in different export formats, such as ONNX, TorchScript, etc.
|
||||
It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is configured
|
||||
using a combination of default configuration values, model-specific arguments, method-specific defaults, and
|
||||
any additional user-provided keyword arguments.
|
||||
|
||||
The method supports various arguments that allow customization of the benchmarking process, such as dataset
|
||||
choice, image size, precision modes, device selection, and verbosity. For a comprehensive list of all
|
||||
configurable options, users should refer to the 'configuration' section in the documentation.
|
||||
It uses the 'benchmark' function from the ultralytics.utils.benchmarks module. The benchmarking is
|
||||
configured using a combination of default configuration values, model-specific arguments, method-specific
|
||||
defaults, and any additional user-provided keyword arguments.
|
||||
|
||||
Args:
|
||||
**kwargs (any): Arbitrary keyword arguments to customize the benchmarking process. These are combined with
|
||||
default configurations, model-specific arguments, and method defaults.
|
||||
**kwargs (Any): Arbitrary keyword arguments to customize the benchmarking process. These are combined with
|
||||
default configurations, model-specific arguments, and method defaults. Common options include:
|
||||
- data (str): Path to the dataset for benchmarking.
|
||||
- imgsz (int | List[int]): Image size for benchmarking.
|
||||
- half (bool): Whether to use half-precision (FP16) mode.
|
||||
- int8 (bool): Whether to use int8 precision mode.
|
||||
- device (str): Device to run the benchmark on (e.g., 'cpu', 'cuda').
|
||||
- verbose (bool): Whether to print detailed benchmark information.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the results of the benchmarking process.
|
||||
(Dict): A dictionary containing the results of the benchmarking process, including metrics for
|
||||
different export formats.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the model is not a PyTorch model.
|
||||
|
||||
Examples:
|
||||
>>> model = YOLO('yolov8n.pt')
|
||||
>>> results = model.benchmark(data='coco8.yaml', imgsz=640, half=True)
|
||||
>>> print(results)
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
from ultralytics.utils.benchmarks import benchmark
|
||||
|
|
@ -570,20 +702,31 @@ class Model(nn.Module):
|
|||
|
||||
This method facilitates the export of the model to various formats (e.g., ONNX, TorchScript) for deployment
|
||||
purposes. It uses the 'Exporter' class for the export process, combining model-specific overrides, method
|
||||
defaults, and any additional arguments provided. The combined arguments are used to configure export settings.
|
||||
|
||||
The method supports a wide range of arguments to customize the export process. For a comprehensive list of all
|
||||
possible arguments, refer to the 'configuration' section in the documentation.
|
||||
defaults, and any additional arguments provided.
|
||||
|
||||
Args:
|
||||
**kwargs (any): Arbitrary keyword arguments to customize the export process. These are combined with the
|
||||
model's overrides and method defaults.
|
||||
**kwargs (Dict): Arbitrary keyword arguments to customize the export process. These are combined with
|
||||
the model's overrides and method defaults. Common arguments include:
|
||||
format (str): Export format (e.g., 'onnx', 'engine', 'coreml').
|
||||
half (bool): Export model in half-precision.
|
||||
int8 (bool): Export model in int8 precision.
|
||||
device (str): Device to run the export on.
|
||||
workspace (int): Maximum memory workspace size for TensorRT engines.
|
||||
nms (bool): Add Non-Maximum Suppression (NMS) module to model.
|
||||
simplify (bool): Simplify ONNX model.
|
||||
|
||||
Returns:
|
||||
(str): The exported model filename in the specified format, or an object related to the export process.
|
||||
(str): The path to the exported model file.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the model is not a PyTorch model.
|
||||
ValueError: If an unsupported export format is specified.
|
||||
RuntimeError: If the export process fails due to errors.
|
||||
|
||||
Examples:
|
||||
>>> model = YOLO('yolov8n.pt')
|
||||
>>> model.export(format='onnx', dynamic=True, simplify=True)
|
||||
'path/to/exported/model.onnx'
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
from .exporter import Exporter
|
||||
|
|
@ -606,29 +749,38 @@ class Model(nn.Module):
|
|||
"""
|
||||
Trains the model using the specified dataset and training configuration.
|
||||
|
||||
This method facilitates model training with a range of customizable settings and configurations. It supports
|
||||
training with a custom trainer or the default training approach defined in the method. The method handles
|
||||
different scenarios, such as resuming training from a checkpoint, integrating with Ultralytics HUB, and
|
||||
updating model and configuration after training.
|
||||
This method facilitates model training with a range of customizable settings. It supports training with a
|
||||
custom trainer or the default training approach. The method handles scenarios such as resuming training
|
||||
from a checkpoint, integrating with Ultralytics HUB, and updating model and configuration after training.
|
||||
|
||||
When using Ultralytics HUB, if the session already has a loaded model, the method prioritizes HUB training
|
||||
arguments and issues a warning if local arguments are provided. It checks for pip updates and combines default
|
||||
configurations, method-specific defaults, and user-provided arguments to configure the training process. After
|
||||
training, it updates the model and its configurations, and optionally attaches metrics.
|
||||
When using Ultralytics HUB, if the session has a loaded model, the method prioritizes HUB training
|
||||
arguments and warns if local arguments are provided. It checks for pip updates and combines default
|
||||
configurations, method-specific defaults, and user-provided arguments to configure the training process.
|
||||
|
||||
Args:
|
||||
trainer (BaseTrainer, optional): An instance of a custom trainer class for training the model. If None, the
|
||||
method uses a default trainer. Defaults to None.
|
||||
**kwargs (any): Arbitrary keyword arguments representing the training configuration. These arguments are
|
||||
used to customize various aspects of the training process.
|
||||
trainer (BaseTrainer | None): Custom trainer instance for model training. If None, uses default.
|
||||
**kwargs (Any): Arbitrary keyword arguments for training configuration. Common options include:
|
||||
data (str): Path to dataset configuration file.
|
||||
epochs (int): Number of training epochs.
|
||||
batch_size (int): Batch size for training.
|
||||
imgsz (int): Input image size.
|
||||
device (str): Device to run training on (e.g., 'cuda', 'cpu').
|
||||
workers (int): Number of worker threads for data loading.
|
||||
optimizer (str): Optimizer to use for training.
|
||||
lr0 (float): Initial learning rate.
|
||||
patience (int): Epochs to wait for no observable improvement for early stopping of training.
|
||||
|
||||
Returns:
|
||||
(dict | None): Training metrics if available and training is successful; otherwise, None.
|
||||
(Dict | None): Training metrics if available and training is successful; otherwise, None.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the model is not a PyTorch model.
|
||||
PermissionError: If there is a permission issue with the HUB session.
|
||||
ModuleNotFoundError: If the HUB SDK is not installed.
|
||||
|
||||
Examples:
|
||||
>>> model = YOLO('yolov8n.pt')
|
||||
>>> results = model.train(data='coco128.yaml', epochs=3)
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
if hasattr(self.session, "model") and self.session.model.id: # Ultralytics HUB session with loaded model
|
||||
|
|
@ -682,14 +834,19 @@ class Model(nn.Module):
|
|||
Args:
|
||||
use_ray (bool): If True, uses Ray Tune for hyperparameter tuning. Defaults to False.
|
||||
iterations (int): The number of tuning iterations to perform. Defaults to 10.
|
||||
*args (list): Variable length argument list for additional arguments.
|
||||
**kwargs (any): Arbitrary keyword arguments. These are combined with the model's overrides and defaults.
|
||||
*args (List): Variable length argument list for additional arguments.
|
||||
**kwargs (Dict): Arbitrary keyword arguments. These are combined with the model's overrides and defaults.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the results of the hyperparameter search.
|
||||
(Dict): A dictionary containing the results of the hyperparameter search.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the model is not a PyTorch model.
|
||||
|
||||
Examples:
|
||||
>>> model = YOLO('yolov8n.pt')
|
||||
>>> results = model.tune(use_ray=True, iterations=20)
|
||||
>>> print(results)
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
if use_ray:
|
||||
|
|
@ -704,7 +861,27 @@ class Model(nn.Module):
|
|||
return Tuner(args=args, _callbacks=self.callbacks)(model=self, iterations=iterations)
|
||||
|
||||
def _apply(self, fn) -> "Model":
|
||||
"""Apply to(), cpu(), cuda(), half(), float() to model tensors that are not parameters or registered buffers."""
|
||||
"""
|
||||
Applies a function to model tensors that are not parameters or registered buffers.
|
||||
|
||||
This method extends the functionality of the parent class's _apply method by additionally resetting the
|
||||
predictor and updating the device in the model's overrides. It's typically used for operations like
|
||||
moving the model to a different device or changing its precision.
|
||||
|
||||
Args:
|
||||
fn (Callable): A function to be applied to the model's tensors. This is typically a method like
|
||||
to(), cpu(), cuda(), half(), or float().
|
||||
|
||||
Returns:
|
||||
(Model): The model instance with the function applied and updated attributes.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the model is not a PyTorch model.
|
||||
|
||||
Examples:
|
||||
>>> model = Model("yolov8n.pt")
|
||||
>>> model = model._apply(lambda t: t.cuda()) # Move model to GPU
|
||||
"""
|
||||
self._check_is_pytorch_model()
|
||||
self = super()._apply(fn) # noqa
|
||||
self.predictor = None # reset predictor as device may have changed
|
||||
|
|
@ -717,10 +894,19 @@ class Model(nn.Module):
|
|||
Retrieves the class names associated with the loaded model.
|
||||
|
||||
This property returns the class names if they are defined in the model. It checks the class names for validity
|
||||
using the 'check_class_names' function from the ultralytics.nn.autobackend module.
|
||||
using the 'check_class_names' function from the ultralytics.nn.autobackend module. If the predictor is not
|
||||
initialized, it sets it up before retrieving the names.
|
||||
|
||||
Returns:
|
||||
(list | None): The class names of the model if available, otherwise None.
|
||||
(List[str]): A list of class names associated with the model.
|
||||
|
||||
Raises:
|
||||
AttributeError: If the model or predictor does not have a 'names' attribute.
|
||||
|
||||
Examples:
|
||||
>>> model = YOLO('yolov8n.pt')
|
||||
>>> print(model.names)
|
||||
['person', 'bicycle', 'car', ...]
|
||||
"""
|
||||
from ultralytics.nn.autobackend import check_class_names
|
||||
|
||||
|
|
@ -736,11 +922,22 @@ class Model(nn.Module):
|
|||
"""
|
||||
Retrieves the device on which the model's parameters are allocated.
|
||||
|
||||
This property is used to determine whether the model's parameters are on CPU or GPU. It only applies to models
|
||||
that are instances of nn.Module.
|
||||
This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is
|
||||
applicable only to models that are instances of nn.Module.
|
||||
|
||||
Returns:
|
||||
(torch.device | None): The device (CPU/GPU) of the model if it is a PyTorch model, otherwise None.
|
||||
(torch.device): The device (CPU/GPU) of the model.
|
||||
|
||||
Raises:
|
||||
AttributeError: If the model is not a PyTorch nn.Module instance.
|
||||
|
||||
Examples:
|
||||
>>> model = YOLO("yolov8n.pt")
|
||||
>>> print(model.device)
|
||||
device(type='cuda', index=0) # if CUDA is available
|
||||
>>> model = model.to("cpu")
|
||||
>>> print(model.device)
|
||||
device(type='cpu')
|
||||
"""
|
||||
return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
|
||||
|
||||
|
|
@ -749,10 +946,20 @@ class Model(nn.Module):
|
|||
"""
|
||||
Retrieves the transformations applied to the input data of the loaded model.
|
||||
|
||||
This property returns the transformations if they are defined in the model.
|
||||
This property returns the transformations if they are defined in the model. The transforms
|
||||
typically include preprocessing steps like resizing, normalization, and data augmentation
|
||||
that are applied to input data before it is fed into the model.
|
||||
|
||||
Returns:
|
||||
(object | None): The transform object of the model if available, otherwise None.
|
||||
|
||||
Examples:
|
||||
>>> model = YOLO('yolov8n.pt')
|
||||
>>> transforms = model.transforms
|
||||
>>> if transforms:
|
||||
... print(f"Model transforms: {transforms}")
|
||||
... else:
|
||||
... print("No transforms defined for this model.")
|
||||
"""
|
||||
return self.model.transforms if hasattr(self.model, "transforms") else None
|
||||
|
||||
|
|
@ -760,15 +967,25 @@ class Model(nn.Module):
|
|||
"""
|
||||
Adds a callback function for a specified event.
|
||||
|
||||
This method allows the user to register a custom callback function that is triggered on a specific event during
|
||||
model training or inference.
|
||||
This method allows registering custom callback functions that are triggered on specific events during
|
||||
model operations such as training or inference. Callbacks provide a way to extend and customize the
|
||||
behavior of the model at various stages of its lifecycle.
|
||||
|
||||
Args:
|
||||
event (str): The name of the event to attach the callback to.
|
||||
func (callable): The callback function to be registered.
|
||||
event (str): The name of the event to attach the callback to. Must be a valid event name recognized
|
||||
by the Ultralytics framework.
|
||||
func (Callable): The callback function to be registered. This function will be called when the
|
||||
specified event occurs.
|
||||
|
||||
Raises:
|
||||
ValueError: If the event name is not recognized.
|
||||
ValueError: If the event name is not recognized or is invalid.
|
||||
|
||||
Examples:
|
||||
>>> def on_train_start(trainer):
|
||||
... print("Training is starting!")
|
||||
>>> model = YOLO('yolov8n.pt')
|
||||
>>> model.add_callback("on_train_start", on_train_start)
|
||||
>>> model.train(data='coco128.yaml', epochs=1)
|
||||
"""
|
||||
self.callbacks[event].append(func)
|
||||
|
||||
|
|
@ -777,12 +994,26 @@ class Model(nn.Module):
|
|||
Clears all callback functions registered for a specified event.
|
||||
|
||||
This method removes all custom and default callback functions associated with the given event.
|
||||
It resets the callback list for the specified event to an empty list, effectively removing all
|
||||
registered callbacks for that event.
|
||||
|
||||
Args:
|
||||
event (str): The name of the event for which to clear the callbacks.
|
||||
event (str): The name of the event for which to clear the callbacks. This should be a valid event name
|
||||
recognized by the Ultralytics callback system.
|
||||
|
||||
Raises:
|
||||
ValueError: If the event name is not recognized.
|
||||
Examples:
|
||||
>>> model = YOLO('yolov8n.pt')
|
||||
>>> model.add_callback('on_train_start', lambda: print('Training started'))
|
||||
>>> model.clear_callback('on_train_start')
|
||||
>>> # All callbacks for 'on_train_start' are now removed
|
||||
|
||||
Notes:
|
||||
- This method affects both custom callbacks added by the user and default callbacks
|
||||
provided by the Ultralytics framework.
|
||||
- After calling this method, no callbacks will be executed for the specified event
|
||||
until new ones are added.
|
||||
- Use with caution as it removes all callbacks, including essential ones that might
|
||||
be required for proper functioning of certain operations.
|
||||
"""
|
||||
self.callbacks[event] = []
|
||||
|
||||
|
|
@ -791,14 +1022,45 @@ class Model(nn.Module):
|
|||
Resets all callbacks to their default functions.
|
||||
|
||||
This method reinstates the default callback functions for all events, removing any custom callbacks that were
|
||||
added previously.
|
||||
previously added. It iterates through all default callback events and replaces the current callbacks with the
|
||||
default ones.
|
||||
|
||||
The default callbacks are defined in the 'callbacks.default_callbacks' dictionary, which contains predefined
|
||||
functions for various events in the model's lifecycle, such as on_train_start, on_epoch_end, etc.
|
||||
|
||||
This method is useful when you want to revert to the original set of callbacks after making custom modifications,
|
||||
ensuring consistent behavior across different runs or experiments.
|
||||
|
||||
Examples:
|
||||
>>> model = YOLO('yolov8n.pt')
|
||||
>>> model.add_callback('on_train_start', custom_function)
|
||||
>>> model.reset_callbacks()
|
||||
# All callbacks are now reset to their default functions
|
||||
"""
|
||||
for event in callbacks.default_callbacks.keys():
|
||||
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
||||
|
||||
@staticmethod
|
||||
def _reset_ckpt_args(args: dict) -> dict:
|
||||
"""Reset arguments when loading a PyTorch model."""
|
||||
"""
|
||||
Resets specific arguments when loading a PyTorch model checkpoint.
|
||||
|
||||
This static method filters the input arguments dictionary to retain only a specific set of keys that are
|
||||
considered important for model loading. It's used to ensure that only relevant arguments are preserved
|
||||
when loading a model from a checkpoint, discarding any unnecessary or potentially conflicting settings.
|
||||
|
||||
Args:
|
||||
args (dict): A dictionary containing various model arguments and settings.
|
||||
|
||||
Returns:
|
||||
(dict): A new dictionary containing only the specified include keys from the input arguments.
|
||||
|
||||
Examples:
|
||||
>>> original_args = {'imgsz': 640, 'data': 'coco.yaml', 'task': 'detect', 'batch': 16, 'epochs': 100}
|
||||
>>> reset_args = Model._reset_ckpt_args(original_args)
|
||||
>>> print(reset_args)
|
||||
{'imgsz': 640, 'data': 'coco.yaml', 'task': 'detect'}
|
||||
"""
|
||||
include = {"imgsz", "data", "task", "single_cls"} # only remember these arguments when loading a PyTorch model
|
||||
return {k: v for k, v in args.items() if k in include}
|
||||
|
||||
|
|
@ -808,7 +1070,31 @@ class Model(nn.Module):
|
|||
# raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
|
||||
|
||||
def _smart_load(self, key: str):
|
||||
"""Load model/trainer/validator/predictor."""
|
||||
"""
|
||||
Loads the appropriate module based on the model task.
|
||||
|
||||
This method dynamically selects and returns the correct module (model, trainer, validator, or predictor)
|
||||
based on the current task of the model and the provided key. It uses the task_map attribute to determine
|
||||
the correct module to load.
|
||||
|
||||
Args:
|
||||
key (str): The type of module to load. Must be one of 'model', 'trainer', 'validator', or 'predictor'.
|
||||
|
||||
Returns:
|
||||
(object): The loaded module corresponding to the specified key and current task.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the specified key is not supported for the current task.
|
||||
|
||||
Examples:
|
||||
>>> model = Model(task='detect')
|
||||
>>> predictor = model._smart_load('predictor')
|
||||
>>> trainer = model._smart_load('trainer')
|
||||
|
||||
Notes:
|
||||
- This method is typically used internally by other methods of the Model class.
|
||||
- The task_map attribute should be properly initialized with the correct mappings for each task.
|
||||
"""
|
||||
try:
|
||||
return self.task_map[self.task][key]
|
||||
except Exception as e:
|
||||
|
|
@ -821,9 +1107,30 @@ class Model(nn.Module):
|
|||
@property
|
||||
def task_map(self) -> dict:
|
||||
"""
|
||||
Map head to model, trainer, validator, and predictor classes.
|
||||
Provides a mapping from model tasks to corresponding classes for different modes.
|
||||
|
||||
This property method returns a dictionary that maps each supported task (e.g., detect, segment, classify)
|
||||
to a nested dictionary. The nested dictionary contains mappings for different operational modes
|
||||
(model, trainer, validator, predictor) to their respective class implementations.
|
||||
|
||||
The mapping allows for dynamic loading of appropriate classes based on the model's task and the
|
||||
desired operational mode. This facilitates a flexible and extensible architecture for handling
|
||||
various tasks and modes within the Ultralytics framework.
|
||||
|
||||
Returns:
|
||||
task_map (dict): The map of model task to mode classes.
|
||||
(Dict[str, Dict[str, Any]]): A dictionary where keys are task names (str) and values are
|
||||
nested dictionaries. Each nested dictionary has keys 'model', 'trainer', 'validator', and
|
||||
'predictor', mapping to their respective class implementations.
|
||||
|
||||
Example:
|
||||
>>> model = Model()
|
||||
>>> task_map = model.task_map
|
||||
>>> detect_class_map = task_map['detect']
|
||||
>>> segment_class_map = task_map['segment']
|
||||
|
||||
Note:
|
||||
The actual implementation of this method may vary depending on the specific tasks and
|
||||
classes supported by the Ultralytics framework. The docstring provides a general
|
||||
description of the expected behavior and structure.
|
||||
"""
|
||||
raise NotImplementedError("Please provide task map for your model!")
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue