Add docformatter to pre-commit (#5279)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Burhan <62214284+Burhan-Q@users.noreply.github.com>
This commit is contained in:
parent
c7aa83da31
commit
7517667a33
90 changed files with 1396 additions and 497 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# Ultralytics YOLO 🚀, AGPL-3.0 license
|
||||
"""
|
||||
Train a model on a dataset
|
||||
Train a model on a dataset.
|
||||
|
||||
Usage:
|
||||
$ yolo mode=train model=yolov8n.pt data=coco128.yaml imgsz=640 epochs=100 batch=16
|
||||
|
|
@ -37,7 +37,7 @@ from ultralytics.utils.torch_utils import (EarlyStopping, ModelEMA, de_parallel,
|
|||
|
||||
class BaseTrainer:
|
||||
"""
|
||||
BaseTrainer
|
||||
BaseTrainer.
|
||||
|
||||
A base class for creating trainers.
|
||||
|
||||
|
|
@ -143,15 +143,11 @@ class BaseTrainer:
|
|||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""
|
||||
Appends the given callback.
|
||||
"""
|
||||
"""Appends the given callback."""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def set_callback(self, event: str, callback):
|
||||
"""
|
||||
Overrides the existing callbacks with the given callback.
|
||||
"""
|
||||
"""Overrides the existing callbacks with the given callback."""
|
||||
self.callbacks[event] = [callback]
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
|
|
@ -207,9 +203,7 @@ class BaseTrainer:
|
|||
world_size=world_size)
|
||||
|
||||
def _setup_train(self, world_size):
|
||||
"""
|
||||
Builds dataloaders and optimizer on correct rank process.
|
||||
"""
|
||||
"""Builds dataloaders and optimizer on correct rank process."""
|
||||
|
||||
# Model
|
||||
self.run_callbacks('on_pretrain_routine_start')
|
||||
|
|
@ -450,14 +444,14 @@ class BaseTrainer:
|
|||
@staticmethod
|
||||
def get_dataset(data):
|
||||
"""
|
||||
Get train, val path from data dict if it exists. Returns None if data format is not recognized.
|
||||
Get train, val path from data dict if it exists.
|
||||
|
||||
Returns None if data format is not recognized.
|
||||
"""
|
||||
return data['train'], data.get('val') or data.get('test')
|
||||
|
||||
def setup_model(self):
|
||||
"""
|
||||
load/create/download model for any task.
|
||||
"""
|
||||
"""Load/create/download model for any task."""
|
||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||
return
|
||||
|
||||
|
|
@ -482,14 +476,14 @@ class BaseTrainer:
|
|||
self.ema.update(self.model)
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""
|
||||
Allows custom preprocessing model inputs and ground truths depending on task type.
|
||||
"""
|
||||
"""Allows custom preprocessing model inputs and ground truths depending on task type."""
|
||||
return batch
|
||||
|
||||
def validate(self):
|
||||
"""
|
||||
Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
|
||||
Runs validation on test set using self.validator.
|
||||
|
||||
The returned dict is expected to contain "fitness" key.
|
||||
"""
|
||||
metrics = self.validator(self)
|
||||
fitness = metrics.pop('fitness', -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
||||
|
|
@ -506,26 +500,20 @@ class BaseTrainer:
|
|||
raise NotImplementedError('get_validator function not implemented in trainer')
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||
"""
|
||||
Returns dataloader derived from torch.data.Dataloader.
|
||||
"""
|
||||
"""Returns dataloader derived from torch.data.Dataloader."""
|
||||
raise NotImplementedError('get_dataloader function not implemented in trainer')
|
||||
|
||||
def build_dataset(self, img_path, mode='train', batch=None):
|
||||
"""Build dataset"""
|
||||
"""Build dataset."""
|
||||
raise NotImplementedError('build_dataset function not implemented in trainer')
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix='train'):
|
||||
"""
|
||||
Returns a loss dict with labelled training loss items tensor
|
||||
"""
|
||||
"""Returns a loss dict with labelled training loss items tensor."""
|
||||
# Not needed for classification but necessary for segmentation & detection
|
||||
return {'loss': loss_items} if loss_items is not None else ['loss']
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""
|
||||
To set or update model parameters before training.
|
||||
"""
|
||||
"""To set or update model parameters before training."""
|
||||
self.model.names = self.data['names']
|
||||
|
||||
def build_targets(self, preds, targets):
|
||||
|
|
@ -632,8 +620,8 @@ class BaseTrainer:
|
|||
|
||||
def build_optimizer(self, model, name='auto', lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
||||
"""
|
||||
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate,
|
||||
momentum, weight decay, and number of iterations.
|
||||
Constructs an optimizer for the given model, based on the specified optimizer name, learning rate, momentum,
|
||||
weight decay, and number of iterations.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model for which to build an optimizer.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue