Add Classification model YAML support (#154)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
0e5a7ae623
commit
07eab49c3d
14 changed files with 199 additions and 71 deletions
|
|
@ -1,11 +1,9 @@
|
|||
import contextlib
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import thop
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
|
||||
from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify,
|
||||
Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,
|
||||
|
|
@ -226,9 +224,15 @@ class SegmentationModel(DetectionModel):
|
|||
|
||||
class ClassificationModel(BaseModel):
|
||||
# YOLOv5 classification model
|
||||
def __init__(self, cfg=None, model=None, nc=1000, cutoff=10): # yaml, model, number of classes, cutoff index
|
||||
def __init__(self,
|
||||
cfg=None,
|
||||
model=None,
|
||||
ch=3,
|
||||
nc=1000,
|
||||
cutoff=10,
|
||||
verbose=True): # yaml, model, number of classes, cutoff index
|
||||
super().__init__()
|
||||
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg)
|
||||
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
|
||||
|
||||
def _from_detection_model(self, model, nc=1000, cutoff=10):
|
||||
# Create a YOLOv5 classification model from a YOLOv5 detection model
|
||||
|
|
@ -246,9 +250,15 @@ class ClassificationModel(BaseModel):
|
|||
self.save = []
|
||||
self.nc = nc
|
||||
|
||||
def _from_yaml(self, cfg):
|
||||
# TODO: Create a YOLOv5 classification model from a *.yaml file
|
||||
self.model = None
|
||||
def _from_yaml(self, cfg, ch, nc, verbose):
|
||||
self.yaml = cfg if isinstance(cfg, dict) else yaml_load(check_yaml(cfg), append_filename=True) # cfg dict
|
||||
# Define model
|
||||
ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
|
||||
if nc and nc != self.yaml['nc']:
|
||||
LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
|
||||
self.yaml['nc'] = nc # override yaml value
|
||||
self.model, self.save = parse_model(deepcopy(self.yaml), ch=[ch], verbose=verbose) # model, savelist
|
||||
self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
|
||||
|
||||
def load(self, weights):
|
||||
model = weights["model"] if isinstance(weights, dict) else weights # torchvision models are not dicts
|
||||
|
|
@ -351,7 +361,7 @@ def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
|||
|
||||
|
||||
def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
||||
# Parse a YOLOv5 model.yaml dictionary
|
||||
# Parse a YOLO model.yaml dictionary
|
||||
if verbose:
|
||||
LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10} {'module':<45}{'arguments':<30}")
|
||||
nc, gd, gw, act = d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
|
||||
|
|
@ -359,7 +369,6 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
|
||||
if verbose:
|
||||
LOGGER.info(f"{colorstr('activation:')} {act}") # print
|
||||
no = nc + 4 # number of outputs = classes + box
|
||||
|
||||
layers, save, c2 = [], [], ch[-1] # layers, savelist, ch out
|
||||
for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, args
|
||||
|
|
@ -370,10 +379,10 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
|
||||
n = n_ = max(round(n * gd), 1) if n > 1 else n # depth gain
|
||||
if m in {
|
||||
Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus, BottleneckCSP,
|
||||
C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
|
||||
Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
|
||||
BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x}:
|
||||
c1, c2 = ch[f], args[0]
|
||||
if c2 != no: # if not output
|
||||
if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
|
||||
c2 = make_divisible(c2 * gw, 8)
|
||||
|
||||
args = [c1, c2, *args[1:]]
|
||||
|
|
@ -384,7 +393,6 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
args = [ch[f]]
|
||||
elif m is Concat:
|
||||
c2 = sum(ch[x] for x in f)
|
||||
# TODO: channel, gw, gd
|
||||
elif m in {Detect, Segment}:
|
||||
args.append([ch[x] for x in f])
|
||||
if m is Segment:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue