ultralytics 8.0.239 Ultralytics Actions and hub-sdk adoption (#7431)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com> Co-authored-by: Kayzwer <68285002+Kayzwer@users.noreply.github.com>
This commit is contained in:
parent
e795277391
commit
fe27db2f6e
139 changed files with 6870 additions and 5125 deletions
|
|
@ -7,16 +7,54 @@ from pathlib import Path
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, OBB, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost,
|
||||
C3x, Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv,
|
||||
DWConvTranspose2d, Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3,
|
||||
RepConv, ResNetLayer, RTDETRDecoder, Segment)
|
||||
from ultralytics.nn.modules import (
|
||||
AIFI,
|
||||
C1,
|
||||
C2,
|
||||
C3,
|
||||
C3TR,
|
||||
OBB,
|
||||
SPP,
|
||||
SPPF,
|
||||
Bottleneck,
|
||||
BottleneckCSP,
|
||||
C2f,
|
||||
C3Ghost,
|
||||
C3x,
|
||||
Classify,
|
||||
Concat,
|
||||
Conv,
|
||||
Conv2,
|
||||
ConvTranspose,
|
||||
Detect,
|
||||
DWConv,
|
||||
DWConvTranspose2d,
|
||||
Focus,
|
||||
GhostBottleneck,
|
||||
GhostConv,
|
||||
HGBlock,
|
||||
HGStem,
|
||||
Pose,
|
||||
RepC3,
|
||||
RepConv,
|
||||
ResNetLayer,
|
||||
RTDETRDecoder,
|
||||
Segment,
|
||||
)
|
||||
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
||||
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
|
||||
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8OBBLoss, v8PoseLoss, v8SegmentationLoss
|
||||
from ultralytics.utils.plotting import feature_visualization
|
||||
from ultralytics.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, intersect_dicts,
|
||||
make_divisible, model_info, scale_img, time_sync)
|
||||
from ultralytics.utils.torch_utils import (
|
||||
fuse_conv_and_bn,
|
||||
fuse_deconv_and_bn,
|
||||
initialize_weights,
|
||||
intersect_dicts,
|
||||
make_divisible,
|
||||
model_info,
|
||||
scale_img,
|
||||
time_sync,
|
||||
)
|
||||
|
||||
try:
|
||||
import thop
|
||||
|
|
@ -90,8 +128,10 @@ class BaseModel(nn.Module):
|
|||
|
||||
def _predict_augment(self, x):
|
||||
"""Perform augmentations on input image x and return augmented inference."""
|
||||
LOGGER.warning(f'WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. '
|
||||
f'Reverting to single-scale inference instead.')
|
||||
LOGGER.warning(
|
||||
f"WARNING ⚠️ {self.__class__.__name__} does not support augmented inference yet. "
|
||||
f"Reverting to single-scale inference instead."
|
||||
)
|
||||
return self._predict_once(x)
|
||||
|
||||
def _profile_one_layer(self, m, x, dt):
|
||||
|
|
@ -108,14 +148,14 @@ class BaseModel(nn.Module):
|
|||
None
|
||||
"""
|
||||
c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix
|
||||
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
|
||||
flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # FLOPs
|
||||
t = time_sync()
|
||||
for _ in range(10):
|
||||
m(x.copy() if c else x)
|
||||
dt.append((time_sync() - t) * 100)
|
||||
if m == self.model[0]:
|
||||
LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
|
||||
LOGGER.info(f'{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}')
|
||||
LOGGER.info(f"{dt[-1]:10.2f} {flops:10.2f} {m.np:10.0f} {m.type}")
|
||||
if c:
|
||||
LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
|
||||
|
||||
|
|
@ -129,15 +169,15 @@ class BaseModel(nn.Module):
|
|||
"""
|
||||
if not self.is_fused():
|
||||
for m in self.model.modules():
|
||||
if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, 'bn'):
|
||||
if isinstance(m, (Conv, Conv2, DWConv)) and hasattr(m, "bn"):
|
||||
if isinstance(m, Conv2):
|
||||
m.fuse_convs()
|
||||
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
|
||||
delattr(m, 'bn') # remove batchnorm
|
||||
delattr(m, "bn") # remove batchnorm
|
||||
m.forward = m.forward_fuse # update forward
|
||||
if isinstance(m, ConvTranspose) and hasattr(m, 'bn'):
|
||||
if isinstance(m, ConvTranspose) and hasattr(m, "bn"):
|
||||
m.conv_transpose = fuse_deconv_and_bn(m.conv_transpose, m.bn)
|
||||
delattr(m, 'bn') # remove batchnorm
|
||||
delattr(m, "bn") # remove batchnorm
|
||||
m.forward = m.forward_fuse # update forward
|
||||
if isinstance(m, RepConv):
|
||||
m.fuse_convs()
|
||||
|
|
@ -156,7 +196,7 @@ class BaseModel(nn.Module):
|
|||
Returns:
|
||||
(bool): True if the number of BatchNorm layers in the model is less than the threshold, False otherwise.
|
||||
"""
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
|
||||
return sum(isinstance(v, bn) for v in self.modules()) < thresh # True if < 'thresh' BatchNorm layers in model
|
||||
|
||||
def info(self, detailed=False, verbose=True, imgsz=640):
|
||||
|
|
@ -196,12 +236,12 @@ class BaseModel(nn.Module):
|
|||
weights (dict | torch.nn.Module): The pre-trained weights to be loaded.
|
||||
verbose (bool, optional): Whether to log the transfer progress. Defaults to True.
|
||||
"""
|
||||
model = weights['model'] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
||||
model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
||||
csd = model.float().state_dict() # checkpoint state_dict as FP32
|
||||
csd = intersect_dicts(csd, self.state_dict()) # intersect
|
||||
self.load_state_dict(csd, strict=False) # load
|
||||
if verbose:
|
||||
LOGGER.info(f'Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights')
|
||||
LOGGER.info(f"Transferred {len(csd)}/{len(self.model.state_dict())} items from pretrained weights")
|
||||
|
||||
def loss(self, batch, preds=None):
|
||||
"""
|
||||
|
|
@ -211,33 +251,33 @@ class BaseModel(nn.Module):
|
|||
batch (dict): Batch to compute loss on
|
||||
preds (torch.Tensor | List[torch.Tensor]): Predictions.
|
||||
"""
|
||||
if not hasattr(self, 'criterion'):
|
||||
if not hasattr(self, "criterion"):
|
||||
self.criterion = self.init_criterion()
|
||||
|
||||
preds = self.forward(batch['img']) if preds is None else preds
|
||||
preds = self.forward(batch["img"]) if preds is None else preds
|
||||
return self.criterion(preds, batch)
|
||||
|
||||
def init_criterion(self):
|
||||
"""Initialize the loss criterion for the BaseModel."""
|
||||
raise NotImplementedError('compute_loss() needs to be implemented by task heads')
|
||||
raise NotImplementedError("compute_loss() needs to be implemented by task heads")
|
||||
|
||||
|
||||
class DetectionModel(BaseModel):
|
||||
"""YOLOv8 detection model."""
|
||||
|
||||
def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes
|
||||
def __init__(self, cfg="yolov8n.yaml", ch=3, nc=None, verbose=True): # model, input channels, number of classes
|
||||
"""Initialize the YOLOv8 detection model with the given config and parameters."""
|
||||
super().__init__()
|
||||
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
||||
|
||||
# Define model
|
||||
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
||||
if nc and nc != self.yaml['nc']:
|
||||
ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
|
||||
if nc and nc != self.yaml["nc"]:
|
||||
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
||||
self.yaml['nc'] = nc # override YAML value
|
||||
self.yaml["nc"] = nc # override YAML value
|
||||
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
||||
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
|
||||
self.inplace = self.yaml.get('inplace', True)
|
||||
self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
|
||||
self.inplace = self.yaml.get("inplace", True)
|
||||
|
||||
# Build strides
|
||||
m = self.model[-1] # Detect()
|
||||
|
|
@ -255,7 +295,7 @@ class DetectionModel(BaseModel):
|
|||
initialize_weights(self)
|
||||
if verbose:
|
||||
self.info()
|
||||
LOGGER.info('')
|
||||
LOGGER.info("")
|
||||
|
||||
def _predict_augment(self, x):
|
||||
"""Perform augmentations on input image x and return augmented inference and train outputs."""
|
||||
|
|
@ -285,9 +325,9 @@ class DetectionModel(BaseModel):
|
|||
def _clip_augmented(self, y):
|
||||
"""Clip YOLO augmented inference tails."""
|
||||
nl = self.model[-1].nl # number of detection layers (P3-P5)
|
||||
g = sum(4 ** x for x in range(nl)) # grid points
|
||||
g = sum(4**x for x in range(nl)) # grid points
|
||||
e = 1 # exclude layer count
|
||||
i = (y[0].shape[-1] // g) * sum(4 ** x for x in range(e)) # indices
|
||||
i = (y[0].shape[-1] // g) * sum(4**x for x in range(e)) # indices
|
||||
y[0] = y[0][..., :-i] # large
|
||||
i = (y[-1].shape[-1] // g) * sum(4 ** (nl - 1 - x) for x in range(e)) # indices
|
||||
y[-1] = y[-1][..., i:] # small
|
||||
|
|
@ -301,7 +341,7 @@ class DetectionModel(BaseModel):
|
|||
class OBBModel(DetectionModel):
|
||||
""""YOLOv8 Oriented Bounding Box (OBB) model."""
|
||||
|
||||
def __init__(self, cfg='yolov8n-obb.yaml', ch=3, nc=None, verbose=True):
|
||||
def __init__(self, cfg="yolov8n-obb.yaml", ch=3, nc=None, verbose=True):
|
||||
"""Initialize YOLOv8 OBB model with given config and parameters."""
|
||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||
|
||||
|
|
@ -313,7 +353,7 @@ class OBBModel(DetectionModel):
|
|||
class SegmentationModel(DetectionModel):
|
||||
"""YOLOv8 segmentation model."""
|
||||
|
||||
def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True):
|
||||
def __init__(self, cfg="yolov8n-seg.yaml", ch=3, nc=None, verbose=True):
|
||||
"""Initialize YOLOv8 segmentation model with given config and parameters."""
|
||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||
|
||||
|
|
@ -325,13 +365,13 @@ class SegmentationModel(DetectionModel):
|
|||
class PoseModel(DetectionModel):
|
||||
"""YOLOv8 pose model."""
|
||||
|
||||
def __init__(self, cfg='yolov8n-pose.yaml', ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
|
||||
def __init__(self, cfg="yolov8n-pose.yaml", ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
|
||||
"""Initialize YOLOv8 Pose model."""
|
||||
if not isinstance(cfg, dict):
|
||||
cfg = yaml_model_load(cfg) # load model YAML
|
||||
if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg['kpt_shape']):
|
||||
if any(data_kpt_shape) and list(data_kpt_shape) != list(cfg["kpt_shape"]):
|
||||
LOGGER.info(f"Overriding model.yaml kpt_shape={cfg['kpt_shape']} with kpt_shape={data_kpt_shape}")
|
||||
cfg['kpt_shape'] = data_kpt_shape
|
||||
cfg["kpt_shape"] = data_kpt_shape
|
||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||
|
||||
def init_criterion(self):
|
||||
|
|
@ -342,7 +382,7 @@ class PoseModel(DetectionModel):
|
|||
class ClassificationModel(BaseModel):
|
||||
"""YOLOv8 classification model."""
|
||||
|
||||
def __init__(self, cfg='yolov8n-cls.yaml', ch=3, nc=None, verbose=True):
|
||||
def __init__(self, cfg="yolov8n-cls.yaml", ch=3, nc=None, verbose=True):
|
||||
"""Init ClassificationModel with YAML, channels, number of classes, verbose flag."""
|
||||
super().__init__()
|
||||
self._from_yaml(cfg, ch, nc, verbose)
|
||||
|
|
@ -352,21 +392,21 @@ class ClassificationModel(BaseModel):
|
|||
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
||||
|
||||
# Define model
|
||||
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
||||
if nc and nc != self.yaml['nc']:
|
||||
ch = self.yaml["ch"] = self.yaml.get("ch", ch) # input channels
|
||||
if nc and nc != self.yaml["nc"]:
|
||||
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
||||
self.yaml['nc'] = nc # override YAML value
|
||||
elif not nc and not self.yaml.get('nc', None):
|
||||
raise ValueError('nc not specified. Must specify nc in model.yaml or function arguments.')
|
||||
self.yaml["nc"] = nc # override YAML value
|
||||
elif not nc and not self.yaml.get("nc", None):
|
||||
raise ValueError("nc not specified. Must specify nc in model.yaml or function arguments.")
|
||||
self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
|
||||
self.stride = torch.Tensor([1]) # no stride constraints
|
||||
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
|
||||
self.names = {i: f"{i}" for i in range(self.yaml["nc"])} # default names dict
|
||||
self.info()
|
||||
|
||||
@staticmethod
|
||||
def reshape_outputs(model, nc):
|
||||
"""Update a TorchVision classification model to class count 'n' if required."""
|
||||
name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
|
||||
name, m = list((model.model if hasattr(model, "model") else model).named_children())[-1] # last module
|
||||
if isinstance(m, Classify): # YOLO Classify() head
|
||||
if m.linear.out_features != nc:
|
||||
m.linear = nn.Linear(m.linear.in_features, nc)
|
||||
|
|
@ -409,7 +449,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|||
predict: Performs a forward pass through the network and returns the output.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg='rtdetr-l.yaml', ch=3, nc=None, verbose=True):
|
||||
def __init__(self, cfg="rtdetr-l.yaml", ch=3, nc=None, verbose=True):
|
||||
"""
|
||||
Initialize the RTDETRDetectionModel.
|
||||
|
||||
|
|
@ -438,39 +478,39 @@ class RTDETRDetectionModel(DetectionModel):
|
|||
Returns:
|
||||
(tuple): A tuple containing the total loss and main three losses in a tensor.
|
||||
"""
|
||||
if not hasattr(self, 'criterion'):
|
||||
if not hasattr(self, "criterion"):
|
||||
self.criterion = self.init_criterion()
|
||||
|
||||
img = batch['img']
|
||||
img = batch["img"]
|
||||
# NOTE: preprocess gt_bbox and gt_labels to list.
|
||||
bs = len(img)
|
||||
batch_idx = batch['batch_idx']
|
||||
batch_idx = batch["batch_idx"]
|
||||
gt_groups = [(batch_idx == i).sum().item() for i in range(bs)]
|
||||
targets = {
|
||||
'cls': batch['cls'].to(img.device, dtype=torch.long).view(-1),
|
||||
'bboxes': batch['bboxes'].to(device=img.device),
|
||||
'batch_idx': batch_idx.to(img.device, dtype=torch.long).view(-1),
|
||||
'gt_groups': gt_groups}
|
||||
"cls": batch["cls"].to(img.device, dtype=torch.long).view(-1),
|
||||
"bboxes": batch["bboxes"].to(device=img.device),
|
||||
"batch_idx": batch_idx.to(img.device, dtype=torch.long).view(-1),
|
||||
"gt_groups": gt_groups,
|
||||
}
|
||||
|
||||
preds = self.predict(img, batch=targets) if preds is None else preds
|
||||
dec_bboxes, dec_scores, enc_bboxes, enc_scores, dn_meta = preds if self.training else preds[1]
|
||||
if dn_meta is None:
|
||||
dn_bboxes, dn_scores = None, None
|
||||
else:
|
||||
dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta['dn_num_split'], dim=2)
|
||||
dn_scores, dec_scores = torch.split(dec_scores, dn_meta['dn_num_split'], dim=2)
|
||||
dn_bboxes, dec_bboxes = torch.split(dec_bboxes, dn_meta["dn_num_split"], dim=2)
|
||||
dn_scores, dec_scores = torch.split(dec_scores, dn_meta["dn_num_split"], dim=2)
|
||||
|
||||
dec_bboxes = torch.cat([enc_bboxes.unsqueeze(0), dec_bboxes]) # (7, bs, 300, 4)
|
||||
dec_scores = torch.cat([enc_scores.unsqueeze(0), dec_scores])
|
||||
|
||||
loss = self.criterion((dec_bboxes, dec_scores),
|
||||
targets,
|
||||
dn_bboxes=dn_bboxes,
|
||||
dn_scores=dn_scores,
|
||||
dn_meta=dn_meta)
|
||||
loss = self.criterion(
|
||||
(dec_bboxes, dec_scores), targets, dn_bboxes=dn_bboxes, dn_scores=dn_scores, dn_meta=dn_meta
|
||||
)
|
||||
# NOTE: There are like 12 losses in RTDETR, backward with all losses but only show the main three losses.
|
||||
return sum(loss.values()), torch.as_tensor([loss[k].detach() for k in ['loss_giou', 'loss_class', 'loss_bbox']],
|
||||
device=img.device)
|
||||
return sum(loss.values()), torch.as_tensor(
|
||||
[loss[k].detach() for k in ["loss_giou", "loss_class", "loss_bbox"]], device=img.device
|
||||
)
|
||||
|
||||
def predict(self, x, profile=False, visualize=False, batch=None, augment=False, embed=None):
|
||||
"""
|
||||
|
|
@ -553,6 +593,7 @@ def temporary_modules(modules=None):
|
|||
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
try:
|
||||
# Set modules in sys.modules under their old name
|
||||
for old, new in modules.items():
|
||||
|
|
@ -580,30 +621,38 @@ def torch_safe_load(weight):
|
|||
"""
|
||||
from ultralytics.utils.downloads import attempt_download_asset
|
||||
|
||||
check_suffix(file=weight, suffix='.pt')
|
||||
check_suffix(file=weight, suffix=".pt")
|
||||
file = attempt_download_asset(weight) # search online if missing locally
|
||||
try:
|
||||
with temporary_modules({
|
||||
'ultralytics.yolo.utils': 'ultralytics.utils',
|
||||
'ultralytics.yolo.v8': 'ultralytics.models.yolo',
|
||||
'ultralytics.yolo.data': 'ultralytics.data'}): # for legacy 8.0 Classify and Pose models
|
||||
return torch.load(file, map_location='cpu'), file # load
|
||||
with temporary_modules(
|
||||
{
|
||||
"ultralytics.yolo.utils": "ultralytics.utils",
|
||||
"ultralytics.yolo.v8": "ultralytics.models.yolo",
|
||||
"ultralytics.yolo.data": "ultralytics.data",
|
||||
}
|
||||
): # for legacy 8.0 Classify and Pose models
|
||||
return torch.load(file, map_location="cpu"), file # load
|
||||
|
||||
except ModuleNotFoundError as e: # e.name is missing module name
|
||||
if e.name == 'models':
|
||||
if e.name == "models":
|
||||
raise TypeError(
|
||||
emojis(f'ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained '
|
||||
f'with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with '
|
||||
f'YOLOv8 at https://github.com/ultralytics/ultralytics.'
|
||||
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
||||
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")) from e
|
||||
LOGGER.warning(f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
|
||||
f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
|
||||
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
||||
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'")
|
||||
emojis(
|
||||
f"ERROR ❌️ {weight} appears to be an Ultralytics YOLOv5 model originally trained "
|
||||
f"with https://github.com/ultralytics/yolov5.\nThis model is NOT forwards compatible with "
|
||||
f"YOLOv8 at https://github.com/ultralytics/ultralytics."
|
||||
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
||||
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
|
||||
)
|
||||
) from e
|
||||
LOGGER.warning(
|
||||
f"WARNING ⚠️ {weight} appears to require '{e.name}', which is not in ultralytics requirements."
|
||||
f"\nAutoInstall will run now for '{e.name}' but this feature will be removed in the future."
|
||||
f"\nRecommend fixes are to train a new model using the latest 'ultralytics' package or to "
|
||||
f"run a command with an official YOLOv8 model, i.e. 'yolo predict model=yolov8n.pt'"
|
||||
)
|
||||
check_requirements(e.name) # install missing module
|
||||
|
||||
return torch.load(file, map_location='cpu'), file # load
|
||||
return torch.load(file, map_location="cpu"), file # load
|
||||
|
||||
|
||||
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||
|
|
@ -612,25 +661,25 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|||
ensemble = Ensemble()
|
||||
for w in weights if isinstance(weights, list) else [weights]:
|
||||
ckpt, w = torch_safe_load(w) # load ckpt
|
||||
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} if 'train_args' in ckpt else None # combined args
|
||||
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||
args = {**DEFAULT_CFG_DICT, **ckpt["train_args"]} if "train_args" in ckpt else None # combined args
|
||||
model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
|
||||
|
||||
# Model compatibility updates
|
||||
model.args = args # attach args to model
|
||||
model.pt_path = w # attach *.pt file path to model
|
||||
model.task = guess_model_task(model)
|
||||
if not hasattr(model, 'stride'):
|
||||
model.stride = torch.tensor([32.])
|
||||
if not hasattr(model, "stride"):
|
||||
model.stride = torch.tensor([32.0])
|
||||
|
||||
# Append
|
||||
ensemble.append(model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval()) # model in eval mode
|
||||
ensemble.append(model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval()) # model in eval mode
|
||||
|
||||
# Module updates
|
||||
for m in ensemble.modules():
|
||||
t = type(m)
|
||||
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
|
||||
m.inplace = inplace
|
||||
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
||||
elif t is nn.Upsample and not hasattr(m, "recompute_scale_factor"):
|
||||
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
||||
|
||||
# Return model
|
||||
|
|
@ -638,35 +687,35 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|||
return ensemble[-1]
|
||||
|
||||
# Return ensemble
|
||||
LOGGER.info(f'Ensemble created with {weights}\n')
|
||||
for k in 'names', 'nc', 'yaml':
|
||||
LOGGER.info(f"Ensemble created with {weights}\n")
|
||||
for k in "names", "nc", "yaml":
|
||||
setattr(ensemble, k, getattr(ensemble[0], k))
|
||||
ensemble.stride = ensemble[torch.argmax(torch.tensor([m.stride.max() for m in ensemble])).int()].stride
|
||||
assert all(ensemble[0].nc == m.nc for m in ensemble), f'Models differ in class counts {[m.nc for m in ensemble]}'
|
||||
assert all(ensemble[0].nc == m.nc for m in ensemble), f"Models differ in class counts {[m.nc for m in ensemble]}"
|
||||
return ensemble
|
||||
|
||||
|
||||
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
||||
"""Loads a single model weights."""
|
||||
ckpt, weight = torch_safe_load(weight) # load ckpt
|
||||
args = {**DEFAULT_CFG_DICT, **(ckpt.get('train_args', {}))} # combine model and default args, preferring model args
|
||||
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||
args = {**DEFAULT_CFG_DICT, **(ckpt.get("train_args", {}))} # combine model and default args, preferring model args
|
||||
model = (ckpt.get("ema") or ckpt["model"]).to(device).float() # FP32 model
|
||||
|
||||
# Model compatibility updates
|
||||
model.args = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # attach args to model
|
||||
model.pt_path = weight # attach *.pt file path to model
|
||||
model.task = guess_model_task(model)
|
||||
if not hasattr(model, 'stride'):
|
||||
model.stride = torch.tensor([32.])
|
||||
if not hasattr(model, "stride"):
|
||||
model.stride = torch.tensor([32.0])
|
||||
|
||||
model = model.fuse().eval() if fuse and hasattr(model, 'fuse') else model.eval() # model in eval mode
|
||||
model = model.fuse().eval() if fuse and hasattr(model, "fuse") else model.eval() # model in eval mode
|
||||
|
||||
# Module updates
|
||||
for m in model.modules():
|
||||
t = type(m)
|
||||
if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Segment, Pose, OBB):
|
||||
m.inplace = inplace
|
||||
elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
|
||||
elif t is nn.Upsample and not hasattr(m, "recompute_scale_factor"):
|
||||
m.recompute_scale_factor = None # torch 1.11.0 compatibility
|
||||
|
||||
# Return model and ckpt
|
||||
|
|
@ -678,11 +727,11 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
import ast
|
||||
|
||||
# Args
|
||||
max_channels = float('inf')
|
||||
nc, act, scales = (d.get(x) for x in ('nc', 'activation', 'scales'))
|
||||
depth, width, kpt_shape = (d.get(x, 1.0) for x in ('depth_multiple', 'width_multiple', 'kpt_shape'))
|
||||
max_channels = float("inf")
|
||||
nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
|
||||
depth, width, kpt_shape = (d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape"))
|
||||
if scales:
|
||||
scale = d.get('scale')
|
||||
scale = d.get("scale")
|
||||
if not scale:
|
||||
scale = tuple(scales.keys())[0]
|
||||
LOGGER.warning(f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'.")
|
||||
|
|
@ -697,16 +746,37 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
|
||||
ch = [ch]
|
||||
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
||||
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
|
||||
m = getattr(torch.nn, m[3:]) if 'nn.' in m else globals()[m] # get module
|
||||
for i, (f, n, m, args) in enumerate(d["backbone"] + d["head"]): # from, number, module, args
|
||||
m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m] # get module
|
||||
for j, a in enumerate(args):
|
||||
if isinstance(a, str):
|
||||
with contextlib.suppress(ValueError):
|
||||
args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
|
||||
|
||||
n = n_ = max(round(n * depth), 1) if n > 1 else n # depth gain
|
||||
if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
|
||||
BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3):
|
||||
if m in (
|
||||
Classify,
|
||||
Conv,
|
||||
ConvTranspose,
|
||||
GhostConv,
|
||||
Bottleneck,
|
||||
GhostBottleneck,
|
||||
SPP,
|
||||
SPPF,
|
||||
DWConv,
|
||||
Focus,
|
||||
BottleneckCSP,
|
||||
C1,
|
||||
C2,
|
||||
C2f,
|
||||
C3,
|
||||
C3TR,
|
||||
C3Ghost,
|
||||
nn.ConvTranspose2d,
|
||||
DWConvTranspose2d,
|
||||
C3x,
|
||||
RepC3,
|
||||
):
|
||||
c1, c2 = ch[f], args[0]
|
||||
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
|
||||
c2 = make_divisible(min(c2, max_channels) * width, 8)
|
||||
|
|
@ -739,11 +809,11 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
c2 = ch[f]
|
||||
|
||||
m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # module
|
||||
t = str(m)[8:-2].replace('__main__.', '') # module type
|
||||
t = str(m)[8:-2].replace("__main__.", "") # module type
|
||||
m.np = sum(x.numel() for x in m_.parameters()) # number params
|
||||
m_.i, m_.f, m_.type = i, f, t # attach index, 'from' index, type
|
||||
if verbose:
|
||||
LOGGER.info(f'{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}') # print
|
||||
LOGGER.info(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f} {t:<45}{str(args):<30}") # print
|
||||
save.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelist
|
||||
layers.append(m_)
|
||||
if i == 0:
|
||||
|
|
@ -757,16 +827,16 @@ def yaml_model_load(path):
|
|||
import re
|
||||
|
||||
path = Path(path)
|
||||
if path.stem in (f'yolov{d}{x}6' for x in 'nsmlx' for d in (5, 8)):
|
||||
new_stem = re.sub(r'(\d+)([nslmx])6(.+)?$', r'\1\2-p6\3', path.stem)
|
||||
LOGGER.warning(f'WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.')
|
||||
if path.stem in (f"yolov{d}{x}6" for x in "nsmlx" for d in (5, 8)):
|
||||
new_stem = re.sub(r"(\d+)([nslmx])6(.+)?$", r"\1\2-p6\3", path.stem)
|
||||
LOGGER.warning(f"WARNING ⚠️ Ultralytics YOLO P6 models now use -p6 suffix. Renaming {path.stem} to {new_stem}.")
|
||||
path = path.with_name(new_stem + path.suffix)
|
||||
|
||||
unified_path = re.sub(r'(\d+)([nslmx])(.+)?$', r'\1\3', str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
|
||||
unified_path = re.sub(r"(\d+)([nslmx])(.+)?$", r"\1\3", str(path)) # i.e. yolov8x.yaml -> yolov8.yaml
|
||||
yaml_file = check_yaml(unified_path, hard=False) or check_yaml(path)
|
||||
d = yaml_load(yaml_file) # model dict
|
||||
d['scale'] = guess_model_scale(path)
|
||||
d['yaml_file'] = str(path)
|
||||
d["scale"] = guess_model_scale(path)
|
||||
d["yaml_file"] = str(path)
|
||||
return d
|
||||
|
||||
|
||||
|
|
@ -784,8 +854,9 @@ def guess_model_scale(model_path):
|
|||
"""
|
||||
with contextlib.suppress(AttributeError):
|
||||
import re
|
||||
return re.search(r'yolov\d+([nslmx])', Path(model_path).stem).group(1) # n, s, m, l, or x
|
||||
return ''
|
||||
|
||||
return re.search(r"yolov\d+([nslmx])", Path(model_path).stem).group(1) # n, s, m, l, or x
|
||||
return ""
|
||||
|
||||
|
||||
def guess_model_task(model):
|
||||
|
|
@ -804,17 +875,17 @@ def guess_model_task(model):
|
|||
|
||||
def cfg2task(cfg):
|
||||
"""Guess from YAML dictionary."""
|
||||
m = cfg['head'][-1][-2].lower() # output module name
|
||||
if m in ('classify', 'classifier', 'cls', 'fc'):
|
||||
return 'classify'
|
||||
if m == 'detect':
|
||||
return 'detect'
|
||||
if m == 'segment':
|
||||
return 'segment'
|
||||
if m == 'pose':
|
||||
return 'pose'
|
||||
if m == 'obb':
|
||||
return 'obb'
|
||||
m = cfg["head"][-1][-2].lower() # output module name
|
||||
if m in ("classify", "classifier", "cls", "fc"):
|
||||
return "classify"
|
||||
if m == "detect":
|
||||
return "detect"
|
||||
if m == "segment":
|
||||
return "segment"
|
||||
if m == "pose":
|
||||
return "pose"
|
||||
if m == "obb":
|
||||
return "obb"
|
||||
|
||||
# Guess from model cfg
|
||||
if isinstance(model, dict):
|
||||
|
|
@ -823,40 +894,42 @@ def guess_model_task(model):
|
|||
|
||||
# Guess from PyTorch model
|
||||
if isinstance(model, nn.Module): # PyTorch model
|
||||
for x in 'model.args', 'model.model.args', 'model.model.model.args':
|
||||
for x in "model.args", "model.model.args", "model.model.model.args":
|
||||
with contextlib.suppress(Exception):
|
||||
return eval(x)['task']
|
||||
for x in 'model.yaml', 'model.model.yaml', 'model.model.model.yaml':
|
||||
return eval(x)["task"]
|
||||
for x in "model.yaml", "model.model.yaml", "model.model.model.yaml":
|
||||
with contextlib.suppress(Exception):
|
||||
return cfg2task(eval(x))
|
||||
|
||||
for m in model.modules():
|
||||
if isinstance(m, Detect):
|
||||
return 'detect'
|
||||
return "detect"
|
||||
elif isinstance(m, Segment):
|
||||
return 'segment'
|
||||
return "segment"
|
||||
elif isinstance(m, Classify):
|
||||
return 'classify'
|
||||
return "classify"
|
||||
elif isinstance(m, Pose):
|
||||
return 'pose'
|
||||
return "pose"
|
||||
elif isinstance(m, OBB):
|
||||
return 'obb'
|
||||
return "obb"
|
||||
|
||||
# Guess from model filename
|
||||
if isinstance(model, (str, Path)):
|
||||
model = Path(model)
|
||||
if '-seg' in model.stem or 'segment' in model.parts:
|
||||
return 'segment'
|
||||
elif '-cls' in model.stem or 'classify' in model.parts:
|
||||
return 'classify'
|
||||
elif '-pose' in model.stem or 'pose' in model.parts:
|
||||
return 'pose'
|
||||
elif '-obb' in model.stem or 'obb' in model.parts:
|
||||
return 'obb'
|
||||
elif 'detect' in model.parts:
|
||||
return 'detect'
|
||||
if "-seg" in model.stem or "segment" in model.parts:
|
||||
return "segment"
|
||||
elif "-cls" in model.stem or "classify" in model.parts:
|
||||
return "classify"
|
||||
elif "-pose" in model.stem or "pose" in model.parts:
|
||||
return "pose"
|
||||
elif "-obb" in model.stem or "obb" in model.parts:
|
||||
return "obb"
|
||||
elif "detect" in model.parts:
|
||||
return "detect"
|
||||
|
||||
# Unable to determine task from model
|
||||
LOGGER.warning("WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
|
||||
"Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'.")
|
||||
return 'detect' # assume detect
|
||||
LOGGER.warning(
|
||||
"WARNING ⚠️ Unable to automatically guess model task, assuming 'task=detect'. "
|
||||
"Explicitly define task for your model, i.e. 'task=detect', 'segment', 'classify','pose' or 'obb'."
|
||||
)
|
||||
return "detect" # assume detect
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue