Segmentation support & other enchancements (#40)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
This commit is contained in:
parent
c617ee1c79
commit
f56c9bcc26
17 changed files with 1320 additions and 47 deletions
|
|
@ -1,12 +1,17 @@
|
|||
"""
|
||||
Simple training loop; Boilerplate that could apply to any arbitrary neural network,
|
||||
"""
|
||||
# TODOs
|
||||
# 1. finish _set_model_attributes
|
||||
# 2. allow num_class update for both pretrained and csv_loaded models
|
||||
# 3. save
|
||||
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from telnetlib import TLS
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
|
|
@ -52,6 +57,8 @@ class BaseTrainer:
|
|||
|
||||
# Model and Dataloaders.
|
||||
self.trainset, self.testset = self.get_dataset(self.args.data)
|
||||
if self.args.cfg is not None:
|
||||
self.model = self.load_cfg(self.args.cfg)
|
||||
if self.args.model is not None:
|
||||
self.model = self.get_model(self.args.model, self.args.pretrained).to(self.device)
|
||||
|
||||
|
|
@ -133,6 +140,20 @@ class BaseTrainer:
|
|||
self.test_loader = self.get_dataloader(self.testset, batch_size=self.args.batch_size * 2, rank=rank)
|
||||
self.validator = self.get_validator()
|
||||
print("created testloader :", rank)
|
||||
self.console.info(self.progress_string())
|
||||
|
||||
def _set_model_attributes(self):
|
||||
# TODO: fix and use after self.data_dict is available
|
||||
'''
|
||||
head = utils.torch_utils.de_parallel(self.model).model[-1]
|
||||
self.args.box *= 3 / head.nl # scale to layers
|
||||
self.args.cls *= head.nc / 80 * 3 / head.nl # scale to classes and layers
|
||||
self.args.obj *= (self.args.img_size / 640) ** 2 * 3 / nl # scale to image size and layers
|
||||
model.nc = nc # attach number of classes to model
|
||||
model.hyp = hyp # attach hyperparameters to model
|
||||
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc # attach class weights
|
||||
model.names = names
|
||||
'''
|
||||
|
||||
def _do_train(self, rank, world_size):
|
||||
if world_size > 1:
|
||||
|
|
@ -153,13 +174,17 @@ class BaseTrainer:
|
|||
pbar = tqdm(enumerate(self.train_loader),
|
||||
total=len(self.train_loader),
|
||||
bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}')
|
||||
tloss = 0
|
||||
for i, (images, labels) in pbar:
|
||||
tloss = None
|
||||
for i, batch in pbar:
|
||||
# img, label (classification)/ img, targets, paths, _, masks(detection)
|
||||
# callback hook. on_batch_start
|
||||
# forward
|
||||
images, labels = self.preprocess_batch(images, labels)
|
||||
self.loss = self.criterion(self.model(images), labels)
|
||||
tloss = (tloss * i + self.loss.item()) / (i + 1)
|
||||
batch = self.preprocess_batch(batch)
|
||||
|
||||
# TODO: warmup, multiscale
|
||||
preds = self.model(batch["img"])
|
||||
self.loss, self.loss_items = self.criterion(preds, batch)
|
||||
tloss = (tloss * i + self.loss_items) / (i + 1) if tloss is not None else self.loss_items
|
||||
|
||||
# backward
|
||||
self.model.zero_grad(set_to_none=True)
|
||||
|
|
@ -170,9 +195,13 @@ class BaseTrainer:
|
|||
self.trigger_callbacks('on_batch_end')
|
||||
|
||||
# log
|
||||
mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
|
||||
mem = (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
|
||||
loss_len = tloss.shape[0] if len(tloss.size()) else 1
|
||||
losses = tloss if loss_len > 1 else torch.unsqueeze(tloss, 0)
|
||||
if rank in {-1, 0}:
|
||||
pbar.desc = f"{f'{epoch + 1}/{self.args.epochs}':>10}{mem:>10}{tloss:>12.3g}" + ' ' * 36
|
||||
pbar.set_description(
|
||||
(" {} " + "{:.3f} " * (2 + loss_len)).format(f'{epoch + 1}/{self.args.epochs}', mem, *losses,
|
||||
batch["img"].shape[-1]))
|
||||
|
||||
if rank in [-1, 0]:
|
||||
# validation
|
||||
|
|
@ -240,6 +269,9 @@ class BaseTrainer:
|
|||
|
||||
return model
|
||||
|
||||
def load_cfg(self, cfg):
|
||||
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||
|
||||
def get_validator(self):
|
||||
pass
|
||||
|
||||
|
|
@ -250,11 +282,11 @@ class BaseTrainer:
|
|||
self.scaler.update()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def preprocess_batch(self, images, labels):
|
||||
def preprocess_batch(self, batch):
|
||||
"""
|
||||
Allows custom preprocessing model inputs and ground truths depending on task type
|
||||
"""
|
||||
return images.to(self.device, non_blocking=True), labels.to(self.device)
|
||||
return batch
|
||||
|
||||
def validate(self):
|
||||
"""
|
||||
|
|
@ -270,14 +302,17 @@ class BaseTrainer:
|
|||
def build_targets(self, preds, targets):
|
||||
pass
|
||||
|
||||
def criterion(self, preds, targets):
|
||||
def criterion(self, preds, batch):
|
||||
"""
|
||||
Returns loss and individual loss items as Tensor
|
||||
"""
|
||||
pass
|
||||
|
||||
def progress_string(self):
|
||||
"""
|
||||
Returns progress string depending on task type.
|
||||
"""
|
||||
pass
|
||||
return ''
|
||||
|
||||
def usage_help(self):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue