Fix AutoBatch when working with RT-DETR models (#18912)

This commit is contained in:
Laughing 2025-01-27 18:04:03 +08:00 committed by GitHub
parent 305a298ae2
commit 30a2de164b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 2 additions and 2 deletions

View file

@ -172,7 +172,7 @@ def get_cdn_group(
bounding boxes, attention mask and meta information for denoising. If not in training mode or 'num_dn' bounding boxes, attention mask and meta information for denoising. If not in training mode or 'num_dn'
is less than or equal to 0, the function returns None for all elements in the tuple. is less than or equal to 0, the function returns None for all elements in the tuple.
""" """
if (not training) or num_dn <= 0: if (not training) or num_dn <= 0 or batch is None:
return None, None, None, None return None, None, None, None
gt_groups = batch["gt_groups"] gt_groups = batch["gt_groups"]
total_num = sum(gt_groups) total_num = sum(gt_groups)

View file

@ -667,7 +667,7 @@ def profile(input, ops, n=10, device=None, max_num_obj=0):
m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
try: try:
flops = thop.profile(m, inputs=[x], verbose=False)[0] / 1e9 * 2 # GFLOPs flops = thop.profile(deepcopy(m), inputs=[x], verbose=False)[0] / 1e9 * 2 # GFLOPs
except Exception: except Exception:
flops = 0 flops = 0