Simplify thop imports (#18717)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2025-01-16 22:27:27 +01:00 committed by GitHub
parent c6dd277493
commit daaebba220
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 6 additions and 17 deletions

View file

@ -567,7 +567,7 @@ class HUBDatasetStats:
# Get dataset statistics # Get dataset statistics
if self.task == "classify": if self.task == "classify":
from torchvision.datasets import ImageFolder from torchvision.datasets import ImageFolder # scope for faster 'import ultralytics'
dataset = ImageFolder(self.data[split]) dataset = ImageFolder(self.data[split])

View file

@ -839,7 +839,7 @@ class Results(SimpleClass):
>>> df_result = results[0].to_df() >>> df_result = results[0].to_df()
>>> print(df_result) >>> print(df_result)
""" """
import pandas as pd import pandas as pd # scope for faster 'import ultralytics'
return pd.DataFrame(self.summary(normalize=normalize, decimals=decimals)) return pd.DataFrame(self.summary(normalize=normalize, decimals=decimals))

View file

@ -1131,7 +1131,7 @@ class TorchVision(nn.Module):
def __init__(self, c1, c2, model, weights="DEFAULT", unwrap=True, truncate=2, split=False): def __init__(self, c1, c2, model, weights="DEFAULT", unwrap=True, truncate=2, split=False):
"""Load the model and weights from torchvision.""" """Load the model and weights from torchvision."""
import torchvision import torchvision # scope for faster 'import ultralytics'
super().__init__() super().__init__()
if hasattr(torchvision.models, "get_model"): if hasattr(torchvision.models, "get_model"):

View file

@ -7,6 +7,7 @@ import types
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
import thop
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -86,11 +87,6 @@ from ultralytics.utils.torch_utils import (
time_sync, time_sync,
) )
try:
import thop
except ImportError:
thop = None
class BaseModel(nn.Module): class BaseModel(nn.Module):
"""The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.""" """The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family."""

View file

@ -12,6 +12,7 @@ from pathlib import Path
from typing import Union from typing import Union
import numpy as np import numpy as np
import thop
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
@ -30,11 +31,6 @@ from ultralytics.utils import (
) )
from ultralytics.utils.checks import check_version from ultralytics.utils.checks import check_version
try:
import thop
except ImportError:
thop = None
# Version checks (all default to version>=min_version) # Version checks (all default to version>=min_version)
TORCH_1_9 = check_version(torch.__version__, "1.9.0") TORCH_1_9 = check_version(torch.__version__, "1.9.0")
TORCH_1_13 = check_version(torch.__version__, "1.13.0") TORCH_1_13 = check_version(torch.__version__, "1.13.0")
@ -367,9 +363,6 @@ def model_info_for_loggers(trainer):
def get_flops(model, imgsz=640): def get_flops(model, imgsz=640):
"""Return a YOLO model's FLOPs.""" """Return a YOLO model's FLOPs."""
if not thop:
return 0.0 # if not installed return 0.0 GFLOPs
try: try:
model = de_parallel(model) model = de_parallel(model)
p = next(model.parameters()) p = next(model.parameters())
@ -674,7 +667,7 @@ def profile(input, ops, n=10, device=None, max_num_obj=0):
m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
try: try:
flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 # GFLOPs
except Exception: except Exception:
flops = 0 flops = 0