ultralytics 8.0.67 Pose speeds, Comet and ClearML updates (#1871)
Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Victor Sonck <victor.sonck@gmail.com> Co-authored-by: Danny Kim <dh031200@gmail.com>
This commit is contained in:
parent
1cb92d7f42
commit
2725545090
28 changed files with 547 additions and 146 deletions
|
|
@ -78,7 +78,7 @@ class YOLO:
|
|||
task (Any, optional): Task type for the YOLO model. Defaults to None.
|
||||
|
||||
"""
|
||||
self._reset_callbacks()
|
||||
self.callbacks = callbacks.get_default_callbacks()
|
||||
self.predictor = None # reuse predictor
|
||||
self.model = None # model object
|
||||
self.trainer = None # trainer object
|
||||
|
|
@ -238,7 +238,7 @@ class YOLO:
|
|||
overrides['save'] = kwargs.get('save', False) # not save files by default
|
||||
if not self.predictor:
|
||||
self.task = overrides.get('task') or self.task
|
||||
self.predictor = TASK_MAP[self.task][3](overrides=overrides)
|
||||
self.predictor = TASK_MAP[self.task][3](overrides=overrides, _callbacks=self.callbacks)
|
||||
self.predictor.setup_model(model=self.model, verbose=is_cli)
|
||||
else: # only update args if predictor is already setup
|
||||
self.predictor.args = get_cfg(self.predictor.args, overrides)
|
||||
|
|
@ -277,7 +277,7 @@ class YOLO:
|
|||
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
args.imgsz = check_imgsz(args.imgsz, max_dim=1)
|
||||
|
||||
validator = TASK_MAP[self.task][2](args=args)
|
||||
validator = TASK_MAP[self.task][2](args=args, _callbacks=self.callbacks)
|
||||
validator(model=self.model)
|
||||
self.metrics = validator.metrics
|
||||
|
||||
|
|
@ -316,7 +316,7 @@ class YOLO:
|
|||
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
|
||||
if args.batch == DEFAULT_CFG.batch:
|
||||
args.batch = 1 # default to 1 if not modified
|
||||
return Exporter(overrides=args)(model=self.model)
|
||||
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
|
|
@ -344,7 +344,7 @@ class YOLO:
|
|||
overrides['resume'] = self.ckpt_path
|
||||
|
||||
self.task = overrides.get('task') or self.task
|
||||
self.trainer = TASK_MAP[self.task][1](overrides=overrides)
|
||||
self.trainer = TASK_MAP[self.task][1](overrides=overrides, _callbacks=self.callbacks)
|
||||
if not overrides.get('resume'): # manually set model only if not resuming
|
||||
self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
|
||||
self.model = self.trainer.model
|
||||
|
|
@ -387,19 +387,17 @@ class YOLO:
|
|||
"""
|
||||
return self.model.transforms if hasattr(self.model, 'transforms') else None
|
||||
|
||||
@staticmethod
|
||||
def add_callback(event: str, func):
|
||||
def add_callback(self, event: str, func):
|
||||
"""
|
||||
Add callback
|
||||
"""
|
||||
callbacks.default_callbacks[event].append(func)
|
||||
self.callbacks[event].append(func)
|
||||
|
||||
@staticmethod
|
||||
def _reset_ckpt_args(args):
|
||||
include = {'imgsz', 'data', 'task', 'single_cls'} # only remember these arguments when loading a PyTorch model
|
||||
return {k: v for k, v in args.items() if k in include}
|
||||
|
||||
@staticmethod
|
||||
def _reset_callbacks():
|
||||
def _reset_callbacks(self):
|
||||
for event in callbacks.default_callbacks.keys():
|
||||
callbacks.default_callbacks[event] = [callbacks.default_callbacks[event][0]]
|
||||
self.callbacks[event] = [callbacks.default_callbacks[event][0]]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue