ultralytics 8.3.2 fix AMP checks with imgsz=256 (#16583)

This commit is contained in:
Glenn Jocher 2024-10-01 11:53:11 +02:00 committed by GitHub
parent c327b0aae1
commit 5af8a5c0fb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 13 additions and 3 deletions

View file

@ -657,9 +657,10 @@ def check_amp(model):
def amp_allclose(m, im):
"""All close FP32 vs AMP results."""
batch = [im] * 8
a = m(batch, imgsz=128, device=device, verbose=False)[0].boxes.data # FP32 inference
imgsz = max(256, int(model.stride.max() * 4)) # max stride P5-32 and P6-64
a = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # FP32 inference
with autocast(enabled=True):
b = m(batch, imgsz=128, device=device, verbose=False)[0].boxes.data # AMP inference
b = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # AMP inference
del m
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance