ultralytics 8.0.188 fix .grad attribute leaf Tensor Warning (#5094)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
f2ed207571
commit
19c3314e68
11 changed files with 78 additions and 41 deletions
|
|
@ -8,8 +8,7 @@ from typing import Union
|
|||
from ultralytics.cfg import TASK2DATA, get_cfg, get_save_dir
|
||||
from ultralytics.hub.utils import HUB_WEB_ROOT
|
||||
from ultralytics.nn.tasks import attempt_load_one_weight, guess_model_task, nn, yaml_model_load
|
||||
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, emojis, yaml_load
|
||||
from ultralytics.utils.checks import check_file, check_imgsz, check_pip_update_available, check_yaml
|
||||
from ultralytics.utils import ASSETS, DEFAULT_CFG_DICT, LOGGER, RANK, callbacks, checks, emojis, yaml_load
|
||||
from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
|
||||
|
||||
|
||||
|
|
@ -139,7 +138,7 @@ class Model(nn.Module):
|
|||
self.overrides = self.model.args = self._reset_ckpt_args(self.model.args)
|
||||
self.ckpt_path = self.model.pt_path
|
||||
else:
|
||||
weights = check_file(weights)
|
||||
weights = checks.check_file(weights)
|
||||
self.model, self.ckpt = weights, None
|
||||
self.task = task or guess_model_task(weights)
|
||||
self.ckpt_path = weights
|
||||
|
|
@ -204,11 +203,11 @@ class Model(nn.Module):
|
|||
|
||||
Args:
|
||||
source (str | int | PIL | np.ndarray): The source of the image to make predictions on.
|
||||
Accepts all source types accepted by the YOLO model.
|
||||
Accepts all source types accepted by the YOLO model.
|
||||
stream (bool): Whether to stream the predictions or not. Defaults to False.
|
||||
predictor (BasePredictor): Customized predictor.
|
||||
**kwargs : Additional keyword arguments passed to the predictor.
|
||||
Check the 'configuration' section in the documentation for all available options.
|
||||
Check the 'configuration' section in the documentation for all available options.
|
||||
|
||||
Returns:
|
||||
(List[ultralytics.engine.results.Results]): The prediction results.
|
||||
|
|
@ -251,8 +250,7 @@ class Model(nn.Module):
|
|||
if not hasattr(self.predictor, 'trackers'):
|
||||
from ultralytics.trackers import register_tracker
|
||||
register_tracker(self, persist)
|
||||
# ByteTrack-based method needs low confidence predictions as input
|
||||
kwargs['conf'] = kwargs.get('conf') or 0.1
|
||||
kwargs['conf'] = kwargs.get('conf') or 0.1 # ByteTrack-based method needs low confidence predictions as input
|
||||
kwargs['mode'] = 'track'
|
||||
return self.predict(source=source, stream=stream, **kwargs)
|
||||
|
||||
|
|
@ -266,7 +264,6 @@ class Model(nn.Module):
|
|||
"""
|
||||
custom = {'rect': True} # method defaults
|
||||
args = {**self.overrides, **custom, **kwargs, 'mode': 'val'} # highest priority args on the right
|
||||
args['imgsz'] = check_imgsz(args['imgsz'], max_dim=1)
|
||||
|
||||
validator = (validator or self._smart_load('validator'))(args=args, _callbacks=self.callbacks)
|
||||
validator(model=self.model)
|
||||
|
|
@ -321,9 +318,9 @@ class Model(nn.Module):
|
|||
if any(kwargs):
|
||||
LOGGER.warning('WARNING ⚠️ using HUB training arguments, ignoring local training arguments.')
|
||||
kwargs = self.session.train_args
|
||||
check_pip_update_available()
|
||||
checks.check_pip_update_available()
|
||||
|
||||
overrides = yaml_load(check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides
|
||||
overrides = yaml_load(checks.check_yaml(kwargs['cfg'])) if kwargs.get('cfg') else self.overrides
|
||||
custom = {'data': TASK2DATA[self.task]} # method defaults
|
||||
args = {**overrides, **custom, **kwargs, 'mode': 'train'} # highest priority args on the right
|
||||
if args.get('resume'):
|
||||
|
|
@ -366,7 +363,7 @@ class Model(nn.Module):
|
|||
self._check_is_pytorch_model()
|
||||
self = super()._apply(fn) # noqa
|
||||
self.predictor = None # reset predictor as device may have changed
|
||||
self.overrides['device'] = str(self.device) # i.e. device(type='cuda', index=0) -> 'cuda:0'
|
||||
self.overrides['device'] = self.device # was str(self.device) i.e. device(type='cuda', index=0) -> 'cuda:0'
|
||||
return self
|
||||
|
||||
@property
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue