ultralytics 8.0.188 fix .grad attribute leaf Tensor Warning (#5094)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-09-26 20:28:45 +02:00 committed by GitHub
parent f2ed207571
commit 19c3314e68
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 78 additions and 41 deletions

View file

@ -8,8 +8,7 @@ from typing import Union
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
from ultralytics.hub.utils import HUB_WEB_ROOT
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, emojis, yaml_load
from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, checks, emojis, yaml_load
from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
@ -139,7 +138,7 @@ class Model(nn.Module):
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
self.ckpt_path = self.model.pt_path
else:
weights = check_file(weights)
weights = checks.check_file(weights)
self.model, self.ckpt = weights, None
self.task = task or guess_model_task(weights)
self.ckpt_path = weights
@ -204,11 +203,11 @@ class Model(nn.Module):
Args:
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
Accepts all source types accepted by the YOLO model.
Accepts all source types accepted by the YOLO model.
stream (bool): Whether to stream the predictions or not. Defaults to False.
predictor (BasePredictor): Customized predictor.
**kwargs : Additional keyword arguments passed to the predictor.
Check the 'configuration' section in the documentation for all available options.
Check the 'configuration' section in the documentation for all available options.
Returns:
(List[ultralytics.engine.results.Results]): The prediction results.
@ -251,8 +250,7 @@ class Model(nn.Module):
if not hasattr(self.predictor, 'trackers'):
from ultralytics.trackers import register_tracker
register_tracker(self, persist)
# ByteTrack-based method needs low confidence predictions as input
kwargs['conf'] = kwargs.get('conf') or 0.1
kwargs['conf'] = kwargs.get('conf') or 0.1 # ByteTrack-based method needs low confidence predictions as input
kwargs['mode'] = 'track'
return self.predict(source=source, stream=stream, **kwargs)
@ -266,7 +264,6 @@ class Model(nn.Module):
"""
custom = {'rect': True} # method defaults
args = {**self.overrides, **custom, **kwargs, 'mode': 'val'} # highest priority args on the right
args['imgsz'] = check_imgsz(args['imgsz'], max_dim=1)
validator = (validator or self._smart_load('validator'))(args=args, _callbacks=self.callbacks)
validator(model=self.model)
@ -321,9 +318,9 @@ class Model(nn.Module):
if any(kwargs):
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
kwargs = self.session.train_args
check_pip_update_available()
checks.check_pip_update_available()
overrides = yaml_load(check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides
overrides = yaml_load(checks.check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides
custom = {'data': TASK2DATA[self.task]} # method defaults
args = {**overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right
if args.get('resume'):
@ -366,7 +363,7 @@ class Model(nn.Module):
self._check_is_pytorch_model()
self = super()._apply(fn) # noqa
self.predictor = None # reset predictor as device may have changed
self.overrides['device'] = str(self.device) # i.e. device(type='cuda', index=0) -> 'cuda:0'
self.overrides['device'] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0'
return self
@property