CoreML NMS and half fixes (#143)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
55bdca6768
commit
172cef2d20
18 changed files with 95 additions and 414 deletions
|
|
@ -67,7 +67,7 @@ import torch
|
|||
|
||||
import ultralytics
|
||||
from ultralytics.nn.modules import Detect, Segment
|
||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights
|
||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel
|
||||
from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.data.dataloaders.stream_loaders import LoadImages
|
||||
from ultralytics.yolo.data.utils import check_dataset
|
||||
|
|
@ -154,7 +154,7 @@ class Exporter:
|
|||
# Load PyTorch model
|
||||
self.device = select_device(self.args.device or 'cpu')
|
||||
if self.args.half:
|
||||
if self.device.type == 'cpu' or not coreml:
|
||||
if self.device.type == 'cpu' and not coreml:
|
||||
LOGGER.info('half=True only compatible with GPU or CoreML export, i.e. use device=0 or format=coreml')
|
||||
self.args.half = False
|
||||
assert not self.args.dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic'
|
||||
|
|
@ -769,17 +769,22 @@ class Exporter:
|
|||
def export(cfg):
|
||||
cfg.model = cfg.model or "yolov8n.yaml"
|
||||
cfg.format = cfg.format or "torchscript"
|
||||
exporter = Exporter(cfg)
|
||||
|
||||
model = None
|
||||
if isinstance(cfg.model, (str, Path)):
|
||||
if Path(cfg.model).suffix == '.yaml':
|
||||
model = DetectionModel(cfg.model)
|
||||
elif Path(cfg.model).suffix == '.pt':
|
||||
model = attempt_load_weights(cfg.model, fuse=True)
|
||||
else:
|
||||
TypeError(f'Unsupported model type {cfg.model}')
|
||||
exporter(model=model)
|
||||
# exporter = Exporter(cfg)
|
||||
#
|
||||
# model = None
|
||||
# if isinstance(cfg.model, (str, Path)):
|
||||
# if Path(cfg.model).suffix == '.yaml':
|
||||
# model = DetectionModel(cfg.model)
|
||||
# elif Path(cfg.model).suffix == '.pt':
|
||||
# model = attempt_load_weights(cfg.model, fuse=True)
|
||||
# else:
|
||||
# TypeError(f'Unsupported model type {cfg.model}')
|
||||
# exporter(model=model)
|
||||
|
||||
from ultralytics import YOLO
|
||||
model = YOLO(cfg.model)
|
||||
model.export(**cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue