Add NAS autodownload (#14627)
This commit is contained in:
parent
f94c82da31
commit
82956dc77b
1 changed files with 2 additions and 1 deletions
|
|
@ -16,6 +16,7 @@ from pathlib import Path
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ultralytics.engine.model import Model
|
from ultralytics.engine.model import Model
|
||||||
|
from ultralytics.utils.downloads import attempt_download_asset
|
||||||
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
|
from ultralytics.utils.torch_utils import model_info, smart_inference_mode
|
||||||
|
|
||||||
from .predict import NASPredictor
|
from .predict import NASPredictor
|
||||||
|
|
@ -56,7 +57,7 @@ class NAS(Model):
|
||||||
|
|
||||||
suffix = Path(weights).suffix
|
suffix = Path(weights).suffix
|
||||||
if suffix == ".pt":
|
if suffix == ".pt":
|
||||||
self.model = torch.load(weights)
|
self.model = torch.load(attempt_download_asset(weights))
|
||||||
elif suffix == "":
|
elif suffix == "":
|
||||||
self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
|
self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
|
||||||
# Standardize model
|
# Standardize model
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue