Update MLP module for RTDETR backward compatibility (#14901)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-08-04 02:03:33 +08:00 committed by GitHub
parent 121f2242e1
commit 08263f5737
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 5 additions and 2 deletions

View file

@ -186,8 +186,8 @@ class MLP(nn.Module):
def forward(self, x):
"""Forward pass for the entire MLP."""
for i, layer in enumerate(self.layers):
x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
return x.sigmoid() if self.sigmoid else x
x = getattr(self, "act", nn.ReLU())(layer(x)) if i < self.num_layers - 1 else layer(x)
return x.sigmoid() if getattr(self, "sigmoid", False) else x
class LayerNorm2d(nn.Module):