diff --git a/ultralytics/utils/autobatch.py b/ultralytics/utils/autobatch.py index 2d09c5d8..86619998 100644 --- a/ultralytics/utils/autobatch.py +++ b/ultralytics/utils/autobatch.py @@ -69,7 +69,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch): batch_sizes = [1, 2, 4, 8, 16] try: img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes] - results = profile(img, model, n=3, device=device) + results = profile(img, model, n=1, device=device) # Fit a solution y = [x[2] for x in results if x] # memory [2] @@ -89,3 +89,5 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch): except Exception as e: LOGGER.warning(f"{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.") return batch_size + finally: + torch.cuda.empty_cache() diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py index db84ed69..15ccc621 100644 --- a/ultralytics/utils/torch_utils.py +++ b/ultralytics/utils/torch_utils.py @@ -643,7 +643,8 @@ def profile(input, ops, n=10, device=None): f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}" f"{'input':>24s}{'output':>24s}" ) - + gc.collect() # attempt to free unused memory + torch.cuda.empty_cache() for x in input if isinstance(input, list) else [input]: x = x.to(device) x.requires_grad = True @@ -677,8 +678,9 @@ def profile(input, ops, n=10, device=None): except Exception as e: LOGGER.info(e) results.append(None) - gc.collect() # attempt to free unused memory - torch.cuda.empty_cache() + finally: + gc.collect() # attempt to free unused memory + torch.cuda.empty_cache() return results