From daaebba220b5849c213a980026fdf0b5af9fa64b Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Thu, 16 Jan 2025 22:27:27 +0100 Subject: [PATCH] Simplify `thop` imports (#18717) Signed-off-by: Glenn Jocher Co-authored-by: UltralyticsAssistant --- ultralytics/data/utils.py | 2 +- ultralytics/engine/results.py | 2 +- ultralytics/nn/modules/block.py | 2 +- ultralytics/nn/tasks.py | 6 +----- ultralytics/utils/torch_utils.py | 11 ++--------- 5 files changed, 6 insertions(+), 17 deletions(-) diff --git a/ultralytics/data/utils.py b/ultralytics/data/utils.py index d5244454..50b597d8 100644 --- a/ultralytics/data/utils.py +++ b/ultralytics/data/utils.py @@ -567,7 +567,7 @@ class HUBDatasetStats: # Get dataset statistics if self.task == "classify": - from torchvision.datasets import ImageFolder + from torchvision.datasets import ImageFolder # scope for faster 'import ultralytics' dataset = ImageFolder(self.data[split]) diff --git a/ultralytics/engine/results.py b/ultralytics/engine/results.py index 911ff406..5ef0b778 100644 --- a/ultralytics/engine/results.py +++ b/ultralytics/engine/results.py @@ -839,7 +839,7 @@ class Results(SimpleClass): >>> df_result = results[0].to_df() >>> print(df_result) """ - import pandas as pd + import pandas as pd # scope for faster 'import ultralytics' return pd.DataFrame(self.summary(normalize=normalize, decimals=decimals)) diff --git a/ultralytics/nn/modules/block.py b/ultralytics/nn/modules/block.py index 634c835f..a961ae45 100644 --- a/ultralytics/nn/modules/block.py +++ b/ultralytics/nn/modules/block.py @@ -1131,7 +1131,7 @@ class TorchVision(nn.Module): def __init__(self, c1, c2, model, weights="DEFAULT", unwrap=True, truncate=2, split=False): """Load the model and weights from torchvision.""" - import torchvision + import torchvision # scope for faster 'import ultralytics' super().__init__() if hasattr(torchvision.models, "get_model"): diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index cdb8fc3d..c5bc8162 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -7,6 +7,7 @@ import types from copy import deepcopy from pathlib import Path +import thop import torch import torch.nn as nn @@ -86,11 +87,6 @@ from ultralytics.utils.torch_utils import ( time_sync, ) -try: - import thop -except ImportError: - thop = None - class BaseModel(nn.Module): """The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.""" diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index 85f511e6..e1cd9de8 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import Union import numpy as np +import thop import torch import torch.distributed as dist import torch.nn as nn @@ -30,11 +31,6 @@ from ultralytics.utils import ( ) from ultralytics.utils.checks import check_version -try: - import thop -except ImportError: - thop = None - # Version checks (all default to version>=min_version) TORCH_1_9 = check_version(torch.__version__, "1.9.0") TORCH_1_13 = check_version(torch.__version__, "1.13.0") @@ -367,9 +363,6 @@ def model_info_for_loggers(trainer): def get_flops(model, imgsz=640): """Return a YOLO model's FLOPs.""" - if not thop: - return 0.0 # if not installed return 0.0 GFLOPs - try: model = de_parallel(model) p = next(model.parameters()) @@ -674,7 +667,7 @@ def profile(input, ops, n=10, device=None, max_num_obj=0): m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward try: - flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs + flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 # GFLOPs except Exception: flops = 0