ultralytics 8.0.196 instance-mean Segment loss (#5285)
Co-authored-by: Andy <39454881+yermandy@users.noreply.github.com>
This commit is contained in:
parent
7517667a33
commit
e7f0658744
72 changed files with 369 additions and 493 deletions
|
|
@ -212,7 +212,6 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|||
|
||||
def __init__(self, model): # model must be de-paralleled
|
||||
super().__init__(model)
|
||||
self.nm = model.model[-1].nm # number of masks
|
||||
self.overlap = model.args.overlap_mask
|
||||
|
||||
def __call__(self, preds, batch):
|
||||
|
|
@ -268,38 +267,108 @@ class v8SegmentationLoss(v8DetectionLoss):
|
|||
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
|
||||
masks = F.interpolate(masks[None], (mask_h, mask_w), mode='nearest')[0]
|
||||
|
||||
for i in range(batch_size):
|
||||
if fg_mask[i].sum():
|
||||
mask_idx = target_gt_idx[i][fg_mask[i]]
|
||||
if self.overlap:
|
||||
gt_mask = torch.where(masks[[i]] == (mask_idx + 1).view(-1, 1, 1), 1.0, 0.0)
|
||||
else:
|
||||
gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
|
||||
xyxyn = target_bboxes[i][fg_mask[i]] / imgsz[[1, 0, 1, 0]]
|
||||
marea = xyxy2xywh(xyxyn)[:, 2:].prod(1)
|
||||
mxyxy = xyxyn * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=self.device)
|
||||
loss[1] += self.single_mask_loss(gt_mask, pred_masks[i][fg_mask[i]], proto[i], mxyxy, marea) # seg
|
||||
|
||||
# WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
|
||||
else:
|
||||
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
|
||||
loss[1] = self.calculate_segmentation_loss(fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto,
|
||||
pred_masks, imgsz, self.overlap)
|
||||
|
||||
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
|
||||
else:
|
||||
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
|
||||
|
||||
loss[0] *= self.hyp.box # box gain
|
||||
loss[1] *= self.hyp.box / batch_size # seg gain
|
||||
loss[1] *= self.hyp.box # seg gain
|
||||
loss[2] *= self.hyp.cls # cls gain
|
||||
loss[3] *= self.hyp.dfl # dfl gain
|
||||
|
||||
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||
|
||||
def single_mask_loss(self, gt_mask, pred, proto, xyxy, area):
|
||||
"""Mask loss for one image."""
|
||||
pred_mask = (pred @ proto.view(self.nm, -1)).view(-1, *proto.shape[1:]) # (n, 32) @ (32,80,80) -> (n,80,80)
|
||||
@staticmethod
|
||||
def single_mask_loss(gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor,
|
||||
area: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute the instance segmentation loss for a single image.
|
||||
|
||||
Args:
|
||||
gt_mask (torch.Tensor): Ground truth mask of shape (n, H, W), where n is the number of objects.
|
||||
pred (torch.Tensor): Predicted mask coefficients of shape (n, 32).
|
||||
proto (torch.Tensor): Prototype masks of shape (32, H, W).
|
||||
xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (n, 4).
|
||||
area (torch.Tensor): Area of each ground truth bounding box of shape (n,).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The calculated mask loss for a single image.
|
||||
|
||||
Notes:
|
||||
The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the
|
||||
predicted masks from the prototype masks and predicted mask coefficients.
|
||||
"""
|
||||
pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
|
||||
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction='none')
|
||||
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).mean()
|
||||
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()
|
||||
|
||||
def calculate_segmentation_loss(
|
||||
self,
|
||||
fg_mask: torch.Tensor,
|
||||
masks: torch.Tensor,
|
||||
target_gt_idx: torch.Tensor,
|
||||
target_bboxes: torch.Tensor,
|
||||
batch_idx: torch.Tensor,
|
||||
proto: torch.Tensor,
|
||||
pred_masks: torch.Tensor,
|
||||
imgsz: torch.Tensor,
|
||||
overlap: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the loss for instance segmentation.
|
||||
|
||||
Args:
|
||||
fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
|
||||
masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W).
|
||||
target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors).
|
||||
target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4).
|
||||
batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1).
|
||||
proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
|
||||
pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
|
||||
imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
|
||||
overlap (bool): Whether the masks in `masks` tensor overlap.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The calculated loss for instance segmentation.
|
||||
|
||||
Notes:
|
||||
The batch loss can be computed for improved speed at higher memory usage.
|
||||
For example, pred_mask can be computed as follows:
|
||||
pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160)
|
||||
"""
|
||||
_, _, mask_h, mask_w = proto.shape
|
||||
loss = 0
|
||||
|
||||
# normalize to 0-1
|
||||
target_bboxes_normalized = target_bboxes / imgsz[[1, 0, 1, 0]]
|
||||
|
||||
# areas of target bboxes
|
||||
marea = xyxy2xywh(target_bboxes_normalized)[..., 2:].prod(2)
|
||||
|
||||
# normalize to mask size
|
||||
mxyxy = target_bboxes_normalized * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=proto.device)
|
||||
|
||||
for i, single_i in enumerate(zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)):
|
||||
fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
|
||||
if fg_mask_i.any():
|
||||
mask_idx = target_gt_idx_i[fg_mask_i]
|
||||
if overlap:
|
||||
gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
|
||||
gt_mask = gt_mask.float()
|
||||
else:
|
||||
gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
|
||||
|
||||
loss += self.single_mask_loss(gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i],
|
||||
marea_i[fg_mask_i])
|
||||
|
||||
# WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
|
||||
else:
|
||||
loss += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
|
||||
|
||||
return loss / fg_mask.sum()
|
||||
|
||||
|
||||
class v8PoseLoss(v8DetectionLoss):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue