__getattr__ support to access YOLO attributes via Model class (#17805)

Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Void 2024-11-29 05:51:33 +08:00 committed by GitHub
parent e668de6e02
commit 29826241a0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -144,6 +144,9 @@ class Model(nn.Module):
else:
self._load(model, task=task)
# Delete super().training for accessing self.model.training
del self.training
def __call__(
self,
source: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor] = None,
@ -1143,3 +1146,29 @@ class Model(nn.Module):
"""
self.model.eval()
return self
def __getattr__(self, name):
"""
Enables accessing model attributes directly through the Model class.
This method provides a way to access attributes of the underlying model directly through the Model class
instance. It first checks if the requested attribute is 'model', in which case it returns the model from
the module dictionary. Otherwise, it delegates the attribute lookup to the underlying model.
Args:
name (str): The name of the attribute to retrieve.
Returns:
(Any): The requested attribute value.
Raises:
AttributeError: If the requested attribute does not exist in the model.
Examples:
>>> model = YOLO("yolo11n.pt")
>>> print(model.stride)
>>> print(model.task)
"""
if name == "model":
return self._modules["model"]
return getattr(self.model, name)