ultralytics 8.1.4 RTDETR TensorBoard graph visualization fix (#7725)

Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com>
Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
Glenn Jocher 2024-01-21 22:28:24 +01:00 committed by GitHub
parent 6535bcde2b
commit 7a0c27c7d7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 65 additions and 26 deletions

View file

@ -1,14 +1,21 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import contextlib
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
try:
# WARNING: do not move import due to protobuf issue in https://github.com/ultralytics/ultralytics/pull/4674
# WARNING: do not move SummaryWriter import due to protobuf bug https://github.com/ultralytics/ultralytics/pull/4674
from torch.utils.tensorboard import SummaryWriter
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS["tensorboard"] is True # verify integration is enabled
WRITER = None # TensorBoard SummaryWriter instance
PREFIX = colorstr("TensorBoard: ")
# Imports below only required if TensorBoard enabled
import warnings
from copy import deepcopy
from ultralytics.utils.torch_utils import de_parallel, torch
except (ImportError, AssertionError, TypeError, AttributeError):
# TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
@ -25,20 +32,37 @@ def _log_scalars(scalars, step=0):
def _log_tensorboard_graph(trainer):
"""Log model graph to TensorBoard."""
try:
import warnings
from ultralytics.utils.torch_utils import de_parallel, torch
# Input image
imgsz = trainer.args.imgsz
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
p = next(trainer.model.parameters()) # for device, type
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
imgsz = trainer.args.imgsz
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
p = next(trainer.model.parameters()) # for device, type
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning
warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning
# Try simple method first (YOLO)
with contextlib.suppress(Exception):
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
except Exception as e:
LOGGER.warning(f"WARNING ⚠️ TensorBoard graph visualization failure {e}")
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
return
# Fallback to TorchScript export steps (RTDETR)
try:
model = deepcopy(de_parallel(trainer.model))
model.eval()
model = model.fuse(verbose=False)
for m in model.modules():
if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
m.export = True
m.format = "torchscript"
model(im) # dry run
WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
except Exception as e:
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard graph visualization failure {e}")
def on_pretrain_routine_start(trainer):
@ -47,10 +71,9 @@ def on_pretrain_routine_start(trainer):
try:
global WRITER
WRITER = SummaryWriter(str(trainer.save_dir))
prefix = colorstr("TensorBoard: ")
LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
except Exception as e:
LOGGER.warning(f"WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
def on_train_start(trainer):