ultralytics 8.0.136 refactor and simplify package (#3748)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2023-07-16 23:47:45 +08:00 committed by GitHub
parent 8ebe94d1e9
commit 620f3eb218
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
383 changed files with 4213 additions and 4646 deletions

View file

@ -11,12 +11,12 @@ from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottlenec
Classify, Concat, Conv, Conv2, ConvTranspose, Detect, DWConv, DWConvTranspose2d,
Focus, GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv,
RTDETRDecoder, Segment)
from ultralytics.yolo.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
from ultralytics.yolo.utils.checks import check_requirements, check_suffix, check_yaml
from ultralytics.yolo.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
from ultralytics.yolo.utils.plotting import feature_visualization
from ultralytics.yolo.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights,
intersect_dicts, make_divisible, model_info, scale_img, time_sync)
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, colorstr, emojis, yaml_load
from ultralytics.utils.checks import check_requirements, check_suffix, check_yaml
from ultralytics.utils.loss import v8ClassificationLoss, v8DetectionLoss, v8PoseLoss, v8SegmentationLoss
from ultralytics.utils.plotting import feature_visualization
from ultralytics.utils.torch_utils import (fuse_conv_and_bn, fuse_deconv_and_bn, initialize_weights, intersect_dicts,
make_divisible, model_info, scale_img, time_sync)
try:
import thop
@ -412,7 +412,7 @@ class RTDETRDetectionModel(DetectionModel):
def init_criterion(self):
"""Compute the classification loss between predictions and true labels."""
from ultralytics.vit.utils.loss import RTDETRDetectionLoss
from ultralytics.models.utils.loss import RTDETRDetectionLoss
return RTDETRDetectionLoss(nc=self.nc, use_vfl=True)
@ -498,6 +498,45 @@ class Ensemble(nn.ModuleList):
# Functions ------------------------------------------------------------------------------------------------------------
@contextlib.contextmanager
def temporary_modules(modules=None):
"""
Context manager for temporarily adding or modifying modules in Python's module cache (`sys.modules`).
This function can be used to change the module paths during runtime. It's useful when refactoring code,
where you've moved a module from one location to another, but you still want to support the old import
paths for backwards compatibility.
Args:
modules (dict, optional): A dictionary mapping old module paths to new module paths.
Example:
with temporary_modules({'old.module.path': 'new.module.path'}):
import old.module.path # this will now import new.module.path
Note:
The changes are only in effect inside the context manager and are undone once the context manager exits.
Be aware that directly manipulating `sys.modules` can lead to unpredictable results, especially in larger
applications or libraries. Use this function with caution.
"""
if not modules:
modules = {}
import importlib
import sys
try:
# Set modules in sys.modules under their old name
for old, new in modules.items():
sys.modules[old] = importlib.import_module(new)
yield
finally:
# Remove the temporary module paths
for old in modules:
if old in sys.modules:
del sys.modules[old]
def torch_safe_load(weight):
"""
This function attempts to load a PyTorch model with the torch.load() function. If a ModuleNotFoundError is raised,
@ -510,12 +549,17 @@ def torch_safe_load(weight):
Returns:
(dict): The loaded PyTorch model.
"""
from ultralytics.yolo.utils.downloads import attempt_download_asset
from ultralytics.utils.downloads import attempt_download_asset
check_suffix(file=weight, suffix='.pt')
file = attempt_download_asset(weight) # search online if missing locally
try:
return torch.load(file, map_location='cpu'), file # load
with temporary_modules({
'ultralytics.yolo.utils': 'ultralytics.utils',
'ultralytics.yolo.v8': 'ultralytics.models.yolo',
'ultralytics.yolo.data': 'ultralytics.data'}): # for legacy 8.0 Classify and Pose models
return torch.load(file, map_location='cpu'), file # load
except ModuleNotFoundError as e: # e.name is missing module name
if e.name == 'models':
raise TypeError(