Ruff Docstring formatting (#15793)
Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
d27664216b
commit
776ca86369
60 changed files with 241 additions and 309 deletions
|
|
@ -89,13 +89,17 @@ class BaseModel(nn.Module):
|
|||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
"""
|
||||
Forward pass of the model on a single scale. Wrapper for `_forward_once` method.
|
||||
Perform forward pass of the model for either training or inference.
|
||||
|
||||
If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor | dict): The input image tensor or a dict including image tensor and gt labels.
|
||||
x (torch.Tensor | dict): Input tensor for inference, or dict with image tensor and labels for training.
|
||||
*args (Any): Variable length argument list.
|
||||
**kwargs (Any): Arbitrary keyword arguments.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The output of the network.
|
||||
(torch.Tensor): Loss if x is a dict (training), or network predictions (inference).
|
||||
"""
|
||||
if isinstance(x, dict): # for cases of training and validating while training.
|
||||
return self.loss(x, *args, **kwargs)
|
||||
|
|
@ -723,7 +727,6 @@ def temporary_modules(modules=None, attributes=None):
|
|||
Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
|
||||
applications or libraries. Use this function with caution.
|
||||
"""
|
||||
|
||||
if modules is None:
|
||||
modules = {}
|
||||
if attributes is None:
|
||||
|
|
@ -752,9 +755,9 @@ def temporary_modules(modules=None, attributes=None):
|
|||
|
||||
def torch_safe_load(weight):
|
||||
"""
|
||||
This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised,
|
||||
it catches the error, logs a warning message, and attempts to install the missing module via the
|
||||
check_requirements() function. After installation, the function again attempts to load the model using torch.load().
|
||||
Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
|
||||
error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
|
||||
After installation, the function again attempts to load the model using torch.load().
|
||||
|
||||
Args:
|
||||
weight (str): The file path of the PyTorch model.
|
||||
|
|
@ -813,7 +816,6 @@ def torch_safe_load(weight):
|
|||
|
||||
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
|
||||
"""Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a."""
|
||||
|
||||
ensemble = Ensemble()
|
||||
for w in weights if isinstance(weights, list) else [weights]:
|
||||
ckpt, w = torch_safe_load(w) # load ckpt
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue