update segment training (#57)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: ayush chaurasia <ayush.chaurarsia@gmail.com>
This commit is contained in:
parent
d0b0fe2592
commit
3a241e4cea
14 changed files with 460 additions and 144 deletions
|
|
@ -9,30 +9,18 @@ from ultralytics.yolo.engine.trainer import DEFAULT_CONFIG, BaseTrainer
|
|||
from ultralytics.yolo.utils.metrics import FocalLoss, bbox_iou, smooth_BCE
|
||||
from ultralytics.yolo.utils.modeling.tasks import SegmentationModel
|
||||
from ultralytics.yolo.utils.ops import crop_mask, xywh2xyxy
|
||||
from ultralytics.yolo.utils.plotting import plot_images_and_masks, plot_results_with_masks
|
||||
from ultralytics.yolo.utils.torch_utils import de_parallel
|
||||
|
||||
|
||||
# BaseTrainer python usage
|
||||
class SegmentationTrainer(BaseTrainer):
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size, rank=0):
|
||||
def get_dataloader(self, dataset_path, batch_size, mode="train", rank=0):
|
||||
# TODO: manage splits differently
|
||||
# calculate stride - check if model is initialized
|
||||
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
|
||||
return build_dataloader(
|
||||
img_path=dataset_path,
|
||||
img_size=self.args.img_size,
|
||||
batch_size=batch_size,
|
||||
single_cls=self.args.single_cls,
|
||||
cache=self.args.cache,
|
||||
image_weights=self.args.image_weights,
|
||||
stride=gs,
|
||||
rect=self.args.rect,
|
||||
rank=rank,
|
||||
workers=self.args.workers,
|
||||
shuffle=self.args.shuffle,
|
||||
use_segments=True,
|
||||
)[0]
|
||||
return build_dataloader(self.args, batch_size, img_path=dataset_path, stride=gs, rank=rank, mode=mode)[0]
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
batch["img"] = batch["img"].to(self.device, non_blocking=True).float() / 255
|
||||
|
|
@ -58,7 +46,10 @@ class SegmentationTrainer(BaseTrainer):
|
|||
self.model.names = self.data["names"]
|
||||
|
||||
def get_validator(self):
|
||||
return v8.segment.SegmentationValidator(self.test_loader, self.device, logger=self.console)
|
||||
return v8.segment.SegmentationValidator(self.test_loader,
|
||||
save_dir=self.save_dir,
|
||||
logger=self.console,
|
||||
args=self.args)
|
||||
|
||||
def criterion(self, preds, batch):
|
||||
head = de_parallel(self.model).model[-1]
|
||||
|
|
@ -218,6 +209,8 @@ class SegmentationTrainer(BaseTrainer):
|
|||
else:
|
||||
mask_gti = masks[tidxs[i]][j]
|
||||
lseg += single_mask_loss(mask_gti, pmask[j], proto[bi], mxyxy[j], marea[j])
|
||||
else:
|
||||
lseg += (proto * 0).sum()
|
||||
|
||||
obji = BCEobj(pi[..., 4], tobj)
|
||||
lobj += obji * balance[i] # obj loss
|
||||
|
|
@ -234,15 +227,33 @@ class SegmentationTrainer(BaseTrainer):
|
|||
loss = lbox + lobj + lcls + lseg
|
||||
return loss * bs, torch.cat((lbox, lseg, lobj, lcls)).detach()
|
||||
|
||||
def label_loss_items(self, loss_items):
|
||||
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||
# We should just use named tensors here in future
|
||||
keys = ["lbox", "lseg", "lobj", "lcls"]
|
||||
return dict(zip(keys, loss_items))
|
||||
keys = [f"{prefix}/lbox", f"{prefix}/lseg", f"{prefix}/lobj", f"{prefix}/lcls"]
|
||||
return dict(zip(keys, loss_items)) if loss_items is not None else keys
|
||||
|
||||
def progress_string(self):
|
||||
return ('\n' + '%11s' * 7) % \
|
||||
('Epoch', 'GPU_mem', 'box_loss', 'seg_loss', 'obj_loss', 'cls_loss', 'Size')
|
||||
|
||||
def plot_training_samples(self, batch, ni):
|
||||
images = batch["img"]
|
||||
masks = batch["masks"]
|
||||
cls = batch["cls"].squeeze(-1)
|
||||
bboxes = batch["bboxes"]
|
||||
paths = batch["im_file"]
|
||||
batch_idx = batch["batch_idx"]
|
||||
plot_images_and_masks(images,
|
||||
batch_idx,
|
||||
cls,
|
||||
bboxes,
|
||||
masks,
|
||||
paths,
|
||||
fname=self.save_dir / f"train_batch{ni}.jpg")
|
||||
|
||||
def plot_metrics(self):
|
||||
plot_results_with_masks(file=self.csv) # save results.png
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path=DEFAULT_CONFIG.parent, config_name=DEFAULT_CONFIG.name)
|
||||
def train(cfg):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue