AutoBatch improve cache clearing (#16744)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
f7808dc61b
commit
97521186a7
2 changed files with 8 additions and 4 deletions
|
|
@ -69,7 +69,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
|
||||||
batch_sizes = [1, 2, 4, 8, 16]
|
batch_sizes = [1, 2, 4, 8, 16]
|
||||||
try:
|
try:
|
||||||
img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
|
img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
|
||||||
results = profile(img, model, n=3, device=device)
|
results = profile(img, model, n=1, device=device)
|
||||||
|
|
||||||
# Fit a solution
|
# Fit a solution
|
||||||
y = [x[2] for x in results if x] # memory [2]
|
y = [x[2] for x in results if x] # memory [2]
|
||||||
|
|
@ -89,3 +89,5 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.warning(f"{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.")
|
LOGGER.warning(f"{prefix}WARNING ⚠️ error detected: {e}, using default batch-size {batch_size}.")
|
||||||
return batch_size
|
return batch_size
|
||||||
|
finally:
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
|
||||||
|
|
@ -643,7 +643,8 @@ def profile(input, ops, n=10, device=None):
|
||||||
f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
|
f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
|
||||||
f"{'input':>24s}{'output':>24s}"
|
f"{'input':>24s}{'output':>24s}"
|
||||||
)
|
)
|
||||||
|
gc.collect() # attempt to free unused memory
|
||||||
|
torch.cuda.empty_cache()
|
||||||
for x in input if isinstance(input, list) else [input]:
|
for x in input if isinstance(input, list) else [input]:
|
||||||
x = x.to(device)
|
x = x.to(device)
|
||||||
x.requires_grad = True
|
x.requires_grad = True
|
||||||
|
|
@ -677,8 +678,9 @@ def profile(input, ops, n=10, device=None):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
LOGGER.info(e)
|
LOGGER.info(e)
|
||||||
results.append(None)
|
results.append(None)
|
||||||
gc.collect() # attempt to free unused memory
|
finally:
|
||||||
torch.cuda.empty_cache()
|
gc.collect() # attempt to free unused memory
|
||||||
|
torch.cuda.empty_cache()
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue