Integration of v8 segmentation (#107)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
384f0ef1c6
commit
8406b49b49
16 changed files with 422 additions and 224 deletions
|
|
@ -101,7 +101,7 @@ class DetectionModel(BaseModel):
|
|||
if isinstance(m, (Detect, Segment)):
|
||||
s = 256 # 2x min stride
|
||||
m.inplace = self.inplace
|
||||
forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Detect)) else self.forward(x)
|
||||
forward = lambda x: self.forward(x)[0] if isinstance(m, Segment) else self.forward(x)
|
||||
m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward
|
||||
self.stride = m.stride
|
||||
m.bias_init() # only run once
|
||||
|
|
@ -163,8 +163,8 @@ class DetectionModel(BaseModel):
|
|||
|
||||
class SegmentationModel(DetectionModel):
|
||||
# YOLOv5 segmentation model
|
||||
def __init__(self, cfg='yolov5s-seg.yaml', ch=3, nc=None):
|
||||
super().__init__(cfg, ch, nc)
|
||||
def __init__(self, cfg='yolov5s-seg.yaml', ch=3, nc=None, verbose=True):
|
||||
super().__init__(cfg, ch, nc, verbose)
|
||||
|
||||
|
||||
class ClassificationModel(BaseModel):
|
||||
|
|
@ -300,7 +300,7 @@ def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)
|
|||
elif m in {Detect, Segment}:
|
||||
args.append([ch[x] for x in f])
|
||||
if m is Segment:
|
||||
args[3] = make_divisible(args[3] * gw, 8)
|
||||
args[2] = make_divisible(args[2] * gw, 8)
|
||||
else:
|
||||
c2 = ch[f]
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue