diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index 64bd6d6f..f859cd5f 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -331,19 +331,28 @@ def get_flops(model, imgsz=640): def get_flops_with_torch_profiler(model, imgsz=640): - """Compute model FLOPs (thop alternative).""" - if TORCH_2_0: - model = de_parallel(model) - p = next(model.parameters()) + """Compute model FLOPs (thop package alternative, but 2-10x slower unfortunately).""" + if not TORCH_2_0: # torch profiler implemented in torch>=2.0 + return 0.0 + model = de_parallel(model) + p = next(model.parameters()) + if not isinstance(imgsz, list): + imgsz = [imgsz, imgsz] # expand if int/float + try: + # Use stride size for input tensor stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride - im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format + im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format with torch.profiler.profile(with_flops=True) as prof: model(im) flops = sum(x.flops for x in prof.key_averages()) / 1e9 - imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs - return flops - return 0 + except Exception: + # Use actual image size for input tensor (i.e. required for RTDETR models) + im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format + with torch.profiler.profile(with_flops=True) as prof: + model(im) + flops = sum(x.flops for x in prof.key_averages()) / 1e9 + return flops def initialize_weights(model):