Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ultralytics Assistant 2024-09-06 03:54:35 +08:00 committed by GitHub
parent 95d54828bb
commit ac2c2be8f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 45 additions and 62 deletions

View file

@ -736,7 +736,7 @@ class PositionEmbeddingSine(nn.Module):
self.num_pos_feats = num_pos_feats // 2
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
if scale is not None and not normalize:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
@ -763,8 +763,7 @@ class PositionEmbeddingSine(nn.Module):
def encode_boxes(self, x, y, w, h):
"""Encodes box coordinates and dimensions into positional embeddings for detection."""
pos_x, pos_y = self._encode_xy(x, y)
pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
return pos
return torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
encode = encode_boxes # Backwards compatibility
@ -775,8 +774,7 @@ class PositionEmbeddingSine(nn.Module):
assert bx == by and nx == ny and bx == bl and nx == nl
pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
return pos
return torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
@torch.no_grad()
def forward(self, x: torch.Tensor):