ultralytics 8.1.33 fix HUB model checks (#9153)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
fc6c66a4a4
commit
ec1d110689
7 changed files with 20 additions and 18 deletions
|
|
@ -119,30 +119,27 @@ class Model(nn.Module):
|
|||
self.metrics = None # validation/training metrics
|
||||
self.session = None # HUB session
|
||||
self.task = task # task type
|
||||
self.model_name = model = str(model).strip() # strip spaces
|
||||
model = str(model).strip()
|
||||
|
||||
# Check if Ultralytics HUB model from https://hub.ultralytics.com
|
||||
if self.is_hub_model(model):
|
||||
# Fetch model from HUB
|
||||
checks.check_requirements("hub-sdk>=0.0.5")
|
||||
checks.check_requirements("hub-sdk>=0.0.6")
|
||||
self.session = self._get_hub_session(model)
|
||||
model = self.session.model_file
|
||||
|
||||
# Check if Triton Server model
|
||||
elif self.is_triton_model(model):
|
||||
self.model = model
|
||||
self.model_name = self.model = model
|
||||
self.task = task
|
||||
return
|
||||
|
||||
# Load or create new YOLO model
|
||||
model = checks.check_model_file_from_stem(model) # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||
if Path(model).suffix in (".yaml", ".yml"):
|
||||
self._new(model, task=task, verbose=verbose)
|
||||
else:
|
||||
self._load(model, task=task)
|
||||
|
||||
self.model_name = model
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
|
||||
|
|
@ -190,8 +187,8 @@ class Model(nn.Module):
|
|||
return any(
|
||||
(
|
||||
model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
|
||||
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODELID
|
||||
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODELID
|
||||
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODEL
|
||||
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODEL
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -215,6 +212,7 @@ class Model(nn.Module):
|
|||
# Below added to allow export from YAMLs
|
||||
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
|
||||
self.model.task = self.task
|
||||
self.model_name = cfg
|
||||
|
||||
def _load(self, weights: str, task=None) -> None:
|
||||
"""
|
||||
|
|
@ -224,19 +222,23 @@ class Model(nn.Module):
|
|||
weights (str): model checkpoint to be loaded
|
||||
task (str | None): model task
|
||||
"""
|
||||
suffix = Path(weights).suffix
|
||||
if suffix == ".pt":
|
||||
if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")):
|
||||
weights = checks.check_file(weights) # automatically download and return local filename
|
||||
weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolov8n -> yolov8n.pt
|
||||
|
||||
if Path(weights).suffix == ".pt":
|
||||
self.model, self.ckpt = attempt_load_one_weight(weights)
|
||||
self.task = self.model.args["task"]
|
||||
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
|
||||
self.ckpt_path = self.model.pt_path
|
||||
else:
|
||||
weights = checks.check_file(weights)
|
||||
weights = checks.check_file(weights) # runs in all cases, not redundant with above call
|
||||
self.model, self.ckpt = weights, None
|
||||
self.task = task or guess_model_task(weights)
|
||||
self.ckpt_path = weights
|
||||
self.overrides["model"] = weights
|
||||
self.overrides["task"] = self.task
|
||||
self.model_name = weights
|
||||
|
||||
def _check_is_pytorch_model(self) -> None:
|
||||
"""Raises TypeError is model is not a PyTorch model."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue