ultralytics 8.3.71 require explicit torch.nn usage (#19067)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: RizwanMunawar <chr043416@gmail.com>
Co-authored-by: Muhammad Rizwan Munawar <muhammadrizwanmunawar123@gmail.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2025-02-05 01:08:17 +09:00 committed by GitHub
parent 17450e9646
commit 5bca9341e8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 50 additions and 51 deletions

View file

@ -48,7 +48,8 @@ jobs:
python-version: "3.x" python-version: "3.x"
- uses: astral-sh/setup-uv@v5 - uses: astral-sh/setup-uv@v5
- name: Install Dependencies - name: Install Dependencies
run: uv pip install --system ruff black tqdm mkdocs-material "mkdocstrings[python]" mkdocs-redirects mkdocs-ultralytics-plugin mkdocs-macros-plugin # Note "beautifulsoup4<=4.12.3" required due to errors errors with >=4.13 in https://github.com/ultralytics/ultralytics/pull/19067
run: uv pip install --system "beautifulsoup4<=4.12.3" ruff black tqdm mkdocs-material "mkdocstrings[python]" mkdocs-redirects mkdocs-ultralytics-plugin mkdocs-macros-plugin
- name: Ruff fixes - name: Ruff fixes
continue-on-error: true continue-on-error: true
run: ruff check --fix --unsafe-fixes --select D --ignore=D100,D104,D203,D205,D212,D213,D401,D406,D407,D413 . run: ruff check --fix --unsafe-fixes --select D --ignore=D100,D104,D203,D205,D212,D213,D401,D406,D407,D413 .

View file

@ -113,7 +113,7 @@ def update_subdir_edit_links(subdir="", docs_url=""):
if str(subdir[0]) == "/": if str(subdir[0]) == "/":
subdir = str(subdir[0])[1:] subdir = str(subdir[0])[1:]
html_files = (SITE / subdir).rglob("*.html") html_files = (SITE / subdir).rglob("*.html")
for html_file in tqdm(html_files, desc="Processing subdir files"): for html_file in tqdm(html_files, desc="Processing subdir files", mininterval=1.0):
with html_file.open("r", encoding="utf-8") as file: with html_file.open("r", encoding="utf-8") as file:
soup = BeautifulSoup(file, "html.parser") soup = BeautifulSoup(file, "html.parser")
@ -178,7 +178,7 @@ def update_docs_html():
# Convert plaintext links to HTML hyperlinks # Convert plaintext links to HTML hyperlinks
files_modified = 0 files_modified = 0
for html_file in tqdm(SITE.rglob("*.html"), desc="Converting plaintext links"): for html_file in tqdm(SITE.rglob("*.html"), desc="Converting plaintext links", mininterval=1.0):
with open(html_file, encoding="utf-8") as file: with open(html_file, encoding="utf-8") as file:
content = file.read() content = file.read()
updated_content = convert_plaintext_links_to_html(content) updated_content = convert_plaintext_links_to_html(content)
@ -294,7 +294,7 @@ def minify_files(html=True, css=True, js=True):
}.items(): }.items():
stats[ext] = {"original": 0, "minified": 0} stats[ext] = {"original": 0, "minified": 0}
directory = "" # "stylesheets" if ext == css else "javascript" if ext == "js" else "" directory = "" # "stylesheets" if ext == css else "javascript" if ext == "js" else ""
for f in tqdm((SITE / directory).rglob(f"*.{ext}"), desc=f"Minifying {ext.upper()}"): for f in tqdm((SITE / directory).rglob(f"*.{ext}"), desc=f"Minifying {ext.upper()}", mininterval=1.0):
content = f.read_text(encoding="utf-8") content = f.read_text(encoding="utf-8")
minified = minifier(content) if minifier else remove_comments_and_empty_lines(content, ext) minified = minifier(content) if minifier else remove_comments_and_empty_lines(content, ext)
stats[ext]["original"] += len(content) stats[ext]["original"] += len(content)

View file

@ -87,6 +87,7 @@ dev = [
"pytest-cov", "pytest-cov",
"coverage[toml]", "coverage[toml]",
"mkdocs>=1.6.0", "mkdocs>=1.6.0",
"beautifulsoup4<=4.12.3", # For docs https://github.com/ultralytics/ultralytics/pull/19067
"mkdocs-material>=9.5.9", "mkdocs-material>=9.5.9",
"mkdocstrings[python]", "mkdocstrings[python]",
"mkdocs-redirects", # 301 redirects "mkdocs-redirects", # 301 redirects

View file

@ -1,6 +1,6 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
__version__ = "8.3.70" __version__ = "8.3.71"
import os import os

View file

@ -6,7 +6,7 @@
# Parameters # Parameters
nc: 80 # number of classes nc: 80 # number of classes
activation: nn.ReLU() # (optional) model default activation function activation: torch.nn.ReLU() # (optional) model default activation function
scales: # model compound scaling constants, i.e. 'model=yolov6n.yaml' will call yolov8.yaml with scale 'n' scales: # model compound scaling constants, i.e. 'model=yolov6n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels] # [depth, width, max_channels]
n: [0.33, 0.25, 1024] n: [0.33, 0.25, 1024]

View file

@ -11,7 +11,7 @@ from PIL import Image
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
from ultralytics.engine.results import Results from ultralytics.engine.results import Results
from ultralytics.hub import HUB_WEB_ROOT, HUBTrainingSession from ultralytics.hub import HUB_WEB_ROOT, HUBTrainingSession
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, yaml_model_load
from ultralytics.utils import ( from ultralytics.utils import (
ARGV, ARGV,
ASSETS, ASSETS,
@ -26,7 +26,7 @@ from ultralytics.utils import (
) )
class Model(nn.Module): class Model(torch.nn.Module):
""" """
A base class for implementing YOLO models, unifying APIs across different model types. A base class for implementing YOLO models, unifying APIs across different model types.
@ -37,7 +37,7 @@ class Model(nn.Module):
Attributes: Attributes:
callbacks (Dict): A dictionary of callback functions for various events during model operations. callbacks (Dict): A dictionary of callback functions for various events during model operations.
predictor (BasePredictor): The predictor object used for making predictions. predictor (BasePredictor): The predictor object used for making predictions.
model (nn.Module): The underlying PyTorch model. model (torch.nn.Module): The underlying PyTorch model.
trainer (BaseTrainer): The trainer object used for training the model. trainer (BaseTrainer): The trainer object used for training the model.
ckpt (Dict): The checkpoint data if the model is loaded from a *.pt file. ckpt (Dict): The checkpoint data if the model is loaded from a *.pt file.
cfg (str): The configuration of the model if loaded from a *.yaml file. cfg (str): The configuration of the model if loaded from a *.yaml file.
@ -317,7 +317,7 @@ class Model(nn.Module):
>>> model._check_is_pytorch_model() # Raises TypeError >>> model._check_is_pytorch_model() # Raises TypeError
""" """
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt" pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
pt_module = isinstance(self.model, nn.Module) pt_module = isinstance(self.model, torch.nn.Module)
if not (pt_module or pt_str): if not (pt_module or pt_str):
raise TypeError( raise TypeError(
f"model='{self.model}' should be a *.pt PyTorch model to run this method, but is a different format. " f"model='{self.model}' should be a *.pt PyTorch model to run this method, but is a different format. "
@ -405,7 +405,7 @@ class Model(nn.Module):
from ultralytics import __version__ from ultralytics import __version__
updates = { updates = {
"model": deepcopy(self.model).half() if isinstance(self.model, nn.Module) else self.model, "model": deepcopy(self.model).half() if isinstance(self.model, torch.nn.Module) else self.model,
"date": datetime.now().isoformat(), "date": datetime.now().isoformat(),
"version": __version__, "version": __version__,
"license": "AGPL-3.0 License (https://ultralytics.com/license)", "license": "AGPL-3.0 License (https://ultralytics.com/license)",
@ -452,7 +452,7 @@ class Model(nn.Module):
performs both convolution and normalization in one step. performs both convolution and normalization in one step.
Raises: Raises:
TypeError: If the model is not a PyTorch nn.Module. TypeError: If the model is not a PyTorch torch.nn.Module.
Examples: Examples:
>>> model = Model("yolo11n.pt") >>> model = Model("yolo11n.pt")
@ -921,13 +921,13 @@ class Model(nn.Module):
Retrieves the device on which the model's parameters are allocated. Retrieves the device on which the model's parameters are allocated.
This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is
applicable only to models that are instances of nn.Module. applicable only to models that are instances of torch.nn.Module.
Returns: Returns:
(torch.device): The device (CPU/GPU) of the model. (torch.device): The device (CPU/GPU) of the model.
Raises: Raises:
AttributeError: If the model is not a PyTorch nn.Module instance. AttributeError: If the model is not a torch.nn.Module instance.
Examples: Examples:
>>> model = YOLO("yolo11n.pt") >>> model = YOLO("yolo11n.pt")
@ -937,7 +937,7 @@ class Model(nn.Module):
>>> print(model.device) >>> print(model.device)
device(type='cpu') device(type='cpu')
""" """
return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None return next(self.model.parameters()).device if isinstance(self.model, torch.nn.Module) else None
@property @property
def transforms(self): def transforms(self):

View file

@ -426,8 +426,7 @@ class SAM2Model(torch.nn.Module):
high_res_masks: Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask. high_res_masks: Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask.
obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask. obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
object_score_logits: Tensor of shape (B) with object score logits. object_score_logits: Tensor of shape (B) with object score logits.
Where M is 3 if multimask_output=True, and 1 if multimask_output=False.
Where M is 3 if multimask_output=True, and 1 if multimask_output=False.
Examples: Examples:
>>> backbone_features = torch.rand(1, 256, 32, 32) >>> backbone_features = torch.rand(1, 256, 32, 32)

View file

@ -158,7 +158,7 @@ class PoseValidator(DetectionValidator):
gt_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing ground truth keypoints. gt_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing ground truth keypoints.
Returns: Returns:
torch.Tensor: A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels, (torch.Tensor): A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels,
where N is the number of detections. where N is the number of detections.
Example: Example:

View file

@ -780,7 +780,7 @@ class AutoBackend(nn.Module):
saved_model, pb, tflite, edgetpu, tfjs, ncnn or paddle. saved_model, pb, tflite, edgetpu, tfjs, ncnn or paddle.
Args: Args:
p: path to the model file. Defaults to path/to/model.pt p (str): path to the model file. Defaults to path/to/model.pt
Examples: Examples:
>>> model = AutoBackend(weights="path/to/model.onnx") >>> model = AutoBackend(weights="path/to/model.onnx")

View file

@ -9,7 +9,6 @@ from pathlib import Path
import thop import thop
import torch import torch
import torch.nn as nn
from ultralytics.nn.modules import ( from ultralytics.nn.modules import (
AIFI, AIFI,
@ -88,7 +87,7 @@ from ultralytics.utils.torch_utils import (
) )
class BaseModel(nn.Module): class BaseModel(torch.nn.Module):
"""The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.""" """The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family."""
def forward(self, x, *args, **kwargs): def forward(self, x, *args, **kwargs):
@ -151,7 +150,7 @@ class BaseModel(nn.Module):
if visualize: if visualize:
feature_visualization(x, m.type, m.i, save_dir=visualize) feature_visualization(x, m.type, m.i, save_dir=visualize)
if embed and m.i in embed: if embed and m.i in embed:
embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
if m.i == max(embed): if m.i == max(embed):
return torch.unbind(torch.cat(embeddings, 1), dim=0) return torch.unbind(torch.cat(embeddings, 1), dim=0)
return x return x
@ -170,12 +169,9 @@ class BaseModel(nn.Module):
the provided list. the provided list.
Args: Args:
m (nn.Module): The layer to be profiled. m (torch.nn.Module): The layer to be profiled.
x (torch.Tensor): The input data to the layer. x (torch.Tensor): The input data to the layer.
dt (list): A list to store the computation time of the layer. dt (list): A list to store the computation time of the layer.
Returns:
None
""" """
c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
@ -195,7 +191,7 @@ class BaseModel(nn.Module):
computation efficiency. computation efficiency.
Returns: Returns:
(nn.Module): The fused model is returned. (torch.nn.Module): The fused model is returned.
""" """
if not self.is_fused(): if not self.is_fused():
for m in self.model.modules(): for m in self.model.modules():
@ -229,7 +225,7 @@ class BaseModel(nn.Module):
Returns: Returns:
(bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise. (bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
""" """
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d() bn = tuple(v for k, v in torch.nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
def info(self, detailed=False, verbose=True, imgsz=640): def info(self, detailed=False, verbose=True, imgsz=640):
@ -304,7 +300,7 @@ class DetectionModel(BaseModel):
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
if self.yaml["backbone"][0][2] == "Silence": if self.yaml["backbone"][0][2] == "Silence":
LOGGER.warning( LOGGER.warning(
"WARNING ⚠️ YOLOv9 `Silence` module is deprecated in favor of nn.Identity. " "WARNING ⚠️ YOLOv9 `Silence` module is deprecated in favor of torch.nn.Identity. "
"Please delete local *.pt file and re-download the latest model checkpoint." "Please delete local *.pt file and re-download the latest model checkpoint."
) )
self.yaml["backbone"][0][2] = "nn.Identity" self.yaml["backbone"][0][2] = "nn.Identity"
@ -458,20 +454,22 @@ class ClassificationModel(BaseModel):
name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module
if isinstance(m, Classify): # YOLO Classify() head if isinstance(m, Classify): # YOLO Classify() head
if m.linear.out_features != nc: if m.linear.out_features != nc:
m.linear = nn.Linear(m.linear.in_features, nc) m.linear = torch.nn.Linear(m.linear.in_features, nc)
elif isinstance(m, nn.Linear): # ResNet, EfficientNet elif isinstance(m, torch.nn.Linear): # ResNet, EfficientNet
if m.out_features != nc: if m.out_features != nc:
setattr(model, name, nn.Linear(m.in_features, nc)) setattr(model, name, torch.nn.Linear(m.in_features, nc))
elif isinstance(m, nn.Sequential): elif isinstance(m, torch.nn.Sequential):
types = [type(x) for x in m] types = [type(x) for x in m]
if nn.Linear in types: if torch.nn.Linear in types:
i = len(types) - 1 - types[::-1].index(nn.Linear) # last nn.Linear index i = len(types) - 1 - types[::-1].index(torch.nn.Linear) # last torch.nn.Linear index
if m[i].out_features != nc: if m[i].out_features != nc:
m[i] = nn.Linear(m[i].in_features, nc) m[i] = torch.nn.Linear(m[i].in_features, nc)
elif nn.Conv2d in types: elif torch.nn.Conv2d in types:
i = len(types) - 1 - types[::-1].index(nn.Conv2d) # last nn.Conv2d index i = len(types) - 1 - types[::-1].index(torch.nn.Conv2d) # last torch.nn.Conv2d index
if m[i].out_channels != nc: if m[i].out_channels != nc:
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None) m[i] = torch.nn.Conv2d(
m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None
)
def init_criterion(self): def init_criterion(self):
"""Initialize the loss criterion for the ClassificationModel.""" """Initialize the loss criterion for the ClassificationModel."""
@ -587,7 +585,7 @@ class RTDETRDetectionModel(DetectionModel):
if visualize: if visualize:
feature_visualization(x, m.type, m.i, save_dir=visualize) feature_visualization(x, m.type, m.i, save_dir=visualize)
if embed and m.i in embed: if embed and m.i in embed:
embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
if m.i == max(embed): if m.i == max(embed):
return torch.unbind(torch.cat(embeddings, 1), dim=0) return torch.unbind(torch.cat(embeddings, 1), dim=0)
head = self.model[-1] head = self.model[-1]
@ -663,7 +661,7 @@ class WorldModel(DetectionModel):
if visualize: if visualize:
feature_visualization(x, m.type, m.i, save_dir=visualize) feature_visualization(x, m.type, m.i, save_dir=visualize)
if embed and m.i in embed: if embed and m.i in embed:
embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
if m.i == max(embed): if m.i == max(embed):
return torch.unbind(torch.cat(embeddings, 1), dim=0) return torch.unbind(torch.cat(embeddings, 1), dim=0)
return x return x
@ -684,7 +682,7 @@ class WorldModel(DetectionModel):
return self.criterion(preds, batch) return self.criterion(preds, batch)
class Ensemble(nn.ModuleList): class Ensemble(torch.nn.ModuleList):
"""Ensemble of models.""" """Ensemble of models."""
def __init__(self): def __init__(self):
@ -887,7 +885,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
for m in ensemble.modules(): for m in ensemble.modules():
if hasattr(m, "inplace"): if hasattr(m, "inplace"):
m.inplace = inplace m.inplace = inplace
elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"): elif isinstance(m, torch.nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
m.recompute_scale_factor = None # torch 1.11.0 compatibility m.recompute_scale_factor = None # torch 1.11.0 compatibility
# Return model # Return model
@ -922,7 +920,7 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
for m in model.modules(): for m in model.modules():
if hasattr(m, "inplace"): if hasattr(m, "inplace"):
m.inplace = inplace m.inplace = inplace
elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"): elif isinstance(m, torch.nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
m.recompute_scale_factor = None # torch 1.11.0 compatibility m.recompute_scale_factor = None # torch 1.11.0 compatibility
# Return model and ckpt # Return model and ckpt
@ -946,7 +944,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
depth, width, max_channels = scales[scale] depth, width, max_channels = scales[scale]
if act: if act:
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU() Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = torch.nn.SiLU()
if verbose: if verbose:
LOGGER.info(f"{colorstr('activation:')} {act}") # print LOGGER.info(f"{colorstr('activation:')} {act}") # print
@ -982,7 +980,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
C3, C3,
C3TR, C3TR,
C3Ghost, C3Ghost,
nn.ConvTranspose2d, torch.nn.ConvTranspose2d,
DWConvTranspose2d, DWConvTranspose2d,
C3x, C3x,
RepC3, RepC3,
@ -1048,7 +1046,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
n = 1 n = 1
elif m is ResNetLayer: elif m is ResNetLayer:
c2 = args[1] if args[3] else args[1] * 4 c2 = args[1] if args[3] else args[1] * 4
elif m is nn.BatchNorm2d: elif m is torch.nn.BatchNorm2d:
args = [ch[f]] args = [ch[f]]
elif m is Concat: elif m is Concat:
c2 = sum(ch[x] for x in f) c2 = sum(ch[x] for x in f)
@ -1073,7 +1071,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
else: else:
c2 = ch[f] c2 = ch[f]
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module m_ = torch.nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
t = str(m)[8:-2].replace("__main__.", "") # module type t = str(m)[8:-2].replace("__main__.", "") # module type
m_.np = sum(x.numel() for x in m_.parameters()) # number params m_.np = sum(x.numel() for x in m_.parameters()) # number params
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
@ -1084,7 +1082,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
if i == 0: if i == 0:
ch = [] ch = []
ch.append(c2) ch.append(c2)
return nn.Sequential(*layers), sorted(save) return torch.nn.Sequential(*layers), sorted(save)
def yaml_model_load(path): def yaml_model_load(path):
@ -1126,7 +1124,7 @@ def guess_model_task(model):
Guess the task of a PyTorch model from its architecture or configuration. Guess the task of a PyTorch model from its architecture or configuration.
Args: Args:
model (nn.Module | dict): PyTorch model or model configuration in YAML format. model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.
Returns: Returns:
(str): Task of the model ('detect', 'segment', 'classify', 'pose'). (str): Task of the model ('detect', 'segment', 'classify', 'pose').
@ -1154,7 +1152,7 @@ def guess_model_task(model):
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
return cfg2task(model) return cfg2task(model)
# Guess from PyTorch model # Guess from PyTorch model
if isinstance(model, nn.Module): # PyTorch model if isinstance(model, torch.nn.Module): # PyTorch model
for x in "model.args", "model.model.args", "model.model.model.args": for x in "model.args", "model.model.args", "model.model.model.args":
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
return eval(x)["task"] return eval(x)["task"]