General refactoring and improvements (#373)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
ac628c0d3e
commit
583eac0e80
18 changed files with 304 additions and 309 deletions
|
|
@ -40,7 +40,7 @@ class BaseTrainer:
|
|||
"""
|
||||
BaseTrainer
|
||||
|
||||
> A base class for creating trainers.
|
||||
A base class for creating trainers.
|
||||
|
||||
Attributes:
|
||||
args (OmegaConf): Configuration for the trainer.
|
||||
|
|
@ -75,7 +75,7 @@ class BaseTrainer:
|
|||
|
||||
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
|
||||
"""
|
||||
> Initializes the BaseTrainer class.
|
||||
Initializes the BaseTrainer class.
|
||||
|
||||
Args:
|
||||
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
|
||||
|
|
@ -149,13 +149,13 @@ class BaseTrainer:
|
|||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""
|
||||
> Appends the given callback.
|
||||
Appends the given callback.
|
||||
"""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def set_callback(self, event: str, callback):
|
||||
"""
|
||||
> Overrides the existing callbacks with the given callback.
|
||||
Overrides the existing callbacks with the given callback.
|
||||
"""
|
||||
self.callbacks[event] = [callback]
|
||||
|
||||
|
|
@ -194,7 +194,7 @@ class BaseTrainer:
|
|||
|
||||
def _setup_train(self, rank, world_size):
|
||||
"""
|
||||
> Builds dataloaders and optimizer on correct rank process.
|
||||
Builds dataloaders and optimizer on correct rank process.
|
||||
"""
|
||||
# model
|
||||
self.run_callbacks("on_pretrain_routine_start")
|
||||
|
|
@ -383,13 +383,13 @@ class BaseTrainer:
|
|||
|
||||
def get_dataset(self, data):
|
||||
"""
|
||||
> Get train, val path from data dict if it exists. Returns None if data format is not recognized.
|
||||
Get train, val path from data dict if it exists. Returns None if data format is not recognized.
|
||||
"""
|
||||
return data["train"], data.get("val") or data.get("test")
|
||||
|
||||
def setup_model(self):
|
||||
"""
|
||||
> load/create/download model for any task.
|
||||
load/create/download model for any task.
|
||||
"""
|
||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||
return
|
||||
|
|
@ -415,13 +415,13 @@ class BaseTrainer:
|
|||
|
||||
def preprocess_batch(self, batch):
|
||||
"""
|
||||
> Allows custom preprocessing model inputs and ground truths depending on task type.
|
||||
Allows custom preprocessing model inputs and ground truths depending on task type.
|
||||
"""
|
||||
return batch
|
||||
|
||||
def validate(self):
|
||||
"""
|
||||
> Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
|
||||
Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
|
||||
"""
|
||||
metrics = self.validator(self)
|
||||
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
||||
|
|
@ -431,7 +431,7 @@ class BaseTrainer:
|
|||
|
||||
def log(self, text, rank=-1):
|
||||
"""
|
||||
> Logs the given text to given ranks process if provided, otherwise logs to all ranks.
|
||||
Logs the given text to given ranks process if provided, otherwise logs to all ranks.
|
||||
|
||||
Args"
|
||||
text (str): text to log
|
||||
|
|
@ -449,13 +449,13 @@ class BaseTrainer:
|
|||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0):
|
||||
"""
|
||||
> Returns dataloader derived from torch.data.Dataloader.
|
||||
Returns dataloader derived from torch.data.Dataloader.
|
||||
"""
|
||||
raise NotImplementedError("get_dataloader function not implemented in trainer")
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
"""
|
||||
> Returns loss and individual loss items as Tensor.
|
||||
Returns loss and individual loss items as Tensor.
|
||||
"""
|
||||
raise NotImplementedError("criterion function not implemented in trainer")
|
||||
|
||||
|
|
@ -543,7 +543,7 @@ class BaseTrainer:
|
|||
@staticmethod
|
||||
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
|
||||
"""
|
||||
> Builds an optimizer with the specified parameters and parameter groups.
|
||||
Builds an optimizer with the specified parameters and parameter groups.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model to optimize
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue