Start export implementation (#110)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
c1b38428bc
commit
92dad1c1b5
32 changed files with 827 additions and 222 deletions
|
|
@ -1,13 +1,13 @@
|
|||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
from ultralytics import yolo # noqa required for python usage
|
||||
from ultralytics.nn.tasks import ClassificationModel, DetectionModel, SegmentationModel, attempt_load_weights
|
||||
# from ultralytics.yolo.data.utils import check_dataset, check_dataset_yaml
|
||||
from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG
|
||||
from ultralytics.yolo.utils import HELP_MSG, LOGGER
|
||||
from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.engine.exporter import export_model
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, HELP_MSG, LOGGER
|
||||
from ultralytics.yolo.utils.checks import check_yaml
|
||||
from ultralytics.yolo.utils.configs import get_config
|
||||
from ultralytics.yolo.utils.files import yaml_load
|
||||
from ultralytics.yolo.utils.torch_utils import smart_inference_mode
|
||||
|
||||
|
|
@ -36,7 +36,7 @@ class YOLO:
|
|||
type (str): Type/version of models to use
|
||||
"""
|
||||
if init_key != YOLO.__init_key:
|
||||
raise Exception(HELP_MSG)
|
||||
raise SyntaxError(HELP_MSG)
|
||||
|
||||
self.type = type
|
||||
self.ModelClass = None
|
||||
|
|
@ -46,7 +46,8 @@ class YOLO:
|
|||
self.model = None
|
||||
self.trainer = None
|
||||
self.task = None
|
||||
self.ckpt = None
|
||||
self.ckpt = None # if loaded from *.pt
|
||||
self.cfg = None # if loaded from *.yaml
|
||||
self.overrides = {}
|
||||
self.init_disabled = False
|
||||
|
||||
|
|
@ -59,12 +60,12 @@ class YOLO:
|
|||
cfg (str): model configuration file
|
||||
"""
|
||||
cfg = check_yaml(cfg) # check YAML
|
||||
with open(cfg, encoding='ascii', errors='ignore') as f:
|
||||
cfg = yaml.safe_load(f) # model dict
|
||||
cfg_dict = yaml_load(cfg) # model dict
|
||||
obj = cls(init_key=cls.__init_key)
|
||||
obj.task = obj._guess_task_from_head(cfg["head"][-1][-2])
|
||||
obj.task = obj._guess_task_from_head(cfg_dict["head"][-1][-2])
|
||||
obj.ModelClass, obj.TrainerClass, obj.ValidatorClass, obj.PredictorClass = obj._guess_ops_from_task(obj.task)
|
||||
obj.model = obj.ModelClass(cfg) # initialize
|
||||
obj.model = obj.ModelClass(cfg_dict) # initialize
|
||||
obj.cfg = cfg
|
||||
|
||||
return obj
|
||||
|
||||
|
|
@ -116,13 +117,14 @@ class YOLO:
|
|||
LOGGER.info("model not initialized!")
|
||||
self.model.fuse()
|
||||
|
||||
@smart_inference_mode()
|
||||
def predict(self, source, **kwargs):
|
||||
"""
|
||||
Visualize prection.
|
||||
Visualize prediction.
|
||||
|
||||
Args:
|
||||
source (str): Accepts all source types accepted by yolo
|
||||
**kwargs : Any other args accepted by the predictors. Too see all args check 'configuration' section in the docs
|
||||
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in the docs
|
||||
"""
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
|
|
@ -131,7 +133,7 @@ class YOLO:
|
|||
|
||||
# check size type
|
||||
sz = predictor.args.imgsz
|
||||
if type(sz) != int: # recieved listConfig
|
||||
if type(sz) != int: # received listConfig
|
||||
predictor.args.imgsz = [sz[0], sz[0]] if len(sz) == 1 else [sz[0], sz[1]] # expand
|
||||
else:
|
||||
predictor.args.imgsz = [sz, sz]
|
||||
|
|
@ -139,16 +141,17 @@ class YOLO:
|
|||
predictor.setup(model=self.model, source=source)
|
||||
predictor()
|
||||
|
||||
@smart_inference_mode()
|
||||
def val(self, data=None, **kwargs):
|
||||
"""
|
||||
Validate a model on a given dataset
|
||||
|
||||
Args:
|
||||
data (str): The dataset to validate on. Accepts all formats accepted by yolo
|
||||
kwargs: Any other args accepted by the validators. Too see all args check 'configuration' section in the docs
|
||||
kwargs: Any other args accepted by the validators. To see all args check 'configuration' section in the docs
|
||||
"""
|
||||
if not self.model:
|
||||
raise Exception("model not initialized!")
|
||||
raise ModuleNotFoundError("model not initialized!")
|
||||
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
|
|
@ -160,6 +163,51 @@ class YOLO:
|
|||
validator = self.ValidatorClass(args=args)
|
||||
validator(model=self.model)
|
||||
|
||||
@smart_inference_mode()
|
||||
def export(self, format='', save_dir='', **kwargs):
|
||||
"""
|
||||
Export model.
|
||||
|
||||
Args:
|
||||
format (str): Export format
|
||||
**kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in the docs
|
||||
"""
|
||||
|
||||
overrides = self.overrides.copy()
|
||||
overrides.update(kwargs)
|
||||
args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
|
||||
args.task = self.task
|
||||
args.format = format
|
||||
|
||||
file = self.ckpt or Path(Path(self.cfg).name)
|
||||
if save_dir:
|
||||
file = Path(save_dir) / file.name
|
||||
file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
export_model(
|
||||
model=self.model,
|
||||
file=file,
|
||||
data=args.data, # 'dataset.yaml path'
|
||||
imgsz=args.imgsz or (640, 640), # image (height, width)
|
||||
batch_size=1, # batch size
|
||||
device=args.device, # cuda device, i.e. 0 or 0,1,2,3 or cpu
|
||||
format=args.format, # include formats
|
||||
half=args.half or False, # FP16 half-precision export
|
||||
keras=args.keras or False, # use Keras
|
||||
optimize=args.optimize or False, # TorchScript: optimize for mobile
|
||||
int8=args.int8 or False, # CoreML/TF INT8 quantization
|
||||
dynamic=args.dynamic or False, # ONNX/TF/TensorRT: dynamic axes
|
||||
opset=args.opset or 17, # ONNX: opset version
|
||||
verbose=False, # TensorRT: verbose log
|
||||
workspace=args.workspace or 4, # TensorRT: workspace size (GB)
|
||||
nms=False, # TF: add NMS to model
|
||||
agnostic_nms=False, # TF: add agnostic NMS to model
|
||||
topk_per_class=100, # TF.js NMS: topk per class to keep
|
||||
topk_all=100, # TF.js NMS: topk for all classes to keep
|
||||
iou_thres=0.45, # TF.js NMS: IoU threshold
|
||||
conf_thres=0.25, # TF.js NMS: confidence threshold
|
||||
)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
Trains the model on given dataset.
|
||||
|
|
@ -178,7 +226,7 @@ class YOLO:
|
|||
overrides["task"] = self.task
|
||||
overrides["mode"] = "train"
|
||||
if not overrides.get("data"):
|
||||
raise AttributeError("dataset not provided! Please check if you have defined `data` in you configs")
|
||||
raise AttributeError("dataset not provided! Please define `data` in config.yaml or pass as an argument.")
|
||||
|
||||
self.trainer = self.TrainerClass(overrides=overrides)
|
||||
self.trainer.model = self.trainer.load_model(weights=self.ckpt,
|
||||
|
|
@ -189,11 +237,11 @@ class YOLO:
|
|||
|
||||
def resume(self, task=None, model=None):
|
||||
"""
|
||||
Resume a training task. Requires either `task` or `model`. `model` takes the higher precederence.
|
||||
Resume a training task. Requires either `task` or `model`. `model` takes the higher precedence.
|
||||
Args:
|
||||
task (str): The task type you want to resume. Automatically finds the last run to resume if `model` is not specified.
|
||||
model (str): The model checkpoint to resume from. If not found, the last run of the given task type is resumed.
|
||||
If `model` is speficied
|
||||
If `model` is specified
|
||||
"""
|
||||
if task:
|
||||
if task.lower() not in MODEL_MAP:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue