ultralytics 8.0.81 single-line docstring updates (#2061)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
5bce1c3021
commit
a38f227672
64 changed files with 620 additions and 58 deletions
|
|
@ -159,6 +159,7 @@ class BaseTrainer:
|
|||
self.callbacks[event] = [callback]
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
"""Run all existing callbacks associated with a particular event."""
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
||||
|
|
@ -190,6 +191,7 @@ class BaseTrainer:
|
|||
self._do_train(world_size)
|
||||
|
||||
def _setup_ddp(self, world_size):
|
||||
"""Initializes and sets the DistributedDataParallel parameters for training."""
|
||||
torch.cuda.set_device(RANK)
|
||||
self.device = torch.device('cuda', RANK)
|
||||
LOGGER.info(f'DDP settings: RANK {RANK}, WORLD_SIZE {world_size}, DEVICE {self.device}')
|
||||
|
|
@ -259,6 +261,7 @@ class BaseTrainer:
|
|||
self.run_callbacks('on_pretrain_routine_end')
|
||||
|
||||
def _do_train(self, world_size=1):
|
||||
"""Train completed, evaluate and plot if specified by arguments."""
|
||||
if world_size > 1:
|
||||
self._setup_ddp(world_size)
|
||||
|
||||
|
|
@ -392,6 +395,7 @@ class BaseTrainer:
|
|||
self.run_callbacks('teardown')
|
||||
|
||||
def save_model(self):
|
||||
"""Save model checkpoints based on various conditions."""
|
||||
ckpt = {
|
||||
'epoch': self.epoch,
|
||||
'best_fitness': self.best_fitness,
|
||||
|
|
@ -436,6 +440,7 @@ class BaseTrainer:
|
|||
return ckpt
|
||||
|
||||
def optimizer_step(self):
|
||||
"""Perform a single step of the training optimizer with gradient clipping and EMA update."""
|
||||
self.scaler.unscale_(self.optimizer) # unscale gradients
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients
|
||||
self.scaler.step(self.optimizer)
|
||||
|
|
@ -461,9 +466,11 @@ class BaseTrainer:
|
|||
return metrics, fitness
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Get model and raise NotImplementedError for loading cfg files."""
|
||||
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||
|
||||
def get_validator(self):
|
||||
"""Returns a NotImplementedError when the get_validator function is called."""
|
||||
raise NotImplementedError('get_validator function not implemented in trainer')
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode='train'):
|
||||
|
|
@ -492,19 +499,24 @@ class BaseTrainer:
|
|||
self.model.names = self.data['names']
|
||||
|
||||
def build_targets(self, preds, targets):
|
||||
"""Builds target tensors for training YOLO model."""
|
||||
pass
|
||||
|
||||
def progress_string(self):
|
||||
"""Returns a string describing training progress."""
|
||||
return ''
|
||||
|
||||
# TODO: may need to put these following functions into callback
|
||||
def plot_training_samples(self, batch, ni):
|
||||
"""Plots training samples during YOLOv5 training."""
|
||||
pass
|
||||
|
||||
def plot_training_labels(self):
|
||||
"""Plots training labels for YOLO model."""
|
||||
pass
|
||||
|
||||
def save_metrics(self, metrics):
|
||||
"""Saves training metrics to a CSV file."""
|
||||
keys, vals = list(metrics.keys()), list(metrics.values())
|
||||
n = len(metrics) + 1 # number of cols
|
||||
s = '' if self.csv.exists() else (('%23s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # header
|
||||
|
|
@ -512,9 +524,11 @@ class BaseTrainer:
|
|||
f.write(s + ('%23.5g,' * n % tuple([self.epoch] + vals)).rstrip(',') + '\n')
|
||||
|
||||
def plot_metrics(self):
|
||||
"""Plot and display metrics visually."""
|
||||
pass
|
||||
|
||||
def final_eval(self):
|
||||
"""Performs final evaluation and validation for object detection YOLO model."""
|
||||
for f in self.last, self.best:
|
||||
if f.exists():
|
||||
strip_optimizer(f) # strip optimizers
|
||||
|
|
@ -525,6 +539,7 @@ class BaseTrainer:
|
|||
self.run_callbacks('on_fit_epoch_end')
|
||||
|
||||
def check_resume(self):
|
||||
"""Check if resume checkpoint exists and update arguments accordingly."""
|
||||
resume = self.args.resume
|
||||
if resume:
|
||||
try:
|
||||
|
|
@ -539,6 +554,7 @@ class BaseTrainer:
|
|||
self.resume = resume
|
||||
|
||||
def resume_training(self, ckpt):
|
||||
"""Resume YOLO training from given epoch and best fitness."""
|
||||
if ckpt is None:
|
||||
return
|
||||
best_fitness = 0.0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue