ultralytics 8.3.52 AutoBatch CUDA computation improvements (#18291)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Laughing 2024-12-20 20:51:21 +08:00 committed by GitHub
parent 00e239a5da
commit e3c46920e7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 52 additions and 18 deletions

View file

@ -127,6 +127,10 @@ keywords: Ultralytics, torch utils, model optimization, device selection, infere
<br><br><hr><br> <br><br><hr><br>
## ::: ultralytics.utils.torch_utils.cuda_memory_usage
<br><br><hr><br>
## ::: ultralytics.utils.torch_utils.profile ## ::: ultralytics.utils.torch_utils.profile
<br><br> <br><br>

View file

@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license # Ultralytics YOLO 🚀, AGPL-3.0 license
__version__ = "8.3.51" __version__ = "8.3.52"
import os import os

View file

@ -617,6 +617,32 @@ def convert_optimizer_state_dict_to_fp16(state_dict):
return state_dict return state_dict
@contextmanager
def cuda_memory_usage(device=None):
"""
Monitor and manage CUDA memory usage.
This function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory.
It then yields a dictionary containing memory usage information, which can be updated by the caller.
Finally, it updates the dictionary with the amount of memory reserved by CUDA on the specified device.
Args:
device (torch.device, optional): The CUDA device to query memory usage for. Defaults to None.
Yields:
(dict): A dictionary with a key 'memory' initialized to 0, which will be updated with the reserved memory.
"""
cuda_info = dict(memory=0)
if torch.cuda.is_available():
torch.cuda.empty_cache()
try:
yield cuda_info
finally:
cuda_info["memory"] = torch.cuda.memory_reserved(device)
else:
yield cuda_info
def profile(input, ops, n=10, device=None, max_num_obj=0): def profile(input, ops, n=10, device=None, max_num_obj=0):
""" """
Ultralytics speed, memory and FLOPs profiler. Ultralytics speed, memory and FLOPs profiler.
@ -653,27 +679,31 @@ def profile(input, ops, n=10, device=None, max_num_obj=0):
flops = 0 flops = 0
try: try:
mem = 0
for _ in range(n): for _ in range(n):
t[0] = time_sync() with cuda_memory_usage(device) as cuda_info:
y = m(x) t[0] = time_sync()
t[1] = time_sync() y = m(x)
try: t[1] = time_sync()
(sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward() try:
t[2] = time_sync() (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
except Exception: # no backward method t[2] = time_sync()
# print(e) # for debug except Exception: # no backward method
t[2] = float("nan") # print(e) # for debug
t[2] = float("nan")
mem += cuda_info["memory"] / 1e9 # (GB)
tf += (t[1] - t[0]) * 1000 / n # ms per op forward tf += (t[1] - t[0]) * 1000 / n # ms per op forward
tb += (t[2] - t[1]) * 1000 / n # ms per op backward tb += (t[2] - t[1]) * 1000 / n # ms per op backward
if max_num_obj: # simulate training with predictions per image grid (for AutoBatch) if max_num_obj: # simulate training with predictions per image grid (for AutoBatch)
torch.randn( with cuda_memory_usage(device) as cuda_info:
x.shape[0], torch.randn(
max_num_obj, x.shape[0],
int(sum((x.shape[-1] / s) * (x.shape[-2] / s) for s in m.stride.tolist())), max_num_obj,
device=device, int(sum((x.shape[-1] / s) * (x.shape[-2] / s) for s in m.stride.tolist())),
dtype=torch.float32, device=device,
) dtype=torch.float32,
mem = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 # (GB) )
mem += cuda_info["memory"] / 1e9 # (GB)
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes
p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}") LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")