Set workers=0 for MPS Train and Val modes (#4697)

This commit is contained in:
Glenn Jocher 2023-09-02 04:25:03 +02:00 committed by GitHub
parent 2bc6e647c7
commit a4fabfdacf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 88 additions and 74 deletions

View file

@ -1,10 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
from ultralytics.utils.torch_utils import model_info_for_loggers
try:
assert not TESTS_RUNNING # do not log pytest
@ -13,11 +9,12 @@ try:
from neptune.types import File
assert hasattr(neptune, '__version__')
run = None # NeptuneAI experiment logger instance
except (ImportError, AssertionError):
neptune = None
run = None # NeptuneAI experiment logger instance
def _log_scalars(scalars, step=0):
"""Log scalars to the NeptuneAI experiment logger."""
@ -42,6 +39,9 @@ def _log_plot(title, plot_path):
title (str) Title of the plot
plot_path (PosixPath or str) Path to the saved image file
"""
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
img = mpimg.imread(plot_path)
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect='auto', xticks=[], yticks=[]) # no ticks
@ -70,6 +70,7 @@ def on_train_epoch_end(trainer):
def on_fit_epoch_end(trainer):
"""Callback function called at end of each fit (train+val) epoch."""
if run and trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
run['Configuration/Model'] = model_info_for_loggers(trainer)
_log_scalars(trainer.metrics, trainer.epoch + 1)