ultralytics 8.1.4 RTDETR TensorBoard graph visualization fix (#7725)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
6535bcde2b
commit
7a0c27c7d7
8 changed files with 65 additions and 26 deletions
|
|
@ -1,14 +1,21 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
import contextlib
|
||||
|
||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
|
||||
|
||||
try:
|
||||
# WARNING: do not move import due to protobuf issue in https://github.com/ultralytics/ultralytics/pull/4674
|
||||
# WARNING: do not move SummaryWriter import due to protobuf bug https://github.com/ultralytics/ultralytics/pull/4674
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS["tensorboard"] is True # verify integration is enabled
|
||||
WRITER = None # TensorBoard SummaryWriter instance
|
||||
PREFIX = colorstr("TensorBoard: ")
|
||||
|
||||
# Imports below only required if TensorBoard enabled
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from ultralytics.utils.torch_utils import de_parallel, torch
|
||||
|
||||
except (ImportError, AssertionError, TypeError, AttributeError):
|
||||
# TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
|
||||
|
|
@ -25,20 +32,37 @@ def _log_scalars(scalars, step=0):
|
|||
|
||||
def _log_tensorboard_graph(trainer):
|
||||
"""Log model graph to TensorBoard."""
|
||||
try:
|
||||
import warnings
|
||||
|
||||
from ultralytics.utils.torch_utils import de_parallel, torch
|
||||
# Input image
|
||||
imgsz = trainer.args.imgsz
|
||||
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
|
||||
p = next(trainer.model.parameters()) # for device, type
|
||||
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
|
||||
|
||||
imgsz = trainer.args.imgsz
|
||||
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
|
||||
p = next(trainer.model.parameters()) # for device, type
|
||||
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning
|
||||
warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning
|
||||
|
||||
# Try simple method first (YOLO)
|
||||
with contextlib.suppress(Exception):
|
||||
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"WARNING ⚠️ TensorBoard graph visualization failure {e}")
|
||||
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
|
||||
return
|
||||
|
||||
# Fallback to TorchScript export steps (RTDETR)
|
||||
try:
|
||||
model = deepcopy(de_parallel(trainer.model))
|
||||
model.eval()
|
||||
model = model.fuse(verbose=False)
|
||||
for m in model.modules():
|
||||
if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
|
||||
m.export = True
|
||||
m.format = "torchscript"
|
||||
model(im) # dry run
|
||||
WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
|
||||
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard graph visualization failure {e}")
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
|
|
@ -47,10 +71,9 @@ def on_pretrain_routine_start(trainer):
|
|||
try:
|
||||
global WRITER
|
||||
WRITER = SummaryWriter(str(trainer.save_dir))
|
||||
prefix = colorstr("TensorBoard: ")
|
||||
LOGGER.info(f"{prefix}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
|
||||
LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
|
||||
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard not initialized correctly, not logging this run. {e}")
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue