ultralytics 8.1.33 fix HUB model checks (#9153)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: Laughing <61612323+Laughing-q@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Kalen Michael 2024-03-24 06:14:15 +01:00 committed by GitHub
parent fc6c66a4a4
commit ec1d110689
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 20 additions and 18 deletions

View file

@ -119,30 +119,27 @@ class Model(nn.Module):
self.metrics = None # validation/training metrics
self.session = None # HUB session
self.task = task # task type
self.model_name = model = str(model).strip() # strip spaces
model = str(model).strip()
# Check if Ultralytics HUB model from https://hub.ultralytics.com
if self.is_hub_model(model):
# Fetch model from HUB
checks.check_requirements("hub-sdk>=0.0.5")
checks.check_requirements("hub-sdk>=0.0.6")
self.session = self._get_hub_session(model)
model = self.session.model_file
# Check if Triton Server model
elif self.is_triton_model(model):
self.model = model
self.model_name = self.model = model
self.task = task
return
# Load or create new YOLO model
model = checks.check_model_file_from_stem(model) # add suffix, i.e. yolov8n -> yolov8n.pt
if Path(model).suffix in (".yaml", ".yml"):
self._new(model, task=task, verbose=verbose)
else:
self._load(model, task=task)
self.model_name = model
def __call__(
self,
source: Union[str, Path, int, list, tuple, np.ndarray, torch.Tensor] = None,
@ -190,8 +187,8 @@ class Model(nn.Module):
return any(
(
model.startswith(f"{HUB_WEB_ROOT}/models/"), # i.e. https://hub.ultralytics.com/models/MODEL_ID
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODELID
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODELID
[len(x) for x in model.split("_")] == [42, 20], # APIKEY_MODEL
len(model) == 20 and not Path(model).exists() and all(x not in model for x in "./\\"), # MODEL
)
)
@ -215,6 +212,7 @@ class Model(nn.Module):
# Below added to allow export from YAMLs
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # combine default and model args (prefer model args)
self.model.task = self.task
self.model_name = cfg
def _load(self, weights: str, task=None) -> None:
"""
@ -224,19 +222,23 @@ class Model(nn.Module):
weights (str): model checkpoint to be loaded
task (str | None): model task
"""
suffix = Path(weights).suffix
if suffix == ".pt":
if weights.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")):
weights = checks.check_file(weights) # automatically download and return local filename
weights = checks.check_model_file_from_stem(weights) # add suffix, i.e. yolov8n -> yolov8n.pt
if Path(weights).suffix == ".pt":
self.model, self.ckpt = attempt_load_one_weight(weights)
self.task = self.model.args["task"]
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
self.ckpt_path = self.model.pt_path
else:
weights = checks.check_file(weights)
weights = checks.check_file(weights) # runs in all cases, not redundant with above call
self.model, self.ckpt = weights, None
self.task = task or guess_model_task(weights)
self.ckpt_path = weights
self.overrides["model"] = weights
self.overrides["task"] = self.task
self.model_name = weights
def _check_is_pytorch_model(self) -> None:
"""Raises TypeError is model is not a PyTorch model."""