ultralytics 8.1.39 add YOLO-World training (#9268)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-03-31 22:30:17 +08:00 committed by GitHub
parent 18036908d4
commit e9187c1296
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 2166 additions and 100 deletions

View file

@ -519,7 +519,8 @@ class ContrastiveHead(nn.Module):
def __init__(self):
"""Initializes ContrastiveHead with specified region-text similarity parameters."""
super().__init__()
self.bias = nn.Parameter(torch.zeros([]))
# NOTE: use -10.0 to keep the init cls loss consistency with other losses
self.bias = nn.Parameter(torch.tensor([-10.0]))
self.logit_scale = nn.Parameter(torch.ones([]) * torch.tensor(1 / 0.07).log())
def forward(self, x, w):
@ -542,7 +543,8 @@ class BNContrastiveHead(nn.Module):
"""Initialize ContrastiveHead with region-text similarity parameters."""
super().__init__()
self.norm = nn.BatchNorm2d(embed_dims)
self.bias = nn.Parameter(torch.zeros([]))
# NOTE: use -10.0 to keep the init cls loss consistency with other losses
self.bias = nn.Parameter(torch.tensor([-10.0]))
# use -1.0 is more stable
self.logit_scale = nn.Parameter(-1.0 * torch.ones([]))

View file

@ -250,6 +250,15 @@ class WorldDetect(Detect):
y = torch.cat((dbox, cls.sigmoid()), 1)
return y if self.export else (y, x)
def bias_init(self):
"""Initialize Detect() biases, WARNING: requires stride availability."""
m = self # self.model[-1] # Detect() module
# cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
# ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # nominal class frequency
for a, b, s in zip(m.cv2, m.cv3, m.stride): # from
a[-1].bias.data[:] = 1.0 # box
# b[-1].bias.data[:] = math.log(5 / m.nc / (640 / s) ** 2) # cls (.01 objects, 80 classes, 640 img)
class RTDETRDecoder(nn.Module):
"""