Model coverage cleanup (#4585)
This commit is contained in:
parent
c635418a27
commit
deac7575b1
12 changed files with 132 additions and 175 deletions
|
|
@ -322,31 +322,10 @@ class PoseModel(DetectionModel):
|
|||
class ClassificationModel(BaseModel):
|
||||
"""YOLOv8 classification model."""
|
||||
|
||||
def __init__(self,
|
||||
cfg='yolov8n-cls.yaml',
|
||||
model=None,
|
||||
ch=3,
|
||||
nc=None,
|
||||
cutoff=10,
|
||||
verbose=True): # YAML, model, channels, number of classes, cutoff index, verbose flag
|
||||
def __init__(self, cfg='yolov8n-cls.yaml', ch=3, nc=None, verbose=True):
|
||||
"""Init ClassificationModel with YAML, channels, number of classes, verbose flag."""
|
||||
super().__init__()
|
||||
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
|
||||
|
||||
def _from_detection_model(self, model, nc=1000, cutoff=10):
|
||||
"""Create a YOLOv5 classification model from a YOLOv5 detection model."""
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
if isinstance(model, AutoBackend):
|
||||
model = model.model # unwrap DetectMultiBackend
|
||||
model.model = model.model[:cutoff] # backbone
|
||||
m = model.model[-1] # last layer
|
||||
ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module
|
||||
c = Classify(ch, nc) # Classify()
|
||||
c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type
|
||||
model.model[-1] = c # replace
|
||||
self.model = model.model
|
||||
self.stride = model.stride
|
||||
self.save = []
|
||||
self.nc = nc
|
||||
self._from_yaml(cfg, ch, nc, verbose)
|
||||
|
||||
def _from_yaml(self, cfg, ch, nc, verbose):
|
||||
"""Set YOLOv8 model configurations and define the model architecture."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue