Set workers=0 for MPS Train and Val modes (#4697)

This commit is contained in:
Glenn Jocher 2023-09-02 04:25:03 +02:00 committed by GitHub
parent 2bc6e647c7
commit a4fabfdacf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 88 additions and 74 deletions

View file

@ -1,17 +1,16 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import os
import re
from pathlib import Path
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
from ultralytics.utils import LOGGER, ROOT, SETTINGS, TESTS_RUNNING, colorstr
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS['mlflow'] is True # verify integration is enabled
import mlflow
assert hasattr(mlflow, '__version__') # verify package is not directory
import os
import re
except (ImportError, AssertionError):
mlflow = None
@ -56,11 +55,10 @@ def on_fit_epoch_end(trainer):
def on_train_end(trainer):
"""Called at end of train loop to log model artifact info."""
if mlflow:
root_dir = Path(__file__).resolve().parents[3]
run.log_artifact(trainer.last)
run.log_artifact(trainer.best)
run.pyfunc.log_model(artifact_path=experiment_name,
code_path=[str(root_dir)],
code_path=[str(ROOT.parent)],
artifacts={'model_path': str(trainer.save_dir)},
python_model=run.pyfunc.PythonModel())