[Docs]: Add customization tutorial and address feedback (#155)
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
c985eaba0d
commit
d387359f74
8 changed files with 133 additions and 74 deletions
|
|
@ -39,7 +39,7 @@ class BaseTrainer:
|
|||
"""
|
||||
BaseTrainer
|
||||
|
||||
A base class for creating trainers.
|
||||
> A base class for creating trainers.
|
||||
|
||||
Attributes:
|
||||
args (OmegaConf): Configuration for the trainer.
|
||||
|
|
@ -74,7 +74,7 @@ class BaseTrainer:
|
|||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
||||
"""
|
||||
Initializes the BaseTrainer class.
|
||||
> Initializes the BaseTrainer class.
|
||||
|
||||
Args:
|
||||
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
|
|
@ -148,13 +148,13 @@ class BaseTrainer:
|
|||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""
|
||||
Appends the given callback. TODO: unused, consider removing
|
||||
> Appends the given callback.
|
||||
"""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def set_callback(self, event: str, callback):
|
||||
"""
|
||||
Overrides the existing callbacks with the given callback. TODO: unused, consider removing
|
||||
> Overrides the existing callbacks with the given callback.
|
||||
"""
|
||||
self.callbacks[event] = [callback]
|
||||
|
||||
|
|
@ -185,7 +185,7 @@ class BaseTrainer:
|
|||
|
||||
def _setup_train(self, rank, world_size):
|
||||
"""
|
||||
Builds dataloaders and optimizer on correct rank process
|
||||
> Builds dataloaders and optimizer on correct rank process.
|
||||
"""
|
||||
# model
|
||||
self.run_callbacks("on_pretrain_routine_start")
|
||||
|
|
@ -373,13 +373,13 @@ class BaseTrainer:
|
|||
|
||||
def get_dataset(self, data):
|
||||
"""
|
||||
Get train, val path from data dict if it exists. Returns None if data format is not recognized
|
||||
> Get train, val path from data dict if it exists. Returns None if data format is not recognized.
|
||||
"""
|
||||
return data["train"], data.get("val") or data.get("test")
|
||||
|
||||
def setup_model(self):
|
||||
"""
|
||||
load/create/download model for any task
|
||||
> load/create/download model for any task.
|
||||
"""
|
||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||
return
|
||||
|
|
@ -405,15 +405,13 @@ class BaseTrainer:
|
|||
|
||||
def preprocess_batch(self, batch):
|
||||
"""
|
||||
Allows custom preprocessing model inputs and ground truths depending on task type
|
||||
> Allows custom preprocessing model inputs and ground truths depending on task type.
|
||||
"""
|
||||
return batch
|
||||
|
||||
def validate(self):
|
||||
"""
|
||||
Runs validation on test set using self.validator.
|
||||
# TODO: discuss validator class. Enforce that a validator metrics dict should contain
|
||||
"fitness" metric.
|
||||
> Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
|
||||
"""
|
||||
metrics = self.validator(self)
|
||||
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
||||
|
|
@ -423,9 +421,11 @@ class BaseTrainer:
|
|||
|
||||
def log(self, text, rank=-1):
|
||||
"""
|
||||
Logs the given text to given ranks process if provided, otherwise logs to all ranks
|
||||
:param text: text to log
|
||||
:param rank: List[Int]
|
||||
> Logs the given text to given ranks process if provided, otherwise logs to all ranks.
|
||||
|
||||
Args"
|
||||
text (str): text to log
|
||||
rank (List[Int]): process rank
|
||||
|
||||
"""
|
||||
if rank in {-1, 0}:
|
||||
|
|
@ -439,13 +439,13 @@ class BaseTrainer:
|
|||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0):
|
||||
"""
|
||||
Returns dataloader derived from torch.data.Dataloader
|
||||
> Returns dataloader derived from torch.data.Dataloader.
|
||||
"""
|
||||
raise NotImplementedError("get_dataloader function not implemented in trainer")
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
"""
|
||||
Returns loss and individual loss items as Tensor
|
||||
> Returns loss and individual loss items as Tensor.
|
||||
"""
|
||||
raise NotImplementedError("criterion function not implemented in trainer")
|
||||
|
||||
|
|
@ -531,7 +531,7 @@ class BaseTrainer:
|
|||
@staticmethod
|
||||
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
||||
"""
|
||||
Builds an optimizer with the specified parameters and parameter groups.
|
||||
> Builds an optimizer with the specified parameters and parameter groups.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to optimize
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue