ultralytics 8.0.41 TF SavedModel and EdgeTPU export (#1034)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Noobtoss <96134731+Noobtoss@users.noreply.github.com>
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
Glenn Jocher 2023-02-20 01:27:28 +01:00 committed by GitHub
parent 4b866c9718
commit f6e393c1d2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
64 changed files with 604 additions and 351 deletions

View file

@ -29,14 +29,45 @@ MODEL_MAP = {
class YOLO:
"""
YOLO
YOLO (You Only Look Once) object detection model.
A python interface which emulates a model-like behaviour by wrapping trainers.
"""
Args:
model (str or Path): Path to the model file to load or create.
type (str): Type/version of models to use. Defaults to "v8".
Attributes:
type (str): Type/version of models being used.
ModelClass (Any): Model class.
TrainerClass (Any): Trainer class.
ValidatorClass (Any): Validator class.
PredictorClass (Any): Predictor class.
predictor (Any): Predictor object.
model (Any): Model object.
trainer (Any): Trainer object.
task (str): Type of model task.
ckpt (Any): Checkpoint object if model loaded from *.pt file.
cfg (str): Model configuration if loaded from *.yaml file.
ckpt_path (str): Checkpoint file path.
overrides (dict): Overrides for trainer object.
metrics_data (Any): Data for metrics.
Methods:
__call__(): Alias for predict method.
_new(cfg, verbose=True): Initializes a new model and infers the task type from the model definitions.
_load(weights): Initializes a new model and infers the task type from the model head.
_check_is_pytorch_model(): Raises TypeError if model is not a PyTorch model.
reset(): Resets the model modules.
info(verbose=False): Logs model info.
fuse(): Fuse model for faster inference.
predict(source=None, stream=False, **kwargs): Perform prediction using the YOLO model.
Returns:
List[ultralytics.yolo.engine.results.Results]: The prediction results.
"""
def __init__(self, model='yolov8n.pt', type='v8') -> None:
"""
Initializes the YOLO object.
Initializes the YOLO model.
Args:
model (str, Path): model to load or create
@ -97,11 +128,12 @@ class YOLO:
self.task = self.model.args['task']
self.overrides = self.model.args
self._reset_ckpt_args(self.overrides)
self.ckpt_path = self.model.pt_path
else:
check_file(weights)
weights = check_file(weights)
self.model, self.ckpt = weights, None
self.task = guess_model_task(weights)
self.ckpt_path = weights
self.ckpt_path = weights
self.overrides['model'] = weights
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
@ -204,7 +236,6 @@ class YOLO:
return validator.metrics
@smart_inference_mode()
def export(self, **kwargs):
"""
Export model.
@ -279,6 +310,13 @@ class YOLO:
"""
return self.model.names if hasattr(self.model, 'names') else None
@property
def device(self):
"""
Returns device if PyTorch model
"""
return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
@property
def transforms(self):
"""
@ -293,7 +331,6 @@ class YOLO:
"""
if not self.metrics_data:
LOGGER.info('No metrics data found! Run training or validation operation first.')
return self.metrics_data
@staticmethod
@ -306,7 +343,7 @@ class YOLO:
@staticmethod
def _reset_ckpt_args(args):
for arg in 'augment', 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', \
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots', 'opset':
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots', 'opset', 'simplify':
args.pop(arg, None)
@staticmethod