ultralytics 8.0.179 base Model class from nn.Module (#4911)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
c8de4fe634
commit
c17106db1f
7 changed files with 101 additions and 56 deletions
|
|
@ -16,7 +16,7 @@ import torch.distributed as dist
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, RANK, __version__
|
||||
from ultralytics.utils import DEFAULT_CFG_DICT, DEFAULT_CFG_KEYS, LOGGER, __version__
|
||||
from ultralytics.utils.checks import check_version
|
||||
|
||||
try:
|
||||
|
|
@ -60,13 +60,48 @@ def get_cpu_info():
|
|||
|
||||
|
||||
def select_device(device='', batch=0, newline=False, verbose=True):
|
||||
"""Selects PyTorch Device. Options are device = None or 'cpu' or 0 or '0' or '0,1,2,3'."""
|
||||
"""
|
||||
Selects the appropriate PyTorch device based on the provided arguments.
|
||||
|
||||
The function takes a string specifying the device or a torch.device object and returns a torch.device object
|
||||
representing the selected device. The function also validates the number of available devices and raises an
|
||||
exception if the requested device(s) are not available.
|
||||
|
||||
Args:
|
||||
device (str | torch.device, optional): Device string or torch.device object.
|
||||
Options are 'None', 'cpu', or 'cuda', or '0' or '0,1,2,3'. Defaults to an empty string, which auto-selects
|
||||
the first available GPU, or CPU if no GPU is available.
|
||||
batch (int, optional): Batch size being used in your model. Defaults to 0.
|
||||
newline (bool, optional): If True, adds a newline at the end of the log string. Defaults to False.
|
||||
verbose (bool, optional): If True, logs the device information. Defaults to True.
|
||||
|
||||
Returns:
|
||||
torch.device: Selected device.
|
||||
|
||||
Raises:
|
||||
ValueError: If the specified device is not available or if the batch size is not a multiple of the number of
|
||||
devices when using multiple GPUs.
|
||||
|
||||
Examples:
|
||||
>>> select_device('cuda:0')
|
||||
device(type='cuda', index=0)
|
||||
|
||||
>>> select_device('cpu')
|
||||
device(type='cpu')
|
||||
|
||||
Note:
|
||||
Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.
|
||||
"""
|
||||
|
||||
if isinstance(device, torch.device):
|
||||
return device
|
||||
|
||||
s = f'Ultralytics YOLOv{__version__} 🚀 Python-{platform.python_version()} torch-{torch.__version__} '
|
||||
device = str(device).lower()
|
||||
for remove in 'cuda:', 'none', '(', ')', '[', ']', "'", ' ':
|
||||
device = device.replace(remove, '') # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
|
||||
cpu = device == 'cpu'
|
||||
mps = device == 'mps' # Apple Metal Performance Shaders (MPS)
|
||||
mps = device in ('mps', 'mps:0') # Apple Metal Performance Shaders (MPS)
|
||||
if cpu or mps:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # force torch.cuda.is_available() = False
|
||||
elif device: # non-cpu device requested
|
||||
|
|
@ -105,7 +140,7 @@ def select_device(device='', batch=0, newline=False, verbose=True):
|
|||
s += f'CPU ({get_cpu_info()})\n'
|
||||
arg = 'cpu'
|
||||
|
||||
if verbose and RANK == -1:
|
||||
if verbose:
|
||||
LOGGER.info(s if newline else s.rstrip())
|
||||
return torch.device(arg)
|
||||
|
||||
|
|
@ -204,12 +239,15 @@ def model_info_for_loggers(trainer):
|
|||
"""
|
||||
Return model info dict with useful model information.
|
||||
|
||||
Example for YOLOv8n:
|
||||
{'model/parameters': 3151904,
|
||||
'model/GFLOPs': 8.746,
|
||||
'model/speed_ONNX(ms)': 41.244,
|
||||
'model/speed_TensorRT(ms)': 3.211,
|
||||
'model/speed_PyTorch(ms)': 18.755}
|
||||
Example:
|
||||
YOLOv8n info for loggers
|
||||
```python
|
||||
results = {'model/parameters': 3151904,
|
||||
'model/GFLOPs': 8.746,
|
||||
'model/speed_ONNX(ms)': 41.244,
|
||||
'model/speed_TensorRT(ms)': 3.211,
|
||||
'model/speed_PyTorch(ms)': 18.755}
|
||||
```
|
||||
"""
|
||||
if trainer.args.profile: # profile ONNX and TensorRT times
|
||||
from ultralytics.utils.benchmarks import ProfileModels
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue