ultralytics 8.0.136 refactor and simplify package (#3748)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
8ebe94d1e9
commit
620f3eb218
383 changed files with 4213 additions and 4646 deletions
|
|
@ -11,12 +11,12 @@ from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottlenec
|
|||
Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d,
|
||||
Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv,
|
||||
RTDETRDecoder, Segment)
|
||||
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
|
||||
from ultralytics.yolo.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
|
||||
from ultralytics.yolo.utils.plotting import feature_visualization
|
||||
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
|
||||
intersect_dicts, make_divisible, model_info, scale_img, time_sync)
|
||||
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
|
||||
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
|
||||
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
|
||||
from ultralytics.utils.plotting import feature_visualization
|
||||
from ultralytics.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, intersect_dicts,
|
||||
make_divisible, model_info, scale_img, time_sync)
|
||||
|
||||
try:
|
||||
import thop
|
||||
|
|
@ -412,7 +412,7 @@ class RTDETRDetectionModel(DetectionModel):
|
|||
|
||||
def init_criterion(self):
|
||||
"""Compute the classification loss between predictions and true labels."""
|
||||
from ultralytics.vit.utils.loss import RTDETRDetectionLoss
|
||||
from ultralytics.models.utils.loss import RTDETRDetectionLoss
|
||||
|
||||
return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
|
||||
|
||||
|
|
@ -498,6 +498,45 @@ class Ensemble(nn.ModuleList):
|
|||
# Functions ------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temporary_modules(modules=None):
|
||||
"""
|
||||
Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
|
||||
|
||||
This function can be used to change the module paths during runtime. It's useful when refactoring code,
|
||||
where you've moved a module from one location to another, but you still want to support the old import
|
||||
paths for backwards compatibility.
|
||||
|
||||
Args:
|
||||
modules (dict, optional): A dictionary mapping old module paths to new module paths.
|
||||
|
||||
Example:
|
||||
with temporary_modules({'old.module.path': 'new.module.path'}):
|
||||
import old.module.path # this will now import new.module.path
|
||||
|
||||
Note:
|
||||
The changes are only in effect inside the context manager and are undone once the context manager exits.
|
||||
Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
|
||||
applications or libraries. Use this function with caution.
|
||||
"""
|
||||
if not modules:
|
||||
modules = {}
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
try:
|
||||
# Set modules in sys.modules under their old name
|
||||
for old, new in modules.items():
|
||||
sys.modules[old] = importlib.import_module(new)
|
||||
|
||||
yield
|
||||
finally:
|
||||
# Remove the temporary module paths
|
||||
for old in modules:
|
||||
if old in sys.modules:
|
||||
del sys.modules[old]
|
||||
|
||||
|
||||
def torch_safe_load(weight):
|
||||
"""
|
||||
This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised,
|
||||
|
|
@ -510,12 +549,17 @@ def torch_safe_load(weight):
|
|||
Returns:
|
||||
(dict): The loaded PyTorch model.
|
||||
"""
|
||||
from ultralytics.yolo.utils.downloads import attempt_download_asset
|
||||
from ultralytics.utils.downloads import attempt_download_asset
|
||||
|
||||
check_suffix(file=weight, suffix='.pt')
|
||||
file = attempt_download_asset(weight) # search online if missing locally
|
||||
try:
|
||||
return torch.load(file, map_location='cpu'), file # load
|
||||
with temporary_modules({
|
||||
'ultralytics.yolo.utils': 'ultralytics.utils',
|
||||
'ultralytics.yolo.v8': 'ultralytics.models.yolo',
|
||||
'ultralytics.yolo.data': 'ultralytics.data'}): # for legacy 8.0 Classify and Pose models
|
||||
return torch.load(file, map_location='cpu'), file # load
|
||||
|
||||
except ModuleNotFoundError as e: # e.name is missing module name
|
||||
if e.name == 'models':
|
||||
raise TypeError(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue