[Docs]: Add customization tutorial and address feedback (#155)

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
Ayush Chaurasia 2023-01-08 18:31:22 +05:30 committed by GitHub
parent c985eaba0d
commit d387359f74
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 133 additions and 74 deletions

View file

@ -39,7 +39,7 @@ class BaseTrainer:
"""
BaseTrainer
A base class for creating trainers.
> A base class for creating trainers.
Attributes:
args (OmegaConf): Configuration for the trainer.
@ -74,7 +74,7 @@ class BaseTrainer:
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
"""
Initializes the BaseTrainer class.
> Initializes the BaseTrainer class.
Args:
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
@ -148,13 +148,13 @@ class BaseTrainer:
def add_callback(self, event: str, callback):
"""
Appends the given callback. TODO: unused, consider removing
> Appends the given callback.
"""
self.callbacks[event].append(callback)
def set_callback(self, event: str, callback):
"""
Overrides the existing callbacks with the given callback. TODO: unused, consider removing
> Overrides the existing callbacks with the given callback.
"""
self.callbacks[event] = [callback]
@ -185,7 +185,7 @@ class BaseTrainer:
def _setup_train(self, rank, world_size):
"""
Builds dataloaders and optimizer on correct rank process
> Builds dataloaders and optimizer on correct rank process.
"""
# model
self.run_callbacks("on_pretrain_routine_start")
@ -373,13 +373,13 @@ class BaseTrainer:
def get_dataset(self, data):
"""
Get train, val path from data dict if it exists. Returns None if data format is not recognized
> Get train, val path from data dict if it exists. Returns None if data format is not recognized.
"""
return data["train"], data.get("val") or data.get("test")
def setup_model(self):
"""
load/create/download model for any task
> load/create/download model for any task.
"""
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
@ -405,15 +405,13 @@ class BaseTrainer:
def preprocess_batch(self, batch):
"""
Allows custom preprocessing model inputs and ground truths depending on task type
> Allows custom preprocessing model inputs and ground truths depending on task type.
"""
return batch
def validate(self):
"""
Runs validation on test set using self.validator.
# TODO: discuss validator class. Enforce that a validator metrics dict should contain
"fitness" metric.
> Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
"""
metrics = self.validator(self)
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
@ -423,9 +421,11 @@ class BaseTrainer:
def log(self, text, rank=-1):
"""
Logs the given text to given ranks process if provided, otherwise logs to all ranks
:param text: text to log
:param rank: List[Int]
> Logs the given text to given ranks process if provided, otherwise logs to all ranks.
Args"
text (str): text to log
rank (List[Int]): process rank
"""
if rank in {-1, 0}:
@ -439,13 +439,13 @@ class BaseTrainer:
def get_dataloader(self, dataset_path, batch_size=16, rank=0):
"""
Returns dataloader derived from torch.data.Dataloader
> Returns dataloader derived from torch.data.Dataloader.
"""
raise NotImplementedError("get_dataloader function not implemented in trainer")
def criterion(self, preds, batch):
"""
Returns loss and individual loss items as Tensor
> Returns loss and individual loss items as Tensor.
"""
raise NotImplementedError("criterion function not implemented in trainer")
@ -531,7 +531,7 @@ class BaseTrainer:
@staticmethod
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
"""
Builds an optimizer with the specified parameters and parameter groups.
> Builds an optimizer with the specified parameters and parameter groups.
Args:
model (nn.Module): model to optimize