ultralytics 8.3.71 require explicit torch.nn usage (#19067)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: RizwanMunawar <chr043416@gmail.com> Co-authored-by: Muhammad Rizwan Munawar <muhammadrizwanmunawar123@gmail.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
17450e9646
commit
5bca9341e8
10 changed files with 50 additions and 51 deletions
3
.github/workflows/docs.yml
vendored
3
.github/workflows/docs.yml
vendored
|
|
@ -48,7 +48,8 @@ jobs:
|
||||||
python-version: "3.x"
|
python-version: "3.x"
|
||||||
- uses: astral-sh/setup-uv@v5
|
- uses: astral-sh/setup-uv@v5
|
||||||
- name: Install Dependencies
|
- name: Install Dependencies
|
||||||
run: uv pip install --system ruff black tqdm mkdocs-material "mkdocstrings[python]" mkdocs-redirects mkdocs-ultralytics-plugin mkdocs-macros-plugin
|
# Note "beautifulsoup4<=4.12.3" required due to errors errors with >=4.13 in https://github.com/ultralytics/ultralytics/pull/19067
|
||||||
|
run: uv pip install --system "beautifulsoup4<=4.12.3" ruff black tqdm mkdocs-material "mkdocstrings[python]" mkdocs-redirects mkdocs-ultralytics-plugin mkdocs-macros-plugin
|
||||||
- name: Ruff fixes
|
- name: Ruff fixes
|
||||||
continue-on-error: true
|
continue-on-error: true
|
||||||
run: ruff check --fix --unsafe-fixes --select D --ignore=D100,D104,D203,D205,D212,D213,D401,D406,D407,D413 .
|
run: ruff check --fix --unsafe-fixes --select D --ignore=D100,D104,D203,D205,D212,D213,D401,D406,D407,D413 .
|
||||||
|
|
|
||||||
|
|
@ -113,7 +113,7 @@ def update_subdir_edit_links(subdir="", docs_url=""):
|
||||||
if str(subdir[0]) == "/":
|
if str(subdir[0]) == "/":
|
||||||
subdir = str(subdir[0])[1:]
|
subdir = str(subdir[0])[1:]
|
||||||
html_files = (SITE / subdir).rglob("*.html")
|
html_files = (SITE / subdir).rglob("*.html")
|
||||||
for html_file in tqdm(html_files, desc="Processing subdir files"):
|
for html_file in tqdm(html_files, desc="Processing subdir files", mininterval=1.0):
|
||||||
with html_file.open("r", encoding="utf-8") as file:
|
with html_file.open("r", encoding="utf-8") as file:
|
||||||
soup = BeautifulSoup(file, "html.parser")
|
soup = BeautifulSoup(file, "html.parser")
|
||||||
|
|
||||||
|
|
@ -178,7 +178,7 @@ def update_docs_html():
|
||||||
|
|
||||||
# Convert plaintext links to HTML hyperlinks
|
# Convert plaintext links to HTML hyperlinks
|
||||||
files_modified = 0
|
files_modified = 0
|
||||||
for html_file in tqdm(SITE.rglob("*.html"), desc="Converting plaintext links"):
|
for html_file in tqdm(SITE.rglob("*.html"), desc="Converting plaintext links", mininterval=1.0):
|
||||||
with open(html_file, encoding="utf-8") as file:
|
with open(html_file, encoding="utf-8") as file:
|
||||||
content = file.read()
|
content = file.read()
|
||||||
updated_content = convert_plaintext_links_to_html(content)
|
updated_content = convert_plaintext_links_to_html(content)
|
||||||
|
|
@ -294,7 +294,7 @@ def minify_files(html=True, css=True, js=True):
|
||||||
}.items():
|
}.items():
|
||||||
stats[ext] = {"original": 0, "minified": 0}
|
stats[ext] = {"original": 0, "minified": 0}
|
||||||
directory = "" # "stylesheets" if ext == css else "javascript" if ext == "js" else ""
|
directory = "" # "stylesheets" if ext == css else "javascript" if ext == "js" else ""
|
||||||
for f in tqdm((SITE / directory).rglob(f"*.{ext}"), desc=f"Minifying {ext.upper()}"):
|
for f in tqdm((SITE / directory).rglob(f"*.{ext}"), desc=f"Minifying {ext.upper()}", mininterval=1.0):
|
||||||
content = f.read_text(encoding="utf-8")
|
content = f.read_text(encoding="utf-8")
|
||||||
minified = minifier(content) if minifier else remove_comments_and_empty_lines(content, ext)
|
minified = minifier(content) if minifier else remove_comments_and_empty_lines(content, ext)
|
||||||
stats[ext]["original"] += len(content)
|
stats[ext]["original"] += len(content)
|
||||||
|
|
|
||||||
|
|
@ -87,6 +87,7 @@ dev = [
|
||||||
"pytest-cov",
|
"pytest-cov",
|
||||||
"coverage[toml]",
|
"coverage[toml]",
|
||||||
"mkdocs>=1.6.0",
|
"mkdocs>=1.6.0",
|
||||||
|
"beautifulsoup4<=4.12.3", # For docs https://github.com/ultralytics/ultralytics/pull/19067
|
||||||
"mkdocs-material>=9.5.9",
|
"mkdocs-material>=9.5.9",
|
||||||
"mkdocstrings[python]",
|
"mkdocstrings[python]",
|
||||||
"mkdocs-redirects", # 301 redirects
|
"mkdocs-redirects", # 301 redirects
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||||
|
|
||||||
__version__ = "8.3.70"
|
__version__ = "8.3.71"
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
# Parameters
|
# Parameters
|
||||||
nc: 80 # number of classes
|
nc: 80 # number of classes
|
||||||
activation: nn.ReLU() # (optional) model default activation function
|
activation: torch.nn.ReLU() # (optional) model default activation function
|
||||||
scales: # model compound scaling constants, i.e. 'model=yolov6n.yaml' will call yolov8.yaml with scale 'n'
|
scales: # model compound scaling constants, i.e. 'model=yolov6n.yaml' will call yolov8.yaml with scale 'n'
|
||||||
# [depth, width, max_channels]
|
# [depth, width, max_channels]
|
||||||
n: [0.33, 0.25, 1024]
|
n: [0.33, 0.25, 1024]
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from PIL import Image
|
||||||
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
||||||
from ultralytics.engine.results import Results
|
from ultralytics.engine.results import Results
|
||||||
from ultralytics.hub import HUB_WEB_ROOT, HUBTrainingSession
|
from ultralytics.hub import HUB_WEB_ROOT, HUBTrainingSession
|
||||||
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, yaml_model_load
|
||||||
from ultralytics.utils import (
|
from ultralytics.utils import (
|
||||||
ARGV,
|
ARGV,
|
||||||
ASSETS,
|
ASSETS,
|
||||||
|
|
@ -26,7 +26,7 @@ from ultralytics.utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
A base class for implementing YOLO models, unifying APIs across different model types.
|
A base class for implementing YOLO models, unifying APIs across different model types.
|
||||||
|
|
||||||
|
|
@ -37,7 +37,7 @@ class Model(nn.Module):
|
||||||
Attributes:
|
Attributes:
|
||||||
callbacks (Dict): A dictionary of callback functions for various events during model operations.
|
callbacks (Dict): A dictionary of callback functions for various events during model operations.
|
||||||
predictor (BasePredictor): The predictor object used for making predictions.
|
predictor (BasePredictor): The predictor object used for making predictions.
|
||||||
model (nn.Module): The underlying PyTorch model.
|
model (torch.nn.Module): The underlying PyTorch model.
|
||||||
trainer (BaseTrainer): The trainer object used for training the model.
|
trainer (BaseTrainer): The trainer object used for training the model.
|
||||||
ckpt (Dict): The checkpoint data if the model is loaded from a *.pt file.
|
ckpt (Dict): The checkpoint data if the model is loaded from a *.pt file.
|
||||||
cfg (str): The configuration of the model if loaded from a *.yaml file.
|
cfg (str): The configuration of the model if loaded from a *.yaml file.
|
||||||
|
|
@ -317,7 +317,7 @@ class Model(nn.Module):
|
||||||
>>> model._check_is_pytorch_model() # Raises TypeError
|
>>> model._check_is_pytorch_model() # Raises TypeError
|
||||||
"""
|
"""
|
||||||
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
|
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
|
||||||
pt_module = isinstance(self.model, nn.Module)
|
pt_module = isinstance(self.model, torch.nn.Module)
|
||||||
if not (pt_module or pt_str):
|
if not (pt_module or pt_str):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"model='{self.model}' should be a *.pt PyTorch model to run this method, but is a different format. "
|
f"model='{self.model}' should be a *.pt PyTorch model to run this method, but is a different format. "
|
||||||
|
|
@ -405,7 +405,7 @@ class Model(nn.Module):
|
||||||
from ultralytics import __version__
|
from ultralytics import __version__
|
||||||
|
|
||||||
updates = {
|
updates = {
|
||||||
"model": deepcopy(self.model).half() if isinstance(self.model, nn.Module) else self.model,
|
"model": deepcopy(self.model).half() if isinstance(self.model, torch.nn.Module) else self.model,
|
||||||
"date": datetime.now().isoformat(),
|
"date": datetime.now().isoformat(),
|
||||||
"version": __version__,
|
"version": __version__,
|
||||||
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
|
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
|
||||||
|
|
@ -452,7 +452,7 @@ class Model(nn.Module):
|
||||||
performs both convolution and normalization in one step.
|
performs both convolution and normalization in one step.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If the model is not a PyTorch nn.Module.
|
TypeError: If the model is not a PyTorch torch.nn.Module.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> model = Model("yolo11n.pt")
|
>>> model = Model("yolo11n.pt")
|
||||||
|
|
@ -921,13 +921,13 @@ class Model(nn.Module):
|
||||||
Retrieves the device on which the model's parameters are allocated.
|
Retrieves the device on which the model's parameters are allocated.
|
||||||
|
|
||||||
This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is
|
This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is
|
||||||
applicable only to models that are instances of nn.Module.
|
applicable only to models that are instances of torch.nn.Module.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(torch.device): The device (CPU/GPU) of the model.
|
(torch.device): The device (CPU/GPU) of the model.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
AttributeError: If the model is not a PyTorch nn.Module instance.
|
AttributeError: If the model is not a torch.nn.Module instance.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> model = YOLO("yolo11n.pt")
|
>>> model = YOLO("yolo11n.pt")
|
||||||
|
|
@ -937,7 +937,7 @@ class Model(nn.Module):
|
||||||
>>> print(model.device)
|
>>> print(model.device)
|
||||||
device(type='cpu')
|
device(type='cpu')
|
||||||
"""
|
"""
|
||||||
return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
|
return next(self.model.parameters()).device if isinstance(self.model, torch.nn.Module) else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def transforms(self):
|
def transforms(self):
|
||||||
|
|
|
||||||
|
|
@ -426,8 +426,7 @@ class SAM2Model(torch.nn.Module):
|
||||||
high_res_masks: Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask.
|
high_res_masks: Tensor of shape (B, 1, H*16, W*16) with the best high-resolution mask.
|
||||||
obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
|
obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask.
|
||||||
object_score_logits: Tensor of shape (B) with object score logits.
|
object_score_logits: Tensor of shape (B) with object score logits.
|
||||||
|
Where M is 3 if multimask_output=True, and 1 if multimask_output=False.
|
||||||
Where M is 3 if multimask_output=True, and 1 if multimask_output=False.
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> backbone_features = torch.rand(1, 256, 32, 32)
|
>>> backbone_features = torch.rand(1, 256, 32, 32)
|
||||||
|
|
|
||||||
|
|
@ -158,7 +158,7 @@ class PoseValidator(DetectionValidator):
|
||||||
gt_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing ground truth keypoints.
|
gt_kpts (torch.Tensor | None): Optional tensor with shape (N, 51) representing ground truth keypoints.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels,
|
(torch.Tensor): A tensor with shape (N, 10) representing the correct prediction matrix for 10 IoU levels,
|
||||||
where N is the number of detections.
|
where N is the number of detections.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
|
||||||
|
|
@ -780,7 +780,7 @@ class AutoBackend(nn.Module):
|
||||||
saved_model, pb, tflite, edgetpu, tfjs, ncnn or paddle.
|
saved_model, pb, tflite, edgetpu, tfjs, ncnn or paddle.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
p: path to the model file. Defaults to path/to/model.pt
|
p (str): path to the model file. Defaults to path/to/model.pt
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> model = AutoBackend(weights="path/to/model.onnx")
|
>>> model = AutoBackend(weights="path/to/model.onnx")
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ from pathlib import Path
|
||||||
|
|
||||||
import thop
|
import thop
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from ultralytics.nn.modules import (
|
from ultralytics.nn.modules import (
|
||||||
AIFI,
|
AIFI,
|
||||||
|
|
@ -88,7 +87,7 @@ from ultralytics.utils.torch_utils import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseModel(nn.Module):
|
class BaseModel(torch.nn.Module):
|
||||||
"""The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family."""
|
"""The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family."""
|
||||||
|
|
||||||
def forward(self, x, *args, **kwargs):
|
def forward(self, x, *args, **kwargs):
|
||||||
|
|
@ -151,7 +150,7 @@ class BaseModel(nn.Module):
|
||||||
if visualize:
|
if visualize:
|
||||||
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
||||||
if embed and m.i in embed:
|
if embed and m.i in embed:
|
||||||
embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
||||||
if m.i == max(embed):
|
if m.i == max(embed):
|
||||||
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
||||||
return x
|
return x
|
||||||
|
|
@ -170,12 +169,9 @@ class BaseModel(nn.Module):
|
||||||
the provided list.
|
the provided list.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
m (nn.Module): The layer to be profiled.
|
m (torch.nn.Module): The layer to be profiled.
|
||||||
x (torch.Tensor): The input data to the layer.
|
x (torch.Tensor): The input data to the layer.
|
||||||
dt (list): A list to store the computation time of the layer.
|
dt (list): A list to store the computation time of the layer.
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
"""
|
||||||
c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
|
c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
|
||||||
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
|
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
|
||||||
|
|
@ -195,7 +191,7 @@ class BaseModel(nn.Module):
|
||||||
computation efficiency.
|
computation efficiency.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(nn.Module): The fused model is returned.
|
(torch.nn.Module): The fused model is returned.
|
||||||
"""
|
"""
|
||||||
if not self.is_fused():
|
if not self.is_fused():
|
||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
|
|
@ -229,7 +225,7 @@ class BaseModel(nn.Module):
|
||||||
Returns:
|
Returns:
|
||||||
(bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
|
(bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
|
||||||
"""
|
"""
|
||||||
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
|
bn = tuple(v for k, v in torch.nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
|
||||||
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
||||||
|
|
||||||
def info(self, detailed=False, verbose=True, imgsz=640):
|
def info(self, detailed=False, verbose=True, imgsz=640):
|
||||||
|
|
@ -304,7 +300,7 @@ class DetectionModel(BaseModel):
|
||||||
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
||||||
if self.yaml["backbone"][0][2] == "Silence":
|
if self.yaml["backbone"][0][2] == "Silence":
|
||||||
LOGGER.warning(
|
LOGGER.warning(
|
||||||
"WARNING ⚠️ YOLOv9 `Silence` module is deprecated in favor of nn.Identity. "
|
"WARNING ⚠️ YOLOv9 `Silence` module is deprecated in favor of torch.nn.Identity. "
|
||||||
"Please delete local *.pt file and re-download the latest model checkpoint."
|
"Please delete local *.pt file and re-download the latest model checkpoint."
|
||||||
)
|
)
|
||||||
self.yaml["backbone"][0][2] = "nn.Identity"
|
self.yaml["backbone"][0][2] = "nn.Identity"
|
||||||
|
|
@ -458,20 +454,22 @@ class ClassificationModel(BaseModel):
|
||||||
name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module
|
name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module
|
||||||
if isinstance(m, Classify): # YOLO Classify() head
|
if isinstance(m, Classify): # YOLO Classify() head
|
||||||
if m.linear.out_features != nc:
|
if m.linear.out_features != nc:
|
||||||
m.linear = nn.Linear(m.linear.in_features, nc)
|
m.linear = torch.nn.Linear(m.linear.in_features, nc)
|
||||||
elif isinstance(m, nn.Linear): # ResNet, EfficientNet
|
elif isinstance(m, torch.nn.Linear): # ResNet, EfficientNet
|
||||||
if m.out_features != nc:
|
if m.out_features != nc:
|
||||||
setattr(model, name, nn.Linear(m.in_features, nc))
|
setattr(model, name, torch.nn.Linear(m.in_features, nc))
|
||||||
elif isinstance(m, nn.Sequential):
|
elif isinstance(m, torch.nn.Sequential):
|
||||||
types = [type(x) for x in m]
|
types = [type(x) for x in m]
|
||||||
if nn.Linear in types:
|
if torch.nn.Linear in types:
|
||||||
i = len(types) - 1 - types[::-1].index(nn.Linear) # last nn.Linear index
|
i = len(types) - 1 - types[::-1].index(torch.nn.Linear) # last torch.nn.Linear index
|
||||||
if m[i].out_features != nc:
|
if m[i].out_features != nc:
|
||||||
m[i] = nn.Linear(m[i].in_features, nc)
|
m[i] = torch.nn.Linear(m[i].in_features, nc)
|
||||||
elif nn.Conv2d in types:
|
elif torch.nn.Conv2d in types:
|
||||||
i = len(types) - 1 - types[::-1].index(nn.Conv2d) # last nn.Conv2d index
|
i = len(types) - 1 - types[::-1].index(torch.nn.Conv2d) # last torch.nn.Conv2d index
|
||||||
if m[i].out_channels != nc:
|
if m[i].out_channels != nc:
|
||||||
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
|
m[i] = torch.nn.Conv2d(
|
||||||
|
m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None
|
||||||
|
)
|
||||||
|
|
||||||
def init_criterion(self):
|
def init_criterion(self):
|
||||||
"""Initialize the loss criterion for the ClassificationModel."""
|
"""Initialize the loss criterion for the ClassificationModel."""
|
||||||
|
|
@ -587,7 +585,7 @@ class RTDETRDetectionModel(DetectionModel):
|
||||||
if visualize:
|
if visualize:
|
||||||
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
||||||
if embed and m.i in embed:
|
if embed and m.i in embed:
|
||||||
embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
||||||
if m.i == max(embed):
|
if m.i == max(embed):
|
||||||
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
||||||
head = self.model[-1]
|
head = self.model[-1]
|
||||||
|
|
@ -663,7 +661,7 @@ class WorldModel(DetectionModel):
|
||||||
if visualize:
|
if visualize:
|
||||||
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
||||||
if embed and m.i in embed:
|
if embed and m.i in embed:
|
||||||
embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
||||||
if m.i == max(embed):
|
if m.i == max(embed):
|
||||||
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
||||||
return x
|
return x
|
||||||
|
|
@ -684,7 +682,7 @@ class WorldModel(DetectionModel):
|
||||||
return self.criterion(preds, batch)
|
return self.criterion(preds, batch)
|
||||||
|
|
||||||
|
|
||||||
class Ensemble(nn.ModuleList):
|
class Ensemble(torch.nn.ModuleList):
|
||||||
"""Ensemble of models."""
|
"""Ensemble of models."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
@ -887,7 +885,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||||
for m in ensemble.modules():
|
for m in ensemble.modules():
|
||||||
if hasattr(m, "inplace"):
|
if hasattr(m, "inplace"):
|
||||||
m.inplace = inplace
|
m.inplace = inplace
|
||||||
elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
|
elif isinstance(m, torch.nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
|
||||||
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
||||||
|
|
||||||
# Return model
|
# Return model
|
||||||
|
|
@ -922,7 +920,7 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if hasattr(m, "inplace"):
|
if hasattr(m, "inplace"):
|
||||||
m.inplace = inplace
|
m.inplace = inplace
|
||||||
elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
|
elif isinstance(m, torch.nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
|
||||||
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
||||||
|
|
||||||
# Return model and ckpt
|
# Return model and ckpt
|
||||||
|
|
@ -946,7 +944,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||||
depth, width, max_channels = scales[scale]
|
depth, width, max_channels = scales[scale]
|
||||||
|
|
||||||
if act:
|
if act:
|
||||||
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
|
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = torch.nn.SiLU()
|
||||||
if verbose:
|
if verbose:
|
||||||
LOGGER.info(f"{colorstr('activation:')} {act}") # print
|
LOGGER.info(f"{colorstr('activation:')} {act}") # print
|
||||||
|
|
||||||
|
|
@ -982,7 +980,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||||
C3,
|
C3,
|
||||||
C3TR,
|
C3TR,
|
||||||
C3Ghost,
|
C3Ghost,
|
||||||
nn.ConvTranspose2d,
|
torch.nn.ConvTranspose2d,
|
||||||
DWConvTranspose2d,
|
DWConvTranspose2d,
|
||||||
C3x,
|
C3x,
|
||||||
RepC3,
|
RepC3,
|
||||||
|
|
@ -1048,7 +1046,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||||
n = 1
|
n = 1
|
||||||
elif m is ResNetLayer:
|
elif m is ResNetLayer:
|
||||||
c2 = args[1] if args[3] else args[1] * 4
|
c2 = args[1] if args[3] else args[1] * 4
|
||||||
elif m is nn.BatchNorm2d:
|
elif m is torch.nn.BatchNorm2d:
|
||||||
args = [ch[f]]
|
args = [ch[f]]
|
||||||
elif m is Concat:
|
elif m is Concat:
|
||||||
c2 = sum(ch[x] for x in f)
|
c2 = sum(ch[x] for x in f)
|
||||||
|
|
@ -1073,7 +1071,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||||
else:
|
else:
|
||||||
c2 = ch[f]
|
c2 = ch[f]
|
||||||
|
|
||||||
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
m_ = torch.nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
||||||
t = str(m)[8:-2].replace("__main__.", "") # module type
|
t = str(m)[8:-2].replace("__main__.", "") # module type
|
||||||
m_.np = sum(x.numel() for x in m_.parameters()) # number params
|
m_.np = sum(x.numel() for x in m_.parameters()) # number params
|
||||||
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
|
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
|
||||||
|
|
@ -1084,7 +1082,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||||
if i == 0:
|
if i == 0:
|
||||||
ch = []
|
ch = []
|
||||||
ch.append(c2)
|
ch.append(c2)
|
||||||
return nn.Sequential(*layers), sorted(save)
|
return torch.nn.Sequential(*layers), sorted(save)
|
||||||
|
|
||||||
|
|
||||||
def yaml_model_load(path):
|
def yaml_model_load(path):
|
||||||
|
|
@ -1126,7 +1124,7 @@ def guess_model_task(model):
|
||||||
Guess the task of a PyTorch model from its architecture or configuration.
|
Guess the task of a PyTorch model from its architecture or configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (nn.Module | dict): PyTorch model or model configuration in YAML format.
|
model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(str): Task of the model ('detect', 'segment', 'classify', 'pose').
|
(str): Task of the model ('detect', 'segment', 'classify', 'pose').
|
||||||
|
|
@ -1154,7 +1152,7 @@ def guess_model_task(model):
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
return cfg2task(model)
|
return cfg2task(model)
|
||||||
# Guess from PyTorch model
|
# Guess from PyTorch model
|
||||||
if isinstance(model, nn.Module): # PyTorch model
|
if isinstance(model, torch.nn.Module): # PyTorch model
|
||||||
for x in "model.args", "model.model.args", "model.model.model.args":
|
for x in "model.args", "model.model.args", "model.model.model.args":
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
return eval(x)["task"]
|
return eval(x)["task"]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue