diff --git a/pyproject.toml b/pyproject.toml index e21c4b3e..6de999d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,7 @@ dependencies = [ "py-cpuinfo", # display CPU info "pandas>=1.1.4", "seaborn>=0.11.0", # plotting - "ultralytics-thop>=0.2.4", # FLOPs computation https://github.com/ultralytics/thop + "ultralytics-thop>=0.2.5", # FLOPs computation https://github.com/ultralytics/thop ] # Optional dependencies ------------------------------------------------------------------------------------------------ diff --git a/ultralytics/nn/tasks.py b/ultralytics/nn/tasks.py index 3e46ded7..e9453a84 100644 --- a/ultralytics/nn/tasks.py +++ b/ultralytics/nn/tasks.py @@ -4,7 +4,6 @@ import contextlib from copy import deepcopy from pathlib import Path -import thop import torch import torch.nn as nn @@ -66,6 +65,11 @@ from ultralytics.utils.torch_utils import ( time_sync, ) +try: + import thop +except ImportError: + thop = None + class BaseModel(nn.Module): """The BaseModel class serves as a base class for all the models in the Ultralytics YOLO family.""" @@ -153,7 +157,7 @@ class BaseModel(nn.Module): None """ c = m == self.model[-1] and isinstance(x, list) # is final layer list, copy input as inplace fix - flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 # GFLOPs + flops = thop.profile(m, inputs=[x.copy() if c else x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs t = time_sync() for _ in range(10): m(x.copy() if c else x) diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index 751e98ba..919fee07 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -11,7 +11,6 @@ from pathlib import Path from typing import Union import numpy as np -import thop import torch import torch.distributed as dist import torch.nn as nn @@ -28,6 +27,11 @@ from ultralytics.utils import ( ) from ultralytics.utils.checks import check_version +try: + import thop +except ImportError: + thop = None + # Version checks (all default to version>=min_version) TORCH_1_9 = check_version(torch.__version__, "1.9.0") TORCH_1_13 = check_version(torch.__version__, "1.13.0") @@ -304,6 +308,9 @@ def model_info_for_loggers(trainer): def get_flops(model, imgsz=640): """Return a YOLO model's FLOPs.""" + if not thop: + return 0.0 # if not installed return 0.0 GFLOPs + try: model = de_parallel(model) p = next(model.parameters()) @@ -564,7 +571,7 @@ def profile(input, ops, n=10, device=None): m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward try: - flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 # GFLOPs + flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs except Exception: flops = 0