ultralytics 8.3.71 require explicit torch.nn usage (#19067)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: RizwanMunawar <chr043416@gmail.com> Co-authored-by: Muhammad Rizwan Munawar <muhammadrizwanmunawar123@gmail.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
17450e9646
commit
5bca9341e8
10 changed files with 50 additions and 51 deletions
|
|
@ -11,7 +11,7 @@ from PIL import Image
|
|||
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.hub import HUB_WEB_ROOT, HUBTrainingSession
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, yaml_model_load
|
||||
from ultralytics.utils import (
|
||||
ARGV,
|
||||
ASSETS,
|
||||
|
|
@ -26,7 +26,7 @@ from ultralytics.utils import (
|
|||
)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
class Model(torch.nn.Module):
|
||||
"""
|
||||
A base class for implementing YOLO models, unifying APIs across different model types.
|
||||
|
||||
|
|
@ -37,7 +37,7 @@ class Model(nn.Module):
|
|||
Attributes:
|
||||
callbacks (Dict): A dictionary of callback functions for various events during model operations.
|
||||
predictor (BasePredictor): The predictor object used for making predictions.
|
||||
model (nn.Module): The underlying PyTorch model.
|
||||
model (torch.nn.Module): The underlying PyTorch model.
|
||||
trainer (BaseTrainer): The trainer object used for training the model.
|
||||
ckpt (Dict): The checkpoint data if the model is loaded from a *.pt file.
|
||||
cfg (str): The configuration of the model if loaded from a *.yaml file.
|
||||
|
|
@ -317,7 +317,7 @@ class Model(nn.Module):
|
|||
>>> model._check_is_pytorch_model() # Raises TypeError
|
||||
"""
|
||||
pt_str = isinstance(self.model, (str, Path)) and Path(self.model).suffix == ".pt"
|
||||
pt_module = isinstance(self.model, nn.Module)
|
||||
pt_module = isinstance(self.model, torch.nn.Module)
|
||||
if not (pt_module or pt_str):
|
||||
raise TypeError(
|
||||
f"model='{self.model}' should be a *.pt PyTorch model to run this method, but is a different format. "
|
||||
|
|
@ -405,7 +405,7 @@ class Model(nn.Module):
|
|||
from ultralytics import __version__
|
||||
|
||||
updates = {
|
||||
"model": deepcopy(self.model).half() if isinstance(self.model, nn.Module) else self.model,
|
||||
"model": deepcopy(self.model).half() if isinstance(self.model, torch.nn.Module) else self.model,
|
||||
"date": datetime.now().isoformat(),
|
||||
"version": __version__,
|
||||
"license": "AGPL-3.0 License (https://ultralytics.com/license)",
|
||||
|
|
@ -452,7 +452,7 @@ class Model(nn.Module):
|
|||
performs both convolution and normalization in one step.
|
||||
|
||||
Raises:
|
||||
TypeError: If the model is not a PyTorch nn.Module.
|
||||
TypeError: If the model is not a PyTorch torch.nn.Module.
|
||||
|
||||
Examples:
|
||||
>>> model = Model("yolo11n.pt")
|
||||
|
|
@ -921,13 +921,13 @@ class Model(nn.Module):
|
|||
Retrieves the device on which the model's parameters are allocated.
|
||||
|
||||
This property determines the device (CPU or GPU) where the model's parameters are currently stored. It is
|
||||
applicable only to models that are instances of nn.Module.
|
||||
applicable only to models that are instances of torch.nn.Module.
|
||||
|
||||
Returns:
|
||||
(torch.device): The device (CPU/GPU) of the model.
|
||||
|
||||
Raises:
|
||||
AttributeError: If the model is not a PyTorch nn.Module instance.
|
||||
AttributeError: If the model is not a torch.nn.Module instance.
|
||||
|
||||
Examples:
|
||||
>>> model = YOLO("yolo11n.pt")
|
||||
|
|
@ -937,7 +937,7 @@ class Model(nn.Module):
|
|||
>>> print(model.device)
|
||||
device(type='cpu')
|
||||
"""
|
||||
return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
|
||||
return next(self.model.parameters()).device if isinstance(self.model, torch.nn.Module) else None
|
||||
|
||||
@property
|
||||
def transforms(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue