ultralytics 8.0.80 single-line docstring fixes (#2060)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
31db8ed163
commit
5bce1c3021
48 changed files with 418 additions and 420 deletions
|
|
@ -167,7 +167,8 @@ class BaseModel(nn.Module):
|
|||
|
||||
|
||||
class DetectionModel(BaseModel):
|
||||
# YOLOv8 detection model
|
||||
"""YOLOv8 detection model."""
|
||||
|
||||
def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes
|
||||
super().__init__()
|
||||
self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
|
||||
|
|
@ -218,7 +219,7 @@ class DetectionModel(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def _descale_pred(p, flips, scale, img_size, dim=1):
|
||||
# de-scale predictions following augmented inference (inverse operation)
|
||||
"""De-scale predictions following augmented inference (inverse operation)."""
|
||||
p[:, :4] /= scale # de-scale
|
||||
x, y, wh, cls = p.split((1, 1, 2, p.shape[dim] - 4), dim)
|
||||
if flips == 2:
|
||||
|
|
@ -228,7 +229,7 @@ class DetectionModel(BaseModel):
|
|||
return torch.cat((x, y, wh, cls), dim)
|
||||
|
||||
def _clip_augmented(self, y):
|
||||
# Clip YOLOv5 augmented inference tails
|
||||
"""Clip YOLOv5 augmented inference tails."""
|
||||
nl = self.model[-1].nl # number of detection layers (P3-P5)
|
||||
g = sum(4 ** x for x in range(nl)) # grid points
|
||||
e = 1 # exclude layer count
|
||||
|
|
@ -240,7 +241,8 @@ class DetectionModel(BaseModel):
|
|||
|
||||
|
||||
class SegmentationModel(DetectionModel):
|
||||
# YOLOv8 segmentation model
|
||||
"""YOLOv8 segmentation model."""
|
||||
|
||||
def __init__(self, cfg='yolov8n-seg.yaml', ch=3, nc=None, verbose=True):
|
||||
super().__init__(cfg=cfg, ch=ch, nc=nc, verbose=verbose)
|
||||
|
||||
|
|
@ -249,7 +251,8 @@ class SegmentationModel(DetectionModel):
|
|||
|
||||
|
||||
class PoseModel(DetectionModel):
|
||||
# YOLOv8 pose model
|
||||
"""YOLOv8 pose model."""
|
||||
|
||||
def __init__(self, cfg='yolov8n-pose.yaml', ch=3, nc=None, data_kpt_shape=(None, None), verbose=True):
|
||||
if not isinstance(cfg, dict):
|
||||
cfg = yaml_model_load(cfg) # load model YAML
|
||||
|
|
@ -260,7 +263,8 @@ class PoseModel(DetectionModel):
|
|||
|
||||
|
||||
class ClassificationModel(BaseModel):
|
||||
# YOLOv8 classification model
|
||||
"""YOLOv8 classification model."""
|
||||
|
||||
def __init__(self,
|
||||
cfg=None,
|
||||
model=None,
|
||||
|
|
@ -272,7 +276,7 @@ class ClassificationModel(BaseModel):
|
|||
self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg, ch, nc, verbose)
|
||||
|
||||
def _from_detection_model(self, model, nc=1000, cutoff=10):
|
||||
# Create a YOLOv5 classification model from a YOLOv5 detection model
|
||||
"""Create a YOLOv5 classification model from a YOLOv5 detection model."""
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
if isinstance(model, AutoBackend):
|
||||
model = model.model # unwrap DetectMultiBackend
|
||||
|
|
@ -304,7 +308,7 @@ class ClassificationModel(BaseModel):
|
|||
|
||||
@staticmethod
|
||||
def reshape_outputs(model, nc):
|
||||
# Update a TorchVision classification model to class count 'n' if required
|
||||
"""Update a TorchVision classification model to class count 'n' if required."""
|
||||
name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
|
||||
if isinstance(m, Classify): # YOLO Classify() head
|
||||
if m.linear.out_features != nc:
|
||||
|
|
@ -363,7 +367,7 @@ def torch_safe_load(weight):
|
|||
|
||||
|
||||
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||
# Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a
|
||||
"""Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a."""
|
||||
|
||||
ensemble = Ensemble()
|
||||
for w in weights if isinstance(weights, list) else [weights]:
|
||||
|
|
@ -403,7 +407,7 @@ def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
|||
|
||||
|
||||
def attempt_load_one_weight(weight, device=None, inplace=True, fuse=False):
|
||||
# Loads a single model weights
|
||||
"""Loads a single model weights."""
|
||||
ckpt, weight = torch_safe_load(weight) # load ckpt
|
||||
args = {**DEFAULT_CFG_DICT, **ckpt['train_args']} # combine model and default args, preferring model args
|
||||
model = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
|
||||
|
|
@ -546,7 +550,7 @@ def guess_model_task(model):
|
|||
"""
|
||||
|
||||
def cfg2task(cfg):
|
||||
# Guess from YAML dictionary
|
||||
"""Guess from YAML dictionary."""
|
||||
m = cfg['head'][-1][-2].lower() # output module name
|
||||
if m in ('classify', 'classifier', 'cls', 'fc'):
|
||||
return 'classify'
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue