ultralytics 8.3.2 fix AMP checks with imgsz=256 (#16583)
This commit is contained in:
parent
c327b0aae1
commit
5af8a5c0fb
4 changed files with 13 additions and 3 deletions
|
|
@ -111,6 +111,7 @@ torch.set_printoptions(linewidth=320, precision=4, profile="default")
|
|||
np.set_printoptions(linewidth=320, formatter={"float_kind": "{:11.5g}".format}) # format short g, %precision=5
|
||||
cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
|
||||
os.environ["NUMEXPR_MAX_THREADS"] = str(NUM_THREADS) # NumExpr max threads
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # for deterministic training to avoid CUDA warning
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # suppress verbose TF compiler warnings in Colab
|
||||
os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR" # suppress "NNPACK.cpp could not initialize NNPACK" warnings
|
||||
os.environ["KINETO_LOG_LEVEL"] = "5" # suppress verbose PyTorch profiler output when computing FLOPs
|
||||
|
|
|
|||
|
|
@ -657,9 +657,10 @@ def check_amp(model):
|
|||
def amp_allclose(m, im):
|
||||
"""All close FP32 vs AMP results."""
|
||||
batch = [im] * 8
|
||||
a = m(batch, imgsz=128, device=device, verbose=False)[0].boxes.data # FP32 inference
|
||||
imgsz = max(256, int(model.stride.max() * 4)) # max stride P5-32 and P6-64
|
||||
a = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # FP32 inference
|
||||
with autocast(enabled=True):
|
||||
b = m(batch, imgsz=128, device=device, verbose=False)[0].boxes.data # AMP inference
|
||||
b = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # AMP inference
|
||||
del m
|
||||
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue