Fix model re-fuse() in inference loops (#466)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ayush Chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
parent
cc3c774bde
commit
a86218b767
22 changed files with 135 additions and 66 deletions
|
|
@ -7,7 +7,7 @@ from ultralytics.nn.tasks import ClassificationModel, DetectionModel, Segmentati
|
|||
from ultralytics.yolo.configs import get_config
|
||||
from ultralytics.yolo.engine.exporter import Exporter
|
||||
from ultralytics.yolo.utils import DEFAULT_CONFIG, LOGGER, yaml_load
|
||||
from ultralytics.yolo.utils.checks import check_imgsz, check_yaml
|
||||
from ultralytics.yolo.utils.checks import check_yaml
|
||||
from ultralytics.yolo.utils.torch_utils import guess_task_from_head, smart_inference_mode
|
||||
|
||||
# Map head to model, trainer, validator, and predictor classes
|
||||
|
|
@ -43,6 +43,7 @@ class YOLO:
|
|||
self.TrainerClass = None # trainer class
|
||||
self.ValidatorClass = None # validator class
|
||||
self.PredictorClass = None # predictor class
|
||||
self.predictor = None # reuse predictor
|
||||
self.model = None # model object
|
||||
self.trainer = None # trainer object
|
||||
self.task = None # task type
|
||||
|
|
@ -131,11 +132,12 @@ class YOLO:
|
|||
overrides.update(kwargs)
|
||||
overrides["mode"] = "predict"
|
||||
overrides["save"] = kwargs.get("save", False) # not save files by default
|
||||
predictor = self.PredictorClass(overrides=overrides)
|
||||
|
||||
predictor.args.imgsz = check_imgsz(predictor.args.imgsz, min_dim=2) # check image size
|
||||
predictor.setup(model=self.model, source=source)
|
||||
return predictor(stream=stream, verbose=verbose)
|
||||
if not self.predictor:
|
||||
self.predictor = self.PredictorClass(overrides=overrides)
|
||||
self.predictor.setup_model(model=self.model)
|
||||
else: # only update args if predictor is already setup
|
||||
self.predictor.args = get_config(self.predictor.args, overrides)
|
||||
return self.predictor(source=source, stream=stream, verbose=verbose)
|
||||
|
||||
@smart_inference_mode()
|
||||
def val(self, data=None, **kwargs):
|
||||
|
|
@ -170,6 +172,7 @@ class YOLO:
|
|||
args = get_config(config=DEFAULT_CONFIG, overrides=overrides)
|
||||
args.task = self.task
|
||||
|
||||
print(args)
|
||||
exporter = Exporter(overrides=args)
|
||||
exporter(model=self.model)
|
||||
|
||||
|
|
@ -224,10 +227,14 @@ class YOLO:
|
|||
def _reset_ckpt_args(args):
|
||||
args.pop("project", None)
|
||||
args.pop("name", None)
|
||||
args.pop("exist_ok", None)
|
||||
args.pop("resume", None)
|
||||
args.pop("batch", None)
|
||||
args.pop("epochs", None)
|
||||
args.pop("cache", None)
|
||||
args.pop("save_json", None)
|
||||
args.pop("half", None)
|
||||
args.pop("v5loader", None)
|
||||
|
||||
# set device to '' to prevent from auto DDP usage
|
||||
args["device"] = ''
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue