Fix TensorBoard graph UserWarning catch (#4513)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
c7ceb84fb6
commit
3c40e7a9fc
3 changed files with 31 additions and 11 deletions
|
|
@ -10,6 +10,7 @@ import math
|
|||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
|
@ -378,7 +379,9 @@ class BaseTrainer:
|
|||
|
||||
self.lr = {f'lr/pg{ir}': x['lr'] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
||||
|
||||
self.scheduler.step()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore') # suppress 'Detected lr_scheduler.step() before optimizer.step()'
|
||||
self.scheduler.step()
|
||||
self.run_callbacks('on_train_epoch_end')
|
||||
|
||||
if RANK in (-1, 0):
|
||||
|
|
|
|||
|
|
@ -32,9 +32,9 @@ def _log_tensorboard_graph(trainer):
|
|||
imgsz = trainer.args.imgsz
|
||||
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
|
||||
p = next(trainer.model.parameters()) # for device, type
|
||||
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input (WARNING: must be zeros, not empty)
|
||||
with warnings.catch_warnings(category=UserWarning):
|
||||
warnings.simplefilter('ignore') # suppress jit trace warning
|
||||
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore', category=UserWarning) # suppress jit trace warning
|
||||
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
|
||||
except Exception as e:
|
||||
LOGGER.warning(f'WARNING ⚠️ TensorBoard graph visualization failure {e}')
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue