Add model.eval() in TensorBoad graph visualization to avoid BN stats changes (#8629)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
9c42596145
commit
609a0cefbf
1 changed files with 1 additions and 0 deletions
|
|
@ -45,6 +45,7 @@ def _log_tensorboard_graph(trainer):
|
||||||
|
|
||||||
# Try simple method first (YOLO)
|
# Try simple method first (YOLO)
|
||||||
with contextlib.suppress(Exception):
|
with contextlib.suppress(Exception):
|
||||||
|
trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes
|
||||||
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
|
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
|
||||||
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
|
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
|
||||||
return
|
return
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue