Docstring additions (#122)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
c9f3e469cb
commit
df4fc14c10
10 changed files with 291 additions and 73 deletions
|
|
@ -33,9 +33,53 @@ from ultralytics.yolo.utils.torch_utils import ModelEMA, de_parallel, init_seeds
|
|||
|
||||
|
||||
class BaseTrainer:
|
||||
"""
|
||||
BaseTrainer
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CONFIG, overrides={}):
|
||||
self.args = get_config(cfg, overrides)
|
||||
A base class for creating trainers.
|
||||
|
||||
Attributes:
|
||||
args (OmegaConf): Configuration for the trainer.
|
||||
check_resume (method): Method to check if training should be resumed from a saved checkpoint.
|
||||
console (logging.Logger): Logger instance.
|
||||
validator (BaseValidator): Validator instance.
|
||||
model (nn.Module): Model instance.
|
||||
callbacks (defaultdict): Dictionary of callbacks.
|
||||
save_dir (Path): Directory to save results.
|
||||
wdir (Path): Directory to save weights.
|
||||
last (Path): Path to last checkpoint.
|
||||
best (Path): Path to best checkpoint.
|
||||
batch_size (int): Batch size for training.
|
||||
epochs (int): Number of epochs to train for.
|
||||
start_epoch (int): Starting epoch for training.
|
||||
device (torch.device): Device to use for training.
|
||||
amp (bool): Flag to enable AMP (Automatic Mixed Precision).
|
||||
scaler (amp.GradScaler): Gradient scaler for AMP.
|
||||
data (str): Path to data.
|
||||
trainset (torch.utils.data.Dataset): Training dataset.
|
||||
testset (torch.utils.data.Dataset): Testing dataset.
|
||||
ema (nn.Module): EMA (Exponential Moving Average) of the model.
|
||||
lf (nn.Module): Loss function.
|
||||
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
|
||||
best_fitness (float): The best fitness value achieved.
|
||||
fitness (float): Current fitness value.
|
||||
loss (float): Current loss value.
|
||||
tloss (float): Total loss value.
|
||||
loss_names (list): List of loss names.
|
||||
csv (Path): Path to results CSV file.
|
||||
"""
|
||||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
||||
"""
|
||||
Initializes the BaseTrainer class.
|
||||
|
||||
Args:
|
||||
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
overrides (dict, optional): Configuration overrides. Defaults to None.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
self.args = get_config(config, overrides)
|
||||
self.check_resume()
|
||||
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
||||
|
||||
|
|
@ -464,6 +508,19 @@ class BaseTrainer:
|
|||
|
||||
@staticmethod
|
||||
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
||||
"""
|
||||
Builds an optimizer with the specified parameters and parameter groups.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to optimize
|
||||
name (str): name of the optimizer to use
|
||||
lr (float): learning rate
|
||||
momentum (float): momentum
|
||||
decay (float): weight decay
|
||||
|
||||
Returns:
|
||||
torch.optim.Optimizer: the built optimizer
|
||||
"""
|
||||
g = [], [], [] # optimizer parameter groups
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if 'Norm' in k) # normalization layers, i.e. BatchNorm2d()
|
||||
for v in model.modules():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue