Ruff Docstring formatting (#15793)

Signed-off-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-08-25 04:27:55 +08:00 committed by GitHub
parent d27664216b
commit 776ca86369
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
60 changed files with 241 additions and 309 deletions

View file

@ -89,13 +89,17 @@ class BaseModel(nn.Module):
def forward(self, x, *args, **kwargs):
"""
Forward pass of the model on a single scale. Wrapper for `_forward_once` method.
Perform forward pass of the model for either training or inference.
If x is a dict, calculates and returns the loss for training. Otherwise, returns predictions for inference.
Args:
x (torch.Tensor | dict): The input image tensor or a dict including image tensor and gt labels.
x (torch.Tensor | dict): Input tensor for inference, or dict with image tensor and labels for training.
*args (Any): Variable length argument list.
**kwargs (Any): Arbitrary keyword arguments.
Returns:
(torch.Tensor): The output of the network.
(torch.Tensor): Loss if x is a dict (training), or network predictions (inference).
"""
if isinstance(x, dict): # for cases of training and validating while training.
return self.loss(x, *args, **kwargs)
@ -723,7 +727,6 @@ def temporary_modules(modules=None, attributes=None):
Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
applications or libraries. Use this function with caution.
"""
if modules is None:
modules = {}
if attributes is None:
@ -752,9 +755,9 @@ def temporary_modules(modules=None, attributes=None):
def torch_safe_load(weight):
"""
This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised,
it catches the error, logs a warning message, and attempts to install the missing module via the
check_requirements() function. After installation, the function again attempts to load the model using torch.load().
Attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised, it catches the
error, logs a warning message, and attempts to install the missing module via the check_requirements() function.
After installation, the function again attempts to load the model using torch.load().
Args:
weight (str): The file path of the PyTorch model.
@ -813,7 +816,6 @@ def torch_safe_load(weight):
def attempt_load_weights(weights, device=None, inplace=True, fuse=False):
"""Loads an ensemble of models weights=[a,b,c] or a single model weights=[a] or weights=a."""
ensemble = Ensemble()
for w in weights if isinstance(weights, list) else [weights]:
ckpt, w = torch_safe_load(w) # load ckpt