Metrics and loss structure (#28)
Co-authored-by: Ayush Chaurasia <ayush.chuararsia@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
d0b3c9812b
commit
c5cb76b356
12 changed files with 183 additions and 43 deletions
|
|
@ -28,20 +28,11 @@ DEFAULT_CONFIG = "defaults.yaml"
|
|||
|
||||
class BaseTrainer:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
data: str,
|
||||
criterion, # Should we create our own base loss classes? yolo.losses -> v8.losses.clfLoss
|
||||
validator=None,
|
||||
config=CONFIG_PATH_ABS / DEFAULT_CONFIG):
|
||||
def __init__(self, config=CONFIG_PATH_ABS / DEFAULT_CONFIG):
|
||||
self.console = LOGGER
|
||||
self.model = model
|
||||
self.data = data
|
||||
self.criterion = criterion # ComputeLoss object TODO: create yolo.Loss classes
|
||||
self.validator = val # Dummy validator
|
||||
self.model, self.data, self.train, self.hyps = self._get_config(config)
|
||||
self.validator = None
|
||||
self.callbacks = defaultdict(list)
|
||||
self.train, self.hyps = self._get_config(config)
|
||||
self.console.info(f"Training config: \n train: \n {self.train} \n hyps: \n {self.hyps}") # to debug
|
||||
# Directories
|
||||
self.save_dir = utils.increment_path(Path(self.train.project) / self.train.name, exist_ok=self.train.exist_ok)
|
||||
|
|
@ -57,7 +48,7 @@ class BaseTrainer:
|
|||
self.console.info(f"running on device {self.device}")
|
||||
self.scaler = amp.GradScaler(enabled=self.device.type != 'cpu')
|
||||
|
||||
# Model and Dataloaders. TBD: Should we move this inside trainer?
|
||||
# Model and Dataloaders.
|
||||
self.trainset, self.testset = self.get_dataset() # initialize dataset before as nc is needed for model
|
||||
self.model = self.get_model()
|
||||
self.model = self.model.to(self.device)
|
||||
|
|
@ -80,9 +71,9 @@ class BaseTrainer:
|
|||
try:
|
||||
if isinstance(config, (str, Path)):
|
||||
config = OmegaConf.load(config)
|
||||
return config.train, config.hyps
|
||||
return config.model, config.data, config.train, config.hyps
|
||||
except KeyError as e:
|
||||
raise Exception("Missing key(s) in config") from e
|
||||
raise KeyError("Missing key(s) in config") from e
|
||||
|
||||
def add_callback(self, onevent: str, callback):
|
||||
"""
|
||||
|
|
@ -131,10 +122,9 @@ class BaseTrainer:
|
|||
self.train_loader = self.get_dataloader(self.trainset, batch_size=self.train.batch_size, rank=rank)
|
||||
if rank in {0, -1}:
|
||||
print(" Creating testloader rank :", rank)
|
||||
# self.test_loader = self.get_dataloader(self.testset,
|
||||
# batch_size=self.train.batch_size*2,
|
||||
# rank=rank)
|
||||
# print("created testloader :", rank)
|
||||
self.test_loader = self.get_dataloader(self.testset, batch_size=self.train.batch_size * 2, rank=rank)
|
||||
self.validator = self.get_validator()
|
||||
print("created testloader :", rank)
|
||||
|
||||
def _do_train(self, rank, world_size):
|
||||
if world_size > 1:
|
||||
|
|
@ -235,11 +225,8 @@ class BaseTrainer:
|
|||
"""
|
||||
pass
|
||||
|
||||
def set_criterion(self, criterion):
|
||||
"""
|
||||
:param criterion: yolo.Loss object.
|
||||
"""
|
||||
self.criterion = criterion
|
||||
def get_validator(self):
|
||||
pass
|
||||
|
||||
def optimizer_step(self):
|
||||
self.scaler.unscale_(self.optimizer) # unscale gradients
|
||||
|
|
@ -265,6 +252,12 @@ class BaseTrainer:
|
|||
if not self.best_fitness or self.best_fitness < self.fitness:
|
||||
self.best_fitness = self.fitness
|
||||
|
||||
def build_targets(self, preds, targets):
|
||||
pass
|
||||
|
||||
def criterion(self, preds, targets):
|
||||
pass
|
||||
|
||||
def progress_string(self):
|
||||
"""
|
||||
Returns progress string depending on task type.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue