From 03225fce9ad885e90e4b677fb6839c77efa546fc Mon Sep 17 00:00:00 2001 From: Laughing <61612323+Laughing-q@users.noreply.github.com> Date: Wed, 24 Jul 2024 03:44:02 +0800 Subject: [PATCH] Attempt to fix NAS models inference (#14630) Signed-off-by: Glenn Jocher Co-authored-by: Glenn Jocher Co-authored-by: UltralyticsAssistant Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com> --- ultralytics/models/nas/model.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/ultralytics/models/nas/model.py b/ultralytics/models/nas/model.py index fd444f13..90446c58 100644 --- a/ultralytics/models/nas/model.py +++ b/ultralytics/models/nas/model.py @@ -17,7 +17,7 @@ import torch from ultralytics.engine.model import Model from ultralytics.utils.downloads import attempt_download_asset -from ultralytics.utils.torch_utils import model_info, smart_inference_mode +from ultralytics.utils.torch_utils import model_info from .predict import NASPredictor from .val import NASValidator @@ -50,16 +50,25 @@ class NAS(Model): assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models." super().__init__(model, task="detect") - @smart_inference_mode() - def _load(self, weights: str, task: str): + def _load(self, weights: str, task=None) -> None: """Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided.""" import super_gradients suffix = Path(weights).suffix if suffix == ".pt": self.model = torch.load(attempt_download_asset(weights)) + elif suffix == "": self.model = super_gradients.training.models.get(weights, pretrained_weights="coco") + + # Override the forward method to ignore additional arguments + def new_forward(x, *args, **kwargs): + """Ignore additional __call__ arguments.""" + return self.model._original_forward(x) + + self.model._original_forward = self.model.forward + self.model.forward = new_forward + # Standardize model self.model.fuse = lambda verbose=True: self.model self.model.stride = torch.tensor([32])