Update get_flops_with_torch_profiler() (#12635)
This commit is contained in:
parent
693abf5958
commit
5ac164940e
1 changed files with 17 additions and 8 deletions
|
|
@ -331,19 +331,28 @@ def get_flops(model, imgsz=640):
|
||||||
|
|
||||||
|
|
||||||
def get_flops_with_torch_profiler(model, imgsz=640):
|
def get_flops_with_torch_profiler(model, imgsz=640):
|
||||||
"""Compute model FLOPs (thop alternative)."""
|
"""Compute model FLOPs (thop package alternative, but 2-10x slower unfortunately)."""
|
||||||
if TORCH_2_0:
|
if not TORCH_2_0: # torch profiler implemented in torch>=2.0
|
||||||
model = de_parallel(model)
|
return 0.0
|
||||||
p = next(model.parameters())
|
model = de_parallel(model)
|
||||||
|
p = next(model.parameters())
|
||||||
|
if not isinstance(imgsz, list):
|
||||||
|
imgsz = [imgsz, imgsz] # expand if int/float
|
||||||
|
try:
|
||||||
|
# Use stride size for input tensor
|
||||||
stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride
|
stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride
|
||||||
im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
|
||||||
with torch.profiler.profile(with_flops=True) as prof:
|
with torch.profiler.profile(with_flops=True) as prof:
|
||||||
model(im)
|
model(im)
|
||||||
flops = sum(x.flops for x in prof.key_averages()) / 1e9
|
flops = sum(x.flops for x in prof.key_averages()) / 1e9
|
||||||
imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
|
|
||||||
flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
|
flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
|
||||||
return flops
|
except Exception:
|
||||||
return 0
|
# Use actual image size for input tensor (i.e. required for RTDETR models)
|
||||||
|
im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
|
||||||
|
with torch.profiler.profile(with_flops=True) as prof:
|
||||||
|
model(im)
|
||||||
|
flops = sum(x.flops for x in prof.key_averages()) / 1e9
|
||||||
|
return flops
|
||||||
|
|
||||||
|
|
||||||
def initialize_weights(model):
|
def initialize_weights(model):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue