Simplify thop imports (#18717)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2025-01-16 22:27:27 +01:00 committed by GitHub
parent c6dd277493
commit daaebba220
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 6 additions and 17 deletions

View file

@ -567,7 +567,7 @@ class HUBDatasetStats:
# Get dataset statistics
if self.task == "classify":
from torchvision.datasets import ImageFolder
from torchvision.datasets import ImageFolder # scope for faster 'import ultralytics'
dataset = ImageFolder(self.data[split])

View file

@ -839,7 +839,7 @@ class Results(SimpleClass):
>>> df_result = results[0].to_df()
>>> print(df_result)
"""
import pandas as pd
import pandas as pd # scope for faster 'import ultralytics'
return pd.DataFrame(self.summary(normalize=normalize, decimals=decimals))

View file

@ -1131,7 +1131,7 @@ class TorchVision(nn.Module):
def __init__(self, c1, c2, model, weights="DEFAULT", unwrap=True, truncate=2, split=False):
"""Load the model and weights from torchvision."""
import torchvision
import torchvision # scope for faster 'import ultralytics'
super().__init__()
if hasattr(torchvision.models, "get_model"):

View file

@ -7,6 +7,7 @@ import types
from copy import deepcopy
from pathlib import Path
import thop
import torch
import torch.nn as nn
@ -86,11 +87,6 @@ from ultralytics.utils.torch_utils import (
time_sync,
)
try:
import thop
except ImportError:
thop = None
class BaseModel(nn.Module):
"""The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family."""

View file

@ -12,6 +12,7 @@ from pathlib import Path
from typing import Union
import numpy as np
import thop
import torch
import torch.distributed as dist
import torch.nn as nn
@ -30,11 +31,6 @@ from ultralytics.utils import (
)
from ultralytics.utils.checks import check_version
try:
import thop
except ImportError:
thop = None
# Version checks (all default to version>=min_version)
TORCH_1_9 = check_version(torch.__version__, "1.9.0")
TORCH_1_13 = check_version(torch.__version__, "1.13.0")
@ -367,9 +363,6 @@ def model_info_for_loggers(trainer):
def get_flops(model, imgsz=640):
"""Return a YOLO model's FLOPs."""
if not thop:
return 0.0 # if not installed return 0.0 GFLOPs
try:
model = de_parallel(model)
p = next(model.parameters())
@ -674,7 +667,7 @@ def profile(input, ops, n=10, device=None, max_num_obj=0):
m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
try:
flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 # GFLOPs
except Exception:
flops = 0