ultralytics 8.3.71 require explicit torch.nn usage (#19067)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: RizwanMunawar <chr043416@gmail.com>
Co-authored-by: Muhammad Rizwan Munawar <muhammadrizwanmunawar123@gmail.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2025-02-05 01:08:17 +09:00 committed by GitHub
parent 17450e9646
commit 5bca9341e8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 50 additions and 51 deletions

View file

@ -11,7 +11,7 @@ from PIL import Image
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
from ultralytics.engine.results import Results
from ultralytics.hub import HUB_WEB_ROOT, HUBTrainingSession
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, yaml_model_load
from ultralytics.utils import (
ARGV,
ASSETS,
@ -26,7 +26,7 @@ from ultralytics.utils import (
)
class Model(nn.Module):
class Model(torch.nn.Module):
"""
A base class for implementing YOLO models, unifying APIs across different model types.
@ -37,7 +37,7 @@ class Model(nn.Module):
Attributes:
callbacks (Dict): A dictionary of callback functions for various events during model operations.
predictor (BasePredictor): The predictor object used for making predictions.
model (nn.Module): The underlying PyTorch model.
model (torch.nn.Module): The underlying PyTorch model.
trainer (BaseTrainer): The trainer object used for training the model.
ckpt (Dict): The checkpoint data if the model is loaded from a *.pt file.
cfg (str): The configuration of the model if loaded from a *.yaml file.
@ -317,7 +317,7 @@ class Model(nn.Module):
>>> model._check_is_pytorch_model() # Raises TypeError
"""
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
pt_module = isinstance(self.model, nn.Module)
pt_module = isinstance(self.model, torch.nn.Module)
if not (pt_module or pt_str):
raise TypeError(
f"model='{self.model}' should be a *.pt PyTorch model to run this method, but is a different format. "
@ -405,7 +405,7 @@ class Model(nn.Module):
from ultralytics import __version__
updates = {
"model": deepcopy(self.model).half() if isinstance(self.model, nn.Module) else self.model,
"model": deepcopy(self.model).half() if isinstance(self.model, torch.nn.Module) else self.model,
"date": datetime.now().isoformat(),
"version": __version__,
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
@ -452,7 +452,7 @@ class Model(nn.Module):
performs both convolution and normalization in one step.
Raises:
TypeError: If the model is not a PyTorch nn.Module.
TypeError: If the model is not a PyTorch torch.nn.Module.
Examples:
>>> model = Model("yolo11n.pt")
@ -921,13 +921,13 @@ class Model(nn.Module):
Retrieves the device on which the model's parameters are allocated.
This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is
applicable only to models that are instances of nn.Module.
applicable only to models that are instances of torch.nn.Module.
Returns:
(torch.device): The device (CPU/GPU) of the model.
Raises:
AttributeError: If the model is not a PyTorch nn.Module instance.
AttributeError: If the model is not a torch.nn.Module instance.
Examples:
>>> model = YOLO("yolo11n.pt")
@ -937,7 +937,7 @@ class Model(nn.Module):
>>> print(model.device)
device(type='cpu')
"""
return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
return next(self.model.parameters()).device if isinstance(self.model, torch.nn.Module) else None
@property
def transforms(self):