ultralytics 8.0.158 add benchmarks to coverage (#4432)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Yonghye Kwon <developer.0hye@gmail.com>
This commit is contained in:
Glenn Jocher 2023-08-20 20:52:30 +02:00 committed by GitHub
parent 495806565d
commit 87ce15d383
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
51 changed files with 352 additions and 482 deletions

View file

@ -8,12 +8,24 @@ from ultralytics.data import build_dataloader, build_yolo_dataset
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import DetectionModel
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
from ultralytics.utils import LOGGER, RANK
from ultralytics.utils.plotting import plot_images, plot_labels, plot_results
from ultralytics.utils.torch_utils import de_parallel, torch_distributed_zero_first
class DetectionTrainer(BaseTrainer):
"""
A class extending the BaseTrainer class for training based on a detection model.
Example:
```python
from ultralytics.models.yolo.detect import DetectionTrainer
args = dict(model='yolov8n.pt', data='coco8.yaml', epochs=3)
trainer = DetectionTrainer(overrides=args)
trainer.train()
```
"""
def build_dataset(self, img_path, mode='train', batch=None):
"""
@ -102,22 +114,3 @@ class DetectionTrainer(BaseTrainer):
boxes = np.concatenate([lb['bboxes'] for lb in self.train_loader.dataset.labels], 0)
cls = np.concatenate([lb['cls'] for lb in self.train_loader.dataset.labels], 0)
plot_labels(boxes, cls.squeeze(), names=self.data['names'], save_dir=self.save_dir, on_plot=self.on_plot)
def train(cfg=DEFAULT_CFG, use_python=False):
"""Train and optimize YOLO model given training data and device."""
model = cfg.model or 'yolov8n.pt'
data = cfg.data or 'coco8.yaml' # or yolo.ClassificationDataset("mnist")
device = cfg.device if cfg.device is not None else ''
args = dict(model=model, data=data, device=device)
if use_python:
from ultralytics import YOLO
YOLO(model).train(**args)
else:
trainer = DetectionTrainer(overrides=args)
trainer.train()
if __name__ == '__main__':
train()