Attempt to fix NAS models inference (#14630)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
parent
fb20867262
commit
03225fce9a
1 changed files with 12 additions and 3 deletions
|
|
@ -17,7 +17,7 @@ import torch
|
|||
|
||||
from ultralytics.engine.model import Model
|
||||
from ultralytics.utils.downloads import attempt_download_asset
|
||||
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
|
||||
from ultralytics.utils.torch_utils import model_info
|
||||
|
||||
from .predict import NASPredictor
|
||||
from .val import NASValidator
|
||||
|
|
@ -50,16 +50,25 @@ class NAS(Model):
|
|||
assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."
|
||||
super().__init__(model, task="detect")
|
||||
|
||||
@smart_inference_mode()
|
||||
def _load(self, weights: str, task: str):
|
||||
def _load(self, weights: str, task=None) -> None:
|
||||
"""Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
|
||||
import super_gradients
|
||||
|
||||
suffix = Path(weights).suffix
|
||||
if suffix == ".pt":
|
||||
self.model = torch.load(attempt_download_asset(weights))
|
||||
|
||||
elif suffix == "":
|
||||
self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
|
||||
|
||||
# Override the forward method to ignore additional arguments
|
||||
def new_forward(x, *args, **kwargs):
|
||||
"""Ignore additional __call__ arguments."""
|
||||
return self.model._original_forward(x)
|
||||
|
||||
self.model._original_forward = self.model.forward
|
||||
self.model.forward = new_forward
|
||||
|
||||
# Standardize model
|
||||
self.model.fuse = lambda verbose=True: self.model
|
||||
self.model.stride = torch.tensor([32])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue