ultralytics 8.3.31 add max_num_obj factor for AutoBatch (#17514)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
e100484422
commit
4453ddab93
6 changed files with 38 additions and 12 deletions
|
|
@ -623,7 +623,7 @@ def convert_optimizer_state_dict_to_fp16(state_dict):
|
|||
return state_dict
|
||||
|
||||
|
||||
def profile(input, ops, n=10, device=None):
|
||||
def profile(input, ops, n=10, device=None, max_num_obj=0):
|
||||
"""
|
||||
Ultralytics speed, memory and FLOPs profiler.
|
||||
|
||||
|
|
@ -671,6 +671,14 @@ def profile(input, ops, n=10, device=None):
|
|||
t[2] = float("nan")
|
||||
tf += (t[1] - t[0]) * 1000 / n # ms per op forward
|
||||
tb += (t[2] - t[1]) * 1000 / n # ms per op backward
|
||||
if max_num_obj: # simulate training with predictions per image grid (for AutoBatch)
|
||||
torch.randn(
|
||||
x.shape[0],
|
||||
max_num_obj,
|
||||
int(sum([(x.shape[-1] / s) * (x.shape[-2] / s) for s in m.stride.tolist()])),
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
mem = torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0 # (GB)
|
||||
s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes
|
||||
p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue