Fix TensorBoard graph UserWarning catch (#4513)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
c7ceb84fb6
commit
3c40e7a9fc
3 changed files with 31 additions and 11 deletions
|
|
@ -32,9 +32,9 @@ def _log_tensorboard_graph(trainer):
|
|||
imgsz = trainer.args.imgsz
|
||||
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
|
||||
p = next(trainer.model.parameters()) # for device, type
|
||||
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input (WARNING: must be zeros, not empty)
|
||||
with warnings.catch_warnings(category=UserWarning):
|
||||
warnings.simplefilter('ignore') # suppress jit trace warning
|
||||
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', category=UserWarning) # suppress jit trace warning
|
||||
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ TensorBoard graph visualization failure {e}')
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue