ultralytics 8.3.71 require explicit torch.nn usage (#19067)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: RizwanMunawar <chr043416@gmail.com> Co-authored-by: Muhammad Rizwan Munawar <muhammadrizwanmunawar123@gmail.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
17450e9646
commit
5bca9341e8
10 changed files with 50 additions and 51 deletions
|
|
@ -9,7 +9,6 @@ from pathlib import Path
|
|||
|
||||
import thop
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ultralytics.nn.modules import (
|
||||
AIFI,
|
||||
|
|
@ -88,7 +87,7 @@ from ultralytics.utils.torch_utils import (
|
|||
)
|
||||
|
||||
|
||||
class BaseModel(nn.Module):
|
||||
class BaseModel(torch.nn.Module):
|
||||
"""The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family."""
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
|
|
@ -151,7 +150,7 @@ class BaseModel(nn.Module):
|
|||
if visualize:
|
||||
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
||||
if embed and m.i in embed:
|
||||
embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
||||
embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
||||
if m.i == max(embed):
|
||||
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
||||
return x
|
||||
|
|
@ -170,12 +169,9 @@ class BaseModel(nn.Module):
|
|||
the provided list.
|
||||
|
||||
Args:
|
||||
m (nn.Module): The layer to be profiled.
|
||||
m (torch.nn.Module): The layer to be profiled.
|
||||
x (torch.Tensor): The input data to the layer.
|
||||
dt (list): A list to store the computation time of the layer.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
|
||||
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
|
||||
|
|
@ -195,7 +191,7 @@ class BaseModel(nn.Module):
|
|||
computation efficiency.
|
||||
|
||||
Returns:
|
||||
(nn.Module): The fused model is returned.
|
||||
(torch.nn.Module): The fused model is returned.
|
||||
"""
|
||||
if not self.is_fused():
|
||||
for m in self.model.modules():
|
||||
|
|
@ -229,7 +225,7 @@ class BaseModel(nn.Module):
|
|||
Returns:
|
||||
(bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
|
||||
"""
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
|
||||
bn = tuple(v for k, v in torch.nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
|
||||
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
||||
|
||||
def info(self, detailed=False, verbose=True, imgsz=640):
|
||||
|
|
@ -304,7 +300,7 @@ class DetectionModel(BaseModel):
|
|||
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
||||
if self.yaml["backbone"][0][2] == "Silence":
|
||||
LOGGER.warning(
|
||||
"WARNING ⚠️ YOLOv9 `Silence` module is deprecated in favor of nn.Identity. "
|
||||
"WARNING ⚠️ YOLOv9 `Silence` module is deprecated in favor of torch.nn.Identity. "
|
||||
"Please delete local *.pt file and re-download the latest model checkpoint."
|
||||
)
|
||||
self.yaml["backbone"][0][2] = "nn.Identity"
|
||||
|
|
@ -458,20 +454,22 @@ class ClassificationModel(BaseModel):
|
|||
name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module
|
||||
if isinstance(m, Classify): # YOLO Classify() head
|
||||
if m.linear.out_features != nc:
|
||||
m.linear = nn.Linear(m.linear.in_features, nc)
|
||||
elif isinstance(m, nn.Linear): # ResNet, EfficientNet
|
||||
m.linear = torch.nn.Linear(m.linear.in_features, nc)
|
||||
elif isinstance(m, torch.nn.Linear): # ResNet, EfficientNet
|
||||
if m.out_features != nc:
|
||||
setattr(model, name, nn.Linear(m.in_features, nc))
|
||||
elif isinstance(m, nn.Sequential):
|
||||
setattr(model, name, torch.nn.Linear(m.in_features, nc))
|
||||
elif isinstance(m, torch.nn.Sequential):
|
||||
types = [type(x) for x in m]
|
||||
if nn.Linear in types:
|
||||
i = len(types) - 1 - types[::-1].index(nn.Linear) # last nn.Linear index
|
||||
if torch.nn.Linear in types:
|
||||
i = len(types) - 1 - types[::-1].index(torch.nn.Linear) # last torch.nn.Linear index
|
||||
if m[i].out_features != nc:
|
||||
m[i] = nn.Linear(m[i].in_features, nc)
|
||||
elif nn.Conv2d in types:
|
||||
i = len(types) - 1 - types[::-1].index(nn.Conv2d) # last nn.Conv2d index
|
||||
m[i] = torch.nn.Linear(m[i].in_features, nc)
|
||||
elif torch.nn.Conv2d in types:
|
||||
i = len(types) - 1 - types[::-1].index(torch.nn.Conv2d) # last torch.nn.Conv2d index
|
||||
if m[i].out_channels != nc:
|
||||
m[i] = nn.Conv2d(m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None)
|
||||
m[i] = torch.nn.Conv2d(
|
||||
m[i].in_channels, nc, m[i].kernel_size, m[i].stride, bias=m[i].bias is not None
|
||||
)
|
||||
|
||||
def init_criterion(self):
|
||||
"""Initialize the loss criterion for the ClassificationModel."""
|
||||
|
|
@ -587,7 +585,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|||
if visualize:
|
||||
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
||||
if embed and m.i in embed:
|
||||
embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
||||
embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
||||
if m.i == max(embed):
|
||||
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
||||
head = self.model[-1]
|
||||
|
|
@ -663,7 +661,7 @@ class WorldModel(DetectionModel):
|
|||
if visualize:
|
||||
feature_visualization(x, m.type, m.i, save_dir=visualize)
|
||||
if embed and m.i in embed:
|
||||
embeddings.append(nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
||||
embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
|
||||
if m.i == max(embed):
|
||||
return torch.unbind(torch.cat(embeddings, 1), dim=0)
|
||||
return x
|
||||
|
|
@ -684,7 +682,7 @@ class WorldModel(DetectionModel):
|
|||
return self.criterion(preds, batch)
|
||||
|
||||
|
||||
class Ensemble(nn.ModuleList):
|
||||
class Ensemble(torch.nn.ModuleList):
|
||||
"""Ensemble of models."""
|
||||
|
||||
def __init__(self):
|
||||
|
|
@ -887,7 +885,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|||
for m in ensemble.modules():
|
||||
if hasattr(m, "inplace"):
|
||||
m.inplace = inplace
|
||||
elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
|
||||
elif isinstance(m, torch.nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
|
||||
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
||||
|
||||
# Return model
|
||||
|
|
@ -922,7 +920,7 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
|||
for m in model.modules():
|
||||
if hasattr(m, "inplace"):
|
||||
m.inplace = inplace
|
||||
elif isinstance(m, nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
|
||||
elif isinstance(m, torch.nn.Upsample) and not hasattr(m, "recompute_scale_factor"):
|
||||
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
||||
|
||||
# Return model and ckpt
|
||||
|
|
@ -946,7 +944,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
depth, width, max_channels = scales[scale]
|
||||
|
||||
if act:
|
||||
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
|
||||
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = torch.nn.SiLU()
|
||||
if verbose:
|
||||
LOGGER.info(f"{colorstr('activation:')} {act}") # print
|
||||
|
||||
|
|
@ -982,7 +980,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
C3,
|
||||
C3TR,
|
||||
C3Ghost,
|
||||
nn.ConvTranspose2d,
|
||||
torch.nn.ConvTranspose2d,
|
||||
DWConvTranspose2d,
|
||||
C3x,
|
||||
RepC3,
|
||||
|
|
@ -1048,7 +1046,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
n = 1
|
||||
elif m is ResNetLayer:
|
||||
c2 = args[1] if args[3] else args[1] * 4
|
||||
elif m is nn.BatchNorm2d:
|
||||
elif m is torch.nn.BatchNorm2d:
|
||||
args = [ch[f]]
|
||||
elif m is Concat:
|
||||
c2 = sum(ch[x] for x in f)
|
||||
|
|
@ -1073,7 +1071,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
else:
|
||||
c2 = ch[f]
|
||||
|
||||
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
||||
m_ = torch.nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
||||
t = str(m)[8:-2].replace("__main__.", "") # module type
|
||||
m_.np = sum(x.numel() for x in m_.parameters()) # number params
|
||||
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
|
||||
|
|
@ -1084,7 +1082,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
if i == 0:
|
||||
ch = []
|
||||
ch.append(c2)
|
||||
return nn.Sequential(*layers), sorted(save)
|
||||
return torch.nn.Sequential(*layers), sorted(save)
|
||||
|
||||
|
||||
def yaml_model_load(path):
|
||||
|
|
@ -1126,7 +1124,7 @@ def guess_model_task(model):
|
|||
Guess the task of a PyTorch model from its architecture or configuration.
|
||||
|
||||
Args:
|
||||
model (nn.Module | dict): PyTorch model or model configuration in YAML format.
|
||||
model (torch.nn.Module | dict): PyTorch model or model configuration in YAML format.
|
||||
|
||||
Returns:
|
||||
(str): Task of the model ('detect', 'segment', 'classify', 'pose').
|
||||
|
|
@ -1154,7 +1152,7 @@ def guess_model_task(model):
|
|||
with contextlib.suppress(Exception):
|
||||
return cfg2task(model)
|
||||
# Guess from PyTorch model
|
||||
if isinstance(model, nn.Module): # PyTorch model
|
||||
if isinstance(model, torch.nn.Module): # PyTorch model
|
||||
for x in "model.args", "model.model.args", "model.model.model.args":
|
||||
with contextlib.suppress(Exception):
|
||||
return eval(x)["task"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue