AutoBatch improve cache clearing (#16744)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ultralytics Assistant 2024-10-07 20:11:10 +02:00 committed by GitHub
parent f7808dc61b
commit 97521186a7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 8 additions and 4 deletions

View file

@ -69,7 +69,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
batch_sizes = [1, 2, 4, 8, 16]
try:
img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
results = profile(img, model, n=3, device=device)
results = profile(img, model, n=1, device=device)
# Fit a solution
y = [x[2] for x in results if x] # memory [2]
@ -89,3 +89,5 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
except Exception as e:
LOGGER.warning(f"{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.")
return batch_size
finally:
torch.cuda.empty_cache()