Add Classification model YAML support (#154)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ayush Chaurasia 2023-01-08 00:34:34 +05:30 committed by GitHub
parent 0e5a7ae623
commit 07eab49c3d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 199 additions and 71 deletions

View file

@ -662,12 +662,10 @@ class Segment(Detect):
self.cv4 = nn.ModuleList(nn.Sequential(Conv(x, c4, 3), Conv(c4, c4, 3), nn.Conv2d(c4, self.nm, 1)) for x in ch)
def forward(self, x):
p = self.proto(x[0])
p = self.proto(x[0]) # mask protos
bs = p.shape[0] # batch size
mc = [] # mask coefficient
for i in range(self.nl):
mc.append(self.cv4[i](x[i]))
mc = torch.cat([mi.view(p.shape[0], self.nm, -1) for mi in mc], 2)
mc = torch.cat([self.cv4[i](x[i]).view(bs, self.nm, -1) for i in range(self.nl)], 2) # mask coefficients
x = self.detect(self, x)
if self.training:
return x, mc, p