From 08263f5737a8ea0f65a1e261235cf4a2ffcf586d Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Sun, 4 Aug 2024 02:03:33 +0800 Subject: [PATCH] Update MLP module for RTDETR backward compatibility (#14901) Co-authored-by: Glenn Jocher --- tests/test_cli.py | 3 +++ ultralytics/nn/modules/transformer.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/test_cli.py b/tests/test_cli.py index 2e53b222..2f3e39ee 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -8,6 +8,7 @@ from PIL import Image from tests import CUDA_DEVICE_COUNT, CUDA_IS_AVAILABLE from ultralytics.cfg import TASK2DATA, TASK2MODEL, TASKS from ultralytics.utils import ASSETS, WEIGHTS_DIR, checks +from ultralytics.utils.torch_utils import TORCH_1_9 # Constants TASK_MODEL_DATA = [(task, WEIGHTS_DIR / TASK2MODEL[task], TASK2DATA[task]) for task in TASKS] @@ -57,6 +58,8 @@ def test_rtdetr(task="detect", model="yolov8n-rtdetr.yaml", data="coco8.yaml"): # Warning: must use imgsz=640 (note also add coma, spaces, fraction=0.25 args to test single-image training) run(f"yolo train {task} model={model} data={data} --imgsz= 160 epochs =1, cache = disk fraction=0.25") run(f"yolo predict {task} model={model} source={ASSETS / 'bus.jpg'} imgsz=160 save save_crop save_txt") + if TORCH_1_9: + run(f"yolo predict {task} model='rtdetr-l.pt' source={ASSETS / 'bus.jpg'} imgsz=160 save save_crop save_txt") @pytest.mark.skipif(checks.IS_PYTHON_3_12, reason="MobileSAM with CLIP is not supported in Python 3.12") diff --git a/ultralytics/nn/modules/transformer.py b/ultralytics/nn/modules/transformer.py index 4c170aa3..def184d8 100644 --- a/ultralytics/nn/modules/transformer.py +++ b/ultralytics/nn/modules/transformer.py @@ -186,8 +186,8 @@ class MLP(nn.Module): def forward(self, x): """Forward pass for the entire MLP.""" for i, layer in enumerate(self.layers): - x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) - return x.sigmoid() if self.sigmoid else x + x = getattr(self, "act", nn.ReLU())(layer(x)) if i < self.num_layers - 1 else layer(x) + return x.sigmoid() if getattr(self, "sigmoid", False) else x class LayerNorm2d(nn.Module):