General refactoring and improvements (#373)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Glenn Jocher 2023-01-15 14:44:25 +01:00 committed by GitHub
parent ac628c0d3e
commit 583eac0e80
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 304 additions and 309 deletions

View file

@ -40,7 +40,7 @@ class BaseTrainer:
"""
BaseTrainer
> A base class for creating trainers.
A base class for creating trainers.
Attributes:
args (OmegaConf): Configuration for the trainer.
@ -75,7 +75,7 @@ class BaseTrainer:
def __init__(self, config=DEFAULT_CONFIG, overrides=None):
"""
> Initializes the BaseTrainer class.
Initializes the BaseTrainer class.
Args:
config (str, optional): Path to a configuration file. Defaults to DEFAULT_CONFIG.
@ -149,13 +149,13 @@ class BaseTrainer:
def add_callback(self, event: str, callback):
"""
> Appends the given callback.
Appends the given callback.
"""
self.callbacks[event].append(callback)
def set_callback(self, event: str, callback):
"""
> Overrides the existing callbacks with the given callback.
Overrides the existing callbacks with the given callback.
"""
self.callbacks[event] = [callback]
@ -194,7 +194,7 @@ class BaseTrainer:
def _setup_train(self, rank, world_size):
"""
> Builds dataloaders and optimizer on correct rank process.
Builds dataloaders and optimizer on correct rank process.
"""
# model
self.run_callbacks("on_pretrain_routine_start")
@ -383,13 +383,13 @@ class BaseTrainer:
def get_dataset(self, data):
"""
> Get train, val path from data dict if it exists. Returns None if data format is not recognized.
Get train, val path from data dict if it exists. Returns None if data format is not recognized.
"""
return data["train"], data.get("val") or data.get("test")
def setup_model(self):
"""
> load/create/download model for any task.
load/create/download model for any task.
"""
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
@ -415,13 +415,13 @@ class BaseTrainer:
def preprocess_batch(self, batch):
"""
> Allows custom preprocessing model inputs and ground truths depending on task type.
Allows custom preprocessing model inputs and ground truths depending on task type.
"""
return batch
def validate(self):
"""
> Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
Runs validation on test set using self.validator. The returned dict is expected to contain "fitness" key.
"""
metrics = self.validator(self)
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
@ -431,7 +431,7 @@ class BaseTrainer:
def log(self, text, rank=-1):
"""
> Logs the given text to given ranks process if provided, otherwise logs to all ranks.
Logs the given text to given ranks process if provided, otherwise logs to all ranks.
Args"
text (str): text to log
@ -449,13 +449,13 @@ class BaseTrainer:
def get_dataloader(self, dataset_path, batch_size=16, rank=0):
"""
> Returns dataloader derived from torch.data.Dataloader.
Returns dataloader derived from torch.data.Dataloader.
"""
raise NotImplementedError("get_dataloader function not implemented in trainer")
def criterion(self, preds, batch):
"""
> Returns loss and individual loss items as Tensor.
Returns loss and individual loss items as Tensor.
"""
raise NotImplementedError("criterion function not implemented in trainer")
@ -543,7 +543,7 @@ class BaseTrainer:
@staticmethod
def build_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
"""
> Builds an optimizer with the specified parameters and parameter groups.
Builds an optimizer with the specified parameters and parameter groups.
Args:
model (nn.Module): model to optimize