ultralytics 8.0.48 Edge TPU fix and Metrics updates (#1171)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: majid nasiri <majnasai@gmail.com>
This commit is contained in:
parent
a58f766f94
commit
74e4c94806
23 changed files with 426 additions and 245 deletions
|
|
@ -43,7 +43,7 @@ class YOLO:
|
|||
cfg (str): The model configuration if loaded from *.yaml file.
|
||||
ckpt_path (str): The checkpoint file path.
|
||||
overrides (dict): Overrides for the trainer object.
|
||||
metrics_data (Any): The data for metrics.
|
||||
metrics (Any): The data for metrics.
|
||||
|
||||
Methods:
|
||||
__call__(source=None, stream=False, **kwargs):
|
||||
|
|
@ -67,7 +67,7 @@ class YOLO:
|
|||
list(ultralytics.yolo.engine.results.Results): The prediction results.
|
||||
"""
|
||||
|
||||
def __init__(self, model='yolov8n.pt', task=None) -> None:
|
||||
def __init__(self, model='yolov8n.pt', task=None, session=None) -> None:
|
||||
"""
|
||||
Initializes the YOLO model.
|
||||
|
||||
|
|
@ -83,7 +83,8 @@ class YOLO:
|
|||
self.cfg = None # if loaded from *.yaml
|
||||
self.ckpt_path = None
|
||||
self.overrides = {} # overrides for trainer object
|
||||
self.metrics_data = None
|
||||
self.metrics = None # validation/training metrics
|
||||
self.session = session # HUB session
|
||||
|
||||
# Load or create new YOLO model
|
||||
suffix = Path(model).suffix
|
||||
|
|
@ -184,6 +185,7 @@ class YOLO:
|
|||
self._check_is_pytorch_model()
|
||||
self.model.fuse()
|
||||
|
||||
@smart_inference_mode()
|
||||
def predict(self, source=None, stream=False, **kwargs):
|
||||
"""
|
||||
Perform prediction using the YOLO model.
|
||||
|
|
@ -217,7 +219,6 @@ class YOLO:
|
|||
is_cli = sys.argv[0].endswith('yolo') or sys.argv[0].endswith('ultralytics')
|
||||
return self.predictor.predict_cli(source=source) if is_cli else self.predictor(source=source, stream=stream)
|
||||
|
||||
@smart_inference_mode()
|
||||
def track(self, source=None, stream=False, **kwargs):
|
||||
from ultralytics.tracker import register_tracker
|
||||
register_tracker(self)
|
||||
|
|
@ -252,7 +253,7 @@ class YOLO:
|
|||
|
||||
validator = TASK_MAP[self.task][2](args=args)
|
||||
validator(model=self.model)
|
||||
self.metrics_data = validator.metrics
|
||||
self.metrics = validator.metrics
|
||||
|
||||
return validator.metrics
|
||||
|
||||
|
|
@ -314,12 +315,13 @@ class YOLO:
|
|||
if not overrides.get('resume'): # manually set model only if not resuming
|
||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||
self.model = self.trainer.model
|
||||
self.trainer.hub_session = self.session # attach optional HUB session
|
||||
self.trainer.train()
|
||||
# update model and cfg after training
|
||||
if RANK in {0, -1}:
|
||||
self.model, _ = attempt_load_one_weight(str(self.trainer.best))
|
||||
self.overrides = self.model.args
|
||||
self.metrics_data = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
|
||||
self.metrics = getattr(self.trainer.validator, 'metrics', None) # TODO: no metrics returned by DDP
|
||||
|
||||
def to(self, device):
|
||||
"""
|
||||
|
|
@ -352,15 +354,6 @@ class YOLO:
|
|||
"""
|
||||
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
||||
|
||||
@property
|
||||
def metrics(self):
|
||||
"""
|
||||
Returns metrics if computed
|
||||
"""
|
||||
if not self.metrics_data:
|
||||
LOGGER.info('No metrics data found! Run training or validation operation first.')
|
||||
return self.metrics_data
|
||||
|
||||
@staticmethod
|
||||
def add_callback(event: str, func):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue