Fix TensorBoard graph UserWarning catch (#4513)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-08-23 12:52:24 +02:00 committed by GitHub
parent c7ceb84fb6
commit 3c40e7a9fc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 31 additions and 11 deletions

View file

@ -32,9 +32,9 @@ def _log_tensorboard_graph(trainer):
imgsz = trainer.args.imgsz
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
p = next(trainer.model.parameters()) # for device, type
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input (WARNING: must be zeros, not empty)
with warnings.catch_warnings(category=UserWarning):
warnings.simplefilter('ignore') # suppress jit trace warning
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=UserWarning) # suppress jit trace warning
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
except Exception as e:
LOGGER.warning(f'WARNING ⚠️ TensorBoard graph visualization failure {e}')