ultralytics 8.2.29 new fractional AutoBatch feature (#13446)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
2fe0946376
commit
6a234f3639
12 changed files with 92 additions and 49 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
|
||||
__version__ = "8.2.28"
|
||||
__version__ = "8.2.29"
|
||||
|
||||
import os
|
||||
|
||||
|
|
|
|||
|
|
@ -95,10 +95,19 @@ CLI_HELP_MSG = f"""
|
|||
"""
|
||||
|
||||
# Define keys for arg type checks
|
||||
CFG_FLOAT_KEYS = {"warmup_epochs", "box", "cls", "dfl", "degrees", "shear", "time", "workspace"}
|
||||
CFG_FRACTION_KEYS = {
|
||||
CFG_FLOAT_KEYS = { # integer or float arguments, i.e. x=2 and x=2.0
|
||||
"warmup_epochs",
|
||||
"box",
|
||||
"cls",
|
||||
"dfl",
|
||||
"degrees",
|
||||
"shear",
|
||||
"time",
|
||||
"workspace",
|
||||
"batch",
|
||||
}
|
||||
CFG_FRACTION_KEYS = { # fractional float arguments with 0.0<=values<=1.0
|
||||
"dropout",
|
||||
"iou",
|
||||
"lr0",
|
||||
"lrf",
|
||||
"momentum",
|
||||
|
|
@ -121,11 +130,10 @@ CFG_FRACTION_KEYS = {
|
|||
"conf",
|
||||
"iou",
|
||||
"fraction",
|
||||
} # fraction floats 0.0 - 1.0
|
||||
CFG_INT_KEYS = {
|
||||
}
|
||||
CFG_INT_KEYS = { # integer-only arguments
|
||||
"epochs",
|
||||
"patience",
|
||||
"batch",
|
||||
"workers",
|
||||
"seed",
|
||||
"close_mosaic",
|
||||
|
|
@ -136,7 +144,7 @@ CFG_INT_KEYS = {
|
|||
"nbs",
|
||||
"save_period",
|
||||
}
|
||||
CFG_BOOL_KEYS = {
|
||||
CFG_BOOL_KEYS = { # boolean-only arguments
|
||||
"save",
|
||||
"exist_ok",
|
||||
"verbose",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
## Models
|
||||
|
||||
Welcome to the Ultralytics Models directory! Here you will find a wide variety of pre-configured model configuration files (`*.yaml`s) that can be used to create custom YOLO models. The models in this directory have been expertly crafted and fine-tuned by the Ultralytics team to provide the best performance for a wide range of object detection and image segmentation tasks.
|
||||
Welcome to the [Ultralytics](https://ultralytics.com) Models directory! Here you will find a wide variety of pre-configured model configuration files (`*.yaml`s) that can be used to create custom YOLO models. The models in this directory have been expertly crafted and fine-tuned by the Ultralytics team to provide the best performance for a wide range of object detection and image segmentation tasks.
|
||||
|
||||
These model configurations cover a wide range of scenarios, from simple object detection to more complex tasks like instance segmentation and object tracking. They are also designed to run efficiently on a variety of hardware platforms, from CPUs to GPUs. Whether you are a seasoned machine learning practitioner or just getting started with YOLO, this directory provides a great starting point for your custom model development needs.
|
||||
|
||||
|
|
@ -8,26 +8,34 @@ To get started, simply browse through the models in this directory and find one
|
|||
|
||||
### Usage
|
||||
|
||||
Model `*.yaml` files may be used directly in the Command Line Interface (CLI) with a `yolo` command:
|
||||
Model `*.yaml` files may be used directly in the [Command Line Interface (CLI)](https://docs.ultralytics.com/usage/cli) with a `yolo` command:
|
||||
|
||||
```bash
|
||||
# Train a YOLOv8n model using the coco8 dataset for 100 epochs
|
||||
yolo task=detect mode=train model=yolov8n.yaml data=coco8.yaml epochs=100
|
||||
```
|
||||
|
||||
They may also be used directly in a Python environment, and accepts the same [arguments](https://docs.ultralytics.com/usage/cfg/) as in the CLI example above:
|
||||
They may also be used directly in a Python environment, and accept the same [arguments](https://docs.ultralytics.com/usage/cfg/) as in the CLI example above:
|
||||
|
||||
```python
|
||||
from ultralytics import YOLO
|
||||
|
||||
model = YOLO("model.yaml") # build a YOLOv8n model from scratch
|
||||
# YOLO("model.pt") use pre-trained model if available
|
||||
model.info() # display model information
|
||||
model.train(data="coco8.yaml", epochs=100) # train the model
|
||||
# Initialize a YOLOv8n model from a YAML configuration file
|
||||
model = YOLO("model.yaml")
|
||||
|
||||
# If a pre-trained model is available, use it instead
|
||||
# model = YOLO("model.pt")
|
||||
|
||||
# Display model information
|
||||
model.info()
|
||||
|
||||
# Train the model using the COCO8 dataset for 100 epochs
|
||||
model.train(data="coco8.yaml", epochs=100)
|
||||
```
|
||||
|
||||
## Pre-trained Model Architectures
|
||||
|
||||
Ultralytics supports many model architectures. Visit https://docs.ultralytics.com/models to view detailed information and usage. Any of these models can be used by loading their configs or pretrained checkpoints if available.
|
||||
Ultralytics supports many model architectures. Visit [Ultralytics Models](https://docs.ultralytics.com/models) to view detailed information and usage. Any of these models can be used by loading their configurations or pretrained checkpoints if available.
|
||||
|
||||
## Contribute New Models
|
||||
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ def gd_outputs(gd):
|
|||
|
||||
|
||||
def try_export(inner_func):
|
||||
"""YOLOv8 export decorator, i..e @try_export."""
|
||||
"""YOLOv8 export decorator, i.e. @try_export."""
|
||||
inner_args = get_default_args(inner_func)
|
||||
|
||||
def outer_func(*args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -269,8 +269,13 @@ class BaseTrainer:
|
|||
self.stride = gs # for multiscale training
|
||||
|
||||
# Batch size
|
||||
if self.batch_size == -1 and RANK == -1: # single-GPU only, estimate best batch size
|
||||
self.args.batch = self.batch_size = check_train_batch_size(self.model, self.args.imgsz, self.amp)
|
||||
if self.batch_size < 1 and RANK == -1: # single-GPU only, estimate best batch size
|
||||
self.args.batch = self.batch_size = check_train_batch_size(
|
||||
model=self.model,
|
||||
imgsz=self.args.imgsz,
|
||||
amp=self.amp,
|
||||
batch=self.batch_size,
|
||||
)
|
||||
|
||||
# Dataloaders
|
||||
batch_size = self.batch_size // max(world_size, 1)
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ class YOLOWorld(Model):
|
|||
Set classes.
|
||||
|
||||
Args:
|
||||
classes (List(str)): A list of categories i.e ["person"].
|
||||
classes (List(str)): A list of categories i.e. ["person"].
|
||||
"""
|
||||
self.model.set_classes(classes)
|
||||
# Remove background if it's given
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr
|
|||
from ultralytics.utils.torch_utils import profile
|
||||
|
||||
|
||||
def check_train_batch_size(model, imgsz=640, amp=True):
|
||||
def check_train_batch_size(model, imgsz=640, amp=True, batch=-1):
|
||||
"""
|
||||
Check YOLO training batch size using the autobatch() function.
|
||||
Compute optimal YOLO training batch size using the autobatch() function.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): YOLO model to check batch size for.
|
||||
|
|
@ -24,7 +24,7 @@ def check_train_batch_size(model, imgsz=640, amp=True):
|
|||
"""
|
||||
|
||||
with torch.cuda.amp.autocast(amp):
|
||||
return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
|
||||
return autobatch(deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6)
|
||||
|
||||
|
||||
def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
|
||||
|
|
@ -43,7 +43,7 @@ def autobatch(model, imgsz=640, fraction=0.60, batch_size=DEFAULT_CFG.batch):
|
|||
|
||||
# Check device
|
||||
prefix = colorstr("AutoBatch: ")
|
||||
LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz}")
|
||||
LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz} at {fraction * 100}% CUDA memory utilization.")
|
||||
device = next(model.parameters()).device # get model device
|
||||
if device.type == "cpu":
|
||||
LOGGER.info(f"{prefix}CUDA not detected, using default CPU batch-size {batch_size}")
|
||||
|
|
|
|||
|
|
@ -146,11 +146,17 @@ def select_device(device="", batch=0, newline=False, verbose=True):
|
|||
if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
|
||||
devices = device.split(",") if device else "0" # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
||||
n = len(devices) # device count
|
||||
if n > 1 and batch > 0 and batch % n != 0: # check batch_size is divisible by device_count
|
||||
raise ValueError(
|
||||
f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
|
||||
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
|
||||
)
|
||||
if n > 1: # multi-GPU
|
||||
if batch < 1:
|
||||
raise ValueError(
|
||||
"AutoBatch with batch<1 not supported for Multi-GPU training, "
|
||||
"please specify a valid batch size, i.e. batch=16."
|
||||
)
|
||||
if batch >= 0 and batch % n != 0: # check batch_size is divisible by device_count
|
||||
raise ValueError(
|
||||
f"'batch={batch}' must be a multiple of GPU count {n}. Try 'batch={batch // n * n}' or "
|
||||
f"'batch={batch // n * n + n}', the nearest batch sizes evenly divisible by {n}."
|
||||
)
|
||||
space = " " * (len(s) + 1)
|
||||
for i, d in enumerate(devices):
|
||||
p = torch.cuda.get_device_properties(i)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue