ultralytics 8.0.41 TF SavedModel and EdgeTPU export (#1034)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Noobtoss <96134731+Noobtoss@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
parent
4b866c9718
commit
f6e393c1d2
64 changed files with 604 additions and 351 deletions
|
|
@ -29,14 +29,45 @@ MODEL_MAP = {
|
|||
|
||||
class YOLO:
|
||||
"""
|
||||
YOLO
|
||||
YOLO (You Only Look Once) object detection model.
|
||||
|
||||
A python interface which emulates a model-like behaviour by wrapping trainers.
|
||||
"""
|
||||
Args:
|
||||
model (str or Path): Path to the model file to load or create.
|
||||
type (str): Type/version of models to use. Defaults to "v8".
|
||||
|
||||
Attributes:
|
||||
type (str): Type/version of models being used.
|
||||
ModelClass (Any): Model class.
|
||||
TrainerClass (Any): Trainer class.
|
||||
ValidatorClass (Any): Validator class.
|
||||
PredictorClass (Any): Predictor class.
|
||||
predictor (Any): Predictor object.
|
||||
model (Any): Model object.
|
||||
trainer (Any): Trainer object.
|
||||
task (str): Type of model task.
|
||||
ckpt (Any): Checkpoint object if model loaded from *.pt file.
|
||||
cfg (str): Model configuration if loaded from *.yaml file.
|
||||
ckpt_path (str): Checkpoint file path.
|
||||
overrides (dict): Overrides for trainer object.
|
||||
metrics_data (Any): Data for metrics.
|
||||
|
||||
Methods:
|
||||
__call__(): Alias for predict method.
|
||||
_new(cfg, verbose=True): Initializes a new model and infers the task type from the model definitions.
|
||||
_load(weights): Initializes a new model and infers the task type from the model head.
|
||||
_check_is_pytorch_model(): Raises TypeError if model is not a PyTorch model.
|
||||
reset(): Resets the model modules.
|
||||
info(verbose=False): Logs model info.
|
||||
fuse(): Fuse model for faster inference.
|
||||
predict(source=None, stream=False, **kwargs): Perform prediction using the YOLO model.
|
||||
|
||||
Returns:
|
||||
List[ultralytics.yolo.engine.results.Results]: The prediction results.
|
||||
"""
|
||||
|
||||
def __init__(self, model='yolov8n.pt', type='v8') -> None:
|
||||
"""
|
||||
Initializes the YOLO object.
|
||||
Initializes the YOLO model.
|
||||
|
||||
Args:
|
||||
model (str, Path): model to load or create
|
||||
|
|
@ -97,11 +128,12 @@ class YOLO:
|
|||
self.task = self.model.args['task']
|
||||
self.overrides = self.model.args
|
||||
self._reset_ckpt_args(self.overrides)
|
||||
self.ckpt_path = self.model.pt_path
|
||||
else:
|
||||
check_file(weights)
|
||||
weights = check_file(weights)
|
||||
self.model, self.ckpt = weights, None
|
||||
self.task = guess_model_task(weights)
|
||||
self.ckpt_path = weights
|
||||
self.ckpt_path = weights
|
||||
self.overrides['model'] = weights
|
||||
self.ModelClass, self.TrainerClass, self.ValidatorClass, self.PredictorClass = self._assign_ops_from_task()
|
||||
|
||||
|
|
@ -204,7 +236,6 @@ class YOLO:
|
|||
|
||||
return validator.metrics
|
||||
|
||||
@smart_inference_mode()
|
||||
def export(self, **kwargs):
|
||||
"""
|
||||
Export model.
|
||||
|
|
@ -279,6 +310,13 @@ class YOLO:
|
|||
"""
|
||||
return self.model.names if hasattr(self.model, 'names') else None
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
"""
|
||||
Returns device if PyTorch model
|
||||
"""
|
||||
return next(self.model.parameters()).device if isinstance(self.model, nn.Module) else None
|
||||
|
||||
@property
|
||||
def transforms(self):
|
||||
"""
|
||||
|
|
@ -293,7 +331,6 @@ class YOLO:
|
|||
"""
|
||||
if not self.metrics_data:
|
||||
LOGGER.info('No metrics data found! Run training or validation operation first.')
|
||||
|
||||
return self.metrics_data
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -306,7 +343,7 @@ class YOLO:
|
|||
@staticmethod
|
||||
def _reset_ckpt_args(args):
|
||||
for arg in 'augment', 'verbose', 'project', 'name', 'exist_ok', 'resume', 'batch', 'epochs', 'cache', \
|
||||
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots', 'opset':
|
||||
'save_json', 'half', 'v5loader', 'device', 'cfg', 'save', 'rect', 'plots', 'opset', 'simplify':
|
||||
args.pop(arg, None)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue