Integration of v8 segmentation (#107)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2022-12-28 23:01:38 +08:00 committed by GitHub
parent 384f0ef1c6
commit 8406b49b49
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 422 additions and 224 deletions

View file

@ -101,7 +101,7 @@ class DetectionModel(BaseModel):
if isinstance(m, (Detect, Segment)):
s = 256 # 2x min stride
m.inplace = self.inplace
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Detect)) else self.forward(x)
forward = lambda x: self.forward(x)[0] if isinstance(m, Segment) else self.forward(x)
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
self.stride = m.stride
m.bias_init() # only run once
@ -163,8 +163,8 @@ class DetectionModel(BaseModel):
class SegmentationModel(DetectionModel):
# YOLOv5 segmentation model
def __init__(self, cfg='yolov5s-seg.yaml', ch=3, nc=None):
super().__init__(cfg, ch, nc)
def __init__(self, cfg='yolov5s-seg.yaml', ch=3, nc=None, verbose=True):
super().__init__(cfg, ch, nc, verbose)
class ClassificationModel(BaseModel):
@ -300,7 +300,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
elif m in {Detect, Segment}:
args.append([ch[x] for x in f])
if m is Segment:
args[3] = make_divisible(args[3] * gw, 8)
args[2] = make_divisible(args[2] * gw, 8)
else:
c2 = ch[f]