Model coverage cleanup (#4585)

This commit is contained in:
Glenn Jocher 2023-08-27 04:19:41 +02:00 committed by GitHub
parent c635418a27
commit deac7575b1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 132 additions and 175 deletions

View file

@ -322,31 +322,10 @@ class PoseModel(DetectionModel):
class ClassificationModel(BaseModel):
"""YOLOv8 classification model."""
def __init__(self,
cfg='yolov8n-cls.yaml',
model=None,
ch=3,
nc=None,
cutoff=10,
verbose=True): # YAML, model, channels, number of classes, cutoff index, verbose flag
def __init__(self, cfg='yolov8n-cls.yaml', ch=3, nc=None, verbose=True):
"""Init ClassificationModel with YAML, channels, number of classes, verbose flag."""
super().__init__()
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
def _from_detection_model(self, model, nc=1000, cutoff=10):
"""Create a YOLOv5 classification model from a YOLOv5 detection model."""
from ultralytics.nn.autobackend import AutoBackend
if isinstance(model, AutoBackend):
model = model.model # unwrap DetectMultiBackend
model.model = model.model[:cutoff] # backbone
m = model.model[-1] # last layer
ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module
c = Classify(ch, nc) # Classify()
c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type
model.model[-1] = c # replace
self.model = model.model
self.stride = model.stride
self.save = []
self.nc = nc
self._from_yaml(cfg, ch, nc, verbose)
def _from_yaml(self, cfg, ch, nc, verbose):
"""Set YOLOv8 model configurations and define the model architecture."""