New guess_model_task() function (#614)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-01-25 02:24:28 +01:00 committed by GitHub
parent 520825c4b2
commit 59d4335664
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 63 additions and 29 deletions

View file

@ -66,7 +66,7 @@ import torch
import ultralytics
from ultralytics.nn.modules import Detect, Segment
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, guess_model_task
from ultralytics.yolo.cfg import get_cfg
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
from ultralytics.yolo.data.utils import check_det_dataset
@ -74,7 +74,7 @@ from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, get
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
from ultralytics.yolo.utils.files import file_size
from ultralytics.yolo.utils.ops import Profile
from ultralytics.yolo.utils.torch_utils import guess_task_from_model_yaml, select_device, smart_inference_mode
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
MACOS = platform.system() == 'Darwin' # macOS environment
@ -235,7 +235,7 @@ class Exporter:
# Finish
f = [str(x) for x in f if x] # filter out '' and None
if any(f):
task = guess_task_from_model_yaml(model)
task = guess_model_task(model)
s = "-WARNING ⚠️ not yet supported for YOLOv8 exported models"
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"