New guess_model_task() function (#614)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
520825c4b2
commit
59d4335664
6 changed files with 63 additions and 29 deletions
|
|
@ -66,7 +66,7 @@ import torch
|
|||
|
||||
import ultralytics
|
||||
from ultralytics.nn.modules import Detect, Segment
|
||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel
|
||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, guess_model_task
|
||||
from ultralytics.yolo.cfg import get_cfg
|
||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
|
||||
from ultralytics.yolo.data.utils import check_det_dataset
|
||||
|
|
@ -74,7 +74,7 @@ from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, callbacks, colorstr, get
|
|||
from ultralytics.yolo.utils.checks import check_imgsz, check_requirements, check_version, check_yaml
|
||||
from ultralytics.yolo.utils.files import file_size
|
||||
from ultralytics.yolo.utils.ops import Profile
|
||||
from ultralytics.yolo.utils.torch_utils import guess_task_from_model_yaml, select_device, smart_inference_mode
|
||||
from ultralytics.yolo.utils.torch_utils import select_device, smart_inference_mode
|
||||
|
||||
MACOS = platform.system() == 'Darwin' # macOS environment
|
||||
|
||||
|
|
@ -235,7 +235,7 @@ class Exporter:
|
|||
# Finish
|
||||
f = [str(x) for x in f if x] # filter out '' and None
|
||||
if any(f):
|
||||
task = guess_task_from_model_yaml(model)
|
||||
task = guess_model_task(model)
|
||||
s = "-WARNING ⚠️ not yet supported for YOLOv8 exported models"
|
||||
LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
|
||||
f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue