ultralytics 8.3.8 replace contextlib with try for speed (#16782)
Signed-off-by: UltralyticsAssistant <web@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
1e6c454460
commit
a6a577961f
12 changed files with 115 additions and 88 deletions
|
|
@ -1,6 +1,5 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import contextlib
|
||||
|
||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
|
||||
|
||||
|
|
@ -45,26 +44,27 @@ def _log_tensorboard_graph(trainer):
|
|||
warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning
|
||||
|
||||
# Try simple method first (YOLO)
|
||||
with contextlib.suppress(Exception):
|
||||
try:
|
||||
trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes
|
||||
WRITER.add_graph(torch.jit.trace(de_parallel(trainer.model), im, strict=False), [])
|
||||
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
|
||||
return
|
||||
|
||||
# Fallback to TorchScript export steps (RTDETR)
|
||||
try:
|
||||
model = deepcopy(de_parallel(trainer.model))
|
||||
model.eval()
|
||||
model = model.fuse(verbose=False)
|
||||
for m in model.modules():
|
||||
if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
|
||||
m.export = True
|
||||
m.format = "torchscript"
|
||||
model(im) # dry run
|
||||
WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
|
||||
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard graph visualization failure {e}")
|
||||
except: # noqa E722
|
||||
# Fallback to TorchScript export steps (RTDETR)
|
||||
try:
|
||||
model = deepcopy(de_parallel(trainer.model))
|
||||
model.eval()
|
||||
model = model.fuse(verbose=False)
|
||||
for m in model.modules():
|
||||
if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
|
||||
m.export = True
|
||||
m.format = "torchscript"
|
||||
model(im) # dry run
|
||||
WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
|
||||
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{PREFIX}WARNING ⚠️ TensorBoard graph visualization failure {e}")
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue