Add Classification model YAML support (#154)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
0e5a7ae623
commit
07eab49c3d
14 changed files with 199 additions and 71 deletions
|
|
@ -662,12 +662,10 @@ class Segment(Detect):
|
|||
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
|
||||
|
||||
def forward(self, x):
|
||||
p = self.proto(x[0])
|
||||
p = self.proto(x[0]) # mask protos
|
||||
bs = p.shape[0] # batch size
|
||||
|
||||
mc = [] # mask coefficient
|
||||
for i in range(self.nl):
|
||||
mc.append(self.cv4[i](x[i]))
|
||||
mc = torch.cat([mi.view(p.shape[0], self.nm, -1) for mi in mc], 2)
|
||||
mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
|
||||
x = self.detect(self, x)
|
||||
if self.training:
|
||||
return x, mc, p
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue