Disable FP16 val on AMP fail and improve AMP checks (#16306)
Co-authored-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: Ultralytics Assistant <135830346+UltralyticsAssistant@users.noreply.github.com>
This commit is contained in:
parent
ba438aea5a
commit
6f2bb65953
2 changed files with 5 additions and 3 deletions
|
|
@ -110,7 +110,8 @@ class BaseValidator:
|
||||||
if self.training:
|
if self.training:
|
||||||
self.device = trainer.device
|
self.device = trainer.device
|
||||||
self.data = trainer.data
|
self.data = trainer.data
|
||||||
self.args.half = self.device.type != "cpu" # force FP16 val during training
|
# force FP16 val during training
|
||||||
|
self.args.half = self.device.type != "cpu" and self.args.amp
|
||||||
model = trainer.ema.ema or trainer.model
|
model = trainer.ema.ema or trainer.model
|
||||||
model = model.half() if self.args.half else model.float()
|
model = model.half() if self.args.half else model.float()
|
||||||
# self.model = model
|
# self.model = model
|
||||||
|
|
|
||||||
|
|
@ -656,9 +656,10 @@ def check_amp(model):
|
||||||
|
|
||||||
def amp_allclose(m, im):
|
def amp_allclose(m, im):
|
||||||
"""All close FP32 vs AMP results."""
|
"""All close FP32 vs AMP results."""
|
||||||
a = m(im, device=device, verbose=False)[0].boxes.data # FP32 inference
|
batch = [im] * 8
|
||||||
|
a = m(batch, imgsz=128, device=device, verbose=False)[0].boxes.data # FP32 inference
|
||||||
with autocast(enabled=True):
|
with autocast(enabled=True):
|
||||||
b = m(im, device=device, verbose=False)[0].boxes.data # AMP inference
|
b = m(batch, imgsz=128, device=device, verbose=False)[0].boxes.data # AMP inference
|
||||||
del m
|
del m
|
||||||
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
|
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue