ultralytics 8.1.4 RTDETR TensorBoard graph visualization fix (#7725)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-01-21 22:28:24 +01:00 committed by GitHub
parent 6535bcde2b
commit 7a0c27c7d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 65 additions and 26 deletions

View file

@ -376,7 +376,7 @@ class RTDETRDecoder(nn.Module):
def _get_decoder_input(self, feats, shapes, dn_embed=None, dn_bbox=None):
"""Generates and prepares the input required for the decoder from the provided features and shapes."""
bs = len(feats)
bs = feats.shape[0]
# Prepare input for decoder
anchors, valid_mask = self._generate_anchors(shapes, dtype=feats.dtype, device=feats.device)
features = self.enc_output(valid_mask * feats) # bs, h*w, 256