ultralytics 8.0.48 Edge TPU fix and Metrics updates (#1171)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: majid nasiri <majnasai@gmail.com>
This commit is contained in:
Glenn Jocher 2023-02-27 21:34:22 -08:00 committed by GitHub
parent a58f766f94
commit 74e4c94806
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 426 additions and 245 deletions

View file

@ -43,7 +43,7 @@ class YOLO:
cfg (str): The model configuration if loaded from *.yaml file.
ckpt_path (str): The checkpoint file path.
overrides (dict): Overrides for the trainer object.
metrics_data (Any): The data for metrics.
metrics (Any): The data for metrics.
Methods:
__call__(source=None, stream=False, **kwargs):
@ -67,7 +67,7 @@ class YOLO:
list(ultralytics.yolo.engine.results.Results): The prediction results.
"""
def __init__(self, model='yolov8n.pt', task=None) -> None:
def __init__(self, model='yolov8n.pt', task=None, session=None) -> None:
"""
Initializes the YOLO model.
@ -83,7 +83,8 @@ class YOLO:
self.cfg = None # if loaded from *.yaml
self.ckpt_path = None
self.overrides = {} # overrides for trainer object
self.metrics_data = None
self.metrics = None # validation/training metrics
self.session = session # HUB session
# Load or create new YOLO model
suffix = Path(model).suffix
@ -184,6 +185,7 @@ class YOLO:
self._check_is_pytorch_model()
self.model.fuse()
@smart_inference_mode()
def predict(self, source=None, stream=False, **kwargs):
"""
Perform prediction using the YOLO model.
@ -217,7 +219,6 @@ class YOLO:
is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
@smart_inference_mode()
def track(self, source=None, stream=False, **kwargs):
from ultralytics.tracker import register_tracker
register_tracker(self)
@ -252,7 +253,7 @@ class YOLO:
validator = TASK_MAP[self.task][2](args=args)
validator(model=self.model)
self.metrics_data = validator.metrics
self.metrics = validator.metrics
return validator.metrics
@ -314,12 +315,13 @@ class YOLO:
if not overrides.get('resume'): # manually set model only if not resuming
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
self.model = self.trainer.model
self.trainer.hub_session = self.session # attach optional HUB session
self.trainer.train()
# update model and cfg after training
if RANK in {0, -1}:
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
self.overrides = self.model.args
self.metrics_data = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
def to(self, device):
"""
@ -352,15 +354,6 @@ class YOLO:
"""
return self.model.transforms if hasattr(self.model, 'transforms') else None
@property
def metrics(self):
"""
Returns metrics if computed
"""
if not self.metrics_data:
LOGGER.info('No metrics data found! Run training or validation operation first.')
return self.metrics_data
@staticmethod
def add_callback(event: str, func):
"""