Set workers=0 for MPS Train and Val modes (#4697)
This commit is contained in:
parent
2bc6e647c7
commit
a4fabfdacf
14 changed files with 88 additions and 74 deletions
|
|
@ -1,17 +1,16 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr
|
||||
from ultralytics.utils import LOGGER, ROOT, SETTINGS, TESTS_RUNNING, colorstr
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS['mlflow'] is True # verify integration is enabled
|
||||
import mlflow
|
||||
|
||||
assert hasattr(mlflow, '__version__') # verify package is not directory
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
mlflow = None
|
||||
|
||||
|
|
@ -56,11 +55,10 @@ def on_fit_epoch_end(trainer):
|
|||
def on_train_end(trainer):
|
||||
"""Called at end of train loop to log model artifact info."""
|
||||
if mlflow:
|
||||
root_dir = Path(__file__).resolve().parents[3]
|
||||
run.log_artifact(trainer.last)
|
||||
run.log_artifact(trainer.best)
|
||||
run.pyfunc.log_model(artifact_path=experiment_name,
|
||||
code_path=[str(root_dir)],
|
||||
code_path=[str(ROOT.parent)],
|
||||
artifacts={'model_path': str(trainer.save_dir)},
|
||||
python_model=run.pyfunc.PythonModel())
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue